feat: add /generate and /generate_streaming (#482)
* feat: add generate_stream interface * extract engine::create_engine * feat add generate::generate * support streaming in llama.cpp * support streaming in ctranslate2 * update * fix formatting * refactor: extract helpers functionsrelease-0.2
parent
1d6ac7836b
commit
44f013f26e
|
|
@ -247,6 +247,26 @@ dependencies = [
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-streams"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4a3e367d27d8c1ce16fbd0d96ddf05105fd1147f5d35ffc55e254dab914e72e8"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"bytes",
|
||||||
|
"cargo-husky",
|
||||||
|
"futures",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"mime",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tokio-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-tracing-opentelemetry"
|
name = "axum-tracing-opentelemetry"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
|
@ -417,6 +437,12 @@ version = "0.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
|
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cargo-husky"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.79"
|
version = "1.0.79"
|
||||||
|
|
@ -666,11 +692,13 @@ dependencies = [
|
||||||
name = "ctranslate2-bindings"
|
name = "ctranslate2-bindings"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"futures",
|
||||||
"rust-cxx-cmake-bridge",
|
"rust-cxx-cmake-bridge",
|
||||||
"stop-words",
|
"stop-words",
|
||||||
"tabby-inference",
|
"tabby-inference",
|
||||||
|
|
@ -1295,6 +1323,7 @@ name = "http-api-bindings"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"futures",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
@ -1625,11 +1654,13 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
|
||||||
name = "llama-cpp-bindings"
|
name = "llama-cpp-bindings"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"futures",
|
||||||
"stop-words",
|
"stop-words",
|
||||||
"tabby-inference",
|
"tabby-inference",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
|
@ -3012,10 +3043,13 @@ name = "tabby"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
"axum-streams",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap",
|
"clap",
|
||||||
"ctranslate2-bindings",
|
"ctranslate2-bindings",
|
||||||
|
"futures",
|
||||||
"http-api-bindings",
|
"http-api-bindings",
|
||||||
"hyper",
|
"hyper",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
|
|
@ -3086,8 +3120,10 @@ dependencies = [
|
||||||
name = "tabby-inference"
|
name = "tabby-inference"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"futures",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
|
|
@ -35,3 +35,5 @@ async-trait = "0.1.72"
|
||||||
reqwest = { version = "0.11.18" }
|
reqwest = { version = "0.11.18" }
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
tokenizers = "0.13.4-rc3"
|
tokenizers = "0.13.4-rc3"
|
||||||
|
futures = "0.3.28"
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ tokio-util = { workspace = true }
|
||||||
tabby-inference = { path = "../tabby-inference" }
|
tabby-inference = { path = "../tabby-inference" }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
stop-words = { path = "../stop-words" }
|
stop-words = { path = "../stop-words" }
|
||||||
|
futures.workspace = true
|
||||||
|
async-stream.workspace = true
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cxx-build = "1.0"
|
cxx-build = "1.0"
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
use stop_words::{StopWords, StopWordsCondition};
|
use stop_words::{StopWords, StopWordsCondition};
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokio::sync::mpsc::{channel, Sender};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "tabby")]
|
#[cxx::bridge(namespace = "tabby")]
|
||||||
|
|
@ -67,13 +70,19 @@ pub struct CTranslate2EngineOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
|
sender: Sender<u32>,
|
||||||
stop_condition: StopWordsCondition,
|
stop_condition: StopWordsCondition,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceContext {
|
impl InferenceContext {
|
||||||
fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self {
|
fn new(
|
||||||
|
sender: Sender<u32>,
|
||||||
|
stop_condition: StopWordsCondition,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> Self {
|
||||||
InferenceContext {
|
InferenceContext {
|
||||||
|
sender,
|
||||||
stop_condition,
|
stop_condition,
|
||||||
cancel,
|
cancel,
|
||||||
}
|
}
|
||||||
|
|
@ -108,9 +117,18 @@ impl CTranslate2Engine {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for CTranslate2Engine {
|
impl TextGeneration for CTranslate2Engine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
|
let s = self.generate_stream(prompt, options).await;
|
||||||
|
helpers::stream_to_string(s).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String> {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
let engine = self.engine.clone();
|
let engine = self.engine.clone();
|
||||||
|
let s = stream! {
|
||||||
let cancel = CancellationToken::new();
|
let cancel = CancellationToken::new();
|
||||||
let cancel_for_inference = cancel.clone();
|
let cancel_for_inference = cancel.clone();
|
||||||
let _guard = cancel.drop_guard();
|
let _guard = cancel.drop_guard();
|
||||||
|
|
@ -118,8 +136,10 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
let stop_condition = self
|
let stop_condition = self
|
||||||
.stop_words
|
.stop_words
|
||||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||||
let context = InferenceContext::new(stop_condition, cancel_for_inference);
|
|
||||||
let output_ids = tokio::task::spawn_blocking(move || {
|
let (sender, mut receiver) = channel::<u32>(8);
|
||||||
|
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
|
||||||
|
tokio::task::spawn(async move {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
context,
|
context,
|
||||||
|
|
@ -127,11 +147,15 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||||
options.max_decoding_length,
|
options.max_decoding_length,
|
||||||
options.sampling_temperature,
|
options.sampling_temperature,
|
||||||
)
|
);
|
||||||
})
|
});
|
||||||
.await
|
|
||||||
.expect("Inference failed");
|
while let Some(next_token_id) = receiver.recv().await {
|
||||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
||||||
|
yield text;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Box::pin(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,6 +174,7 @@ fn inference_callback(
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
_token: String,
|
_token: String,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
|
let _ = context.sender.blocking_send(token_id);
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
|
futures.workspace = true
|
||||||
reqwest = { workspace = true, features = ["json"] }
|
reqwest = { workspace = true, features = ["json"] }
|
||||||
serde = { workspace = true, features = ["derive"] }
|
serde = { workspace = true, features = ["derive"] }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
use reqwest::header;
|
use reqwest::header;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct Request {
|
struct Request {
|
||||||
|
|
@ -87,4 +88,12 @@ impl TextGeneration for FastChatEngine {
|
||||||
|
|
||||||
resp.choices[0].text[0].clone()
|
resp.choices[0].text[0].clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String> {
|
||||||
|
helpers::string_to_stream(self.generate(prompt, options).await).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
use reqwest::header;
|
use reqwest::header;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct Request {
|
struct Request {
|
||||||
|
|
@ -107,4 +108,12 @@ impl TextGeneration for VertexAIEngine {
|
||||||
|
|
||||||
resp.predictions[0].content.clone()
|
resp.predictions[0].content.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String> {
|
||||||
|
helpers::string_to_stream(self.generate(prompt, options).await).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,3 +16,5 @@ derive_builder = { workspace = true }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
stop-words = { version = "0.1.0", path = "../stop-words" }
|
stop-words = { version = "0.1.0", path = "../stop-words" }
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
|
futures.workspace = true
|
||||||
|
async-stream.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use ffi::create_engine;
|
use ffi::create_engine;
|
||||||
|
use futures::{lock::Mutex, stream::BoxStream};
|
||||||
use stop_words::StopWords;
|
use stop_words::StopWords;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "llama")]
|
#[cxx::bridge(namespace = "llama")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
@ -35,7 +36,7 @@ pub struct LlamaEngineOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaEngine {
|
pub struct LlamaEngine {
|
||||||
engine: Arc<Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>>,
|
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
stop_words: StopWords,
|
stop_words: StopWords,
|
||||||
}
|
}
|
||||||
|
|
@ -43,7 +44,7 @@ pub struct LlamaEngine {
|
||||||
impl LlamaEngine {
|
impl LlamaEngine {
|
||||||
pub fn create(options: LlamaEngineOptions) -> Self {
|
pub fn create(options: LlamaEngineOptions) -> Self {
|
||||||
LlamaEngine {
|
LlamaEngine {
|
||||||
engine: Arc::new(Mutex::new(create_engine(&options.model_path))),
|
engine: Mutex::new(create_engine(&options.model_path)),
|
||||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||||
stop_words: StopWords::default(),
|
stop_words: StopWords::default(),
|
||||||
}
|
}
|
||||||
|
|
@ -53,34 +54,31 @@ impl LlamaEngine {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for LlamaEngine {
|
impl TextGeneration for LlamaEngine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let cancel = CancellationToken::new();
|
let s = self.generate_stream(prompt, options).await;
|
||||||
let cancel_for_inference = cancel.clone();
|
helpers::stream_to_string(s).await
|
||||||
let _guard = cancel.drop_guard();
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String> {
|
||||||
let prompt = prompt.to_owned();
|
let prompt = prompt.to_owned();
|
||||||
let engine = self.engine.clone();
|
|
||||||
let mut stop_condition = self
|
let mut stop_condition = self
|
||||||
.stop_words
|
.stop_words
|
||||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||||
|
|
||||||
let output_ids = tokio::task::spawn_blocking(move || {
|
let s = stream! {
|
||||||
let engine = engine.lock().unwrap();
|
let engine = self.engine.lock().await;
|
||||||
let eos_token = engine.eos_token();
|
let eos_token = engine.eos_token();
|
||||||
|
|
||||||
let mut next_token_id = engine.start(&prompt, options.max_input_length);
|
let mut next_token_id = engine.start(&prompt, options.max_input_length);
|
||||||
if next_token_id == eos_token {
|
if next_token_id == eos_token {
|
||||||
return Vec::new();
|
yield "".to_owned();
|
||||||
}
|
} else {
|
||||||
|
|
||||||
let mut n_remains = options.max_decoding_length - 1;
|
let mut n_remains = options.max_decoding_length - 1;
|
||||||
let mut output_ids = vec![next_token_id];
|
|
||||||
|
|
||||||
while n_remains > 0 {
|
while n_remains > 0 {
|
||||||
if cancel_for_inference.is_cancelled() {
|
|
||||||
// The token was cancelled
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
next_token_id = engine.step(next_token_id);
|
next_token_id = engine.step(next_token_id);
|
||||||
if next_token_id == eos_token {
|
if next_token_id == eos_token {
|
||||||
break;
|
break;
|
||||||
|
|
@ -89,15 +87,16 @@ impl TextGeneration for LlamaEngine {
|
||||||
if stop_condition.next_token(next_token_id) {
|
if stop_condition.next_token(next_token_id) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
output_ids.push(next_token_id);
|
|
||||||
|
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
||||||
|
yield text;
|
||||||
n_remains -= 1;
|
n_remains -= 1;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
engine.end();
|
engine.end();
|
||||||
output_ids
|
};
|
||||||
})
|
|
||||||
.await
|
Box::pin(s)
|
||||||
.expect("Inference failed");
|
|
||||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,5 +6,7 @@ edition = "2021"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
async-stream = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
|
futures = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
pub struct TextGenerationOptions {
|
pub struct TextGenerationOptions {
|
||||||
|
|
@ -21,4 +22,33 @@ static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait TextGeneration: Sync + Send {
|
pub trait TextGeneration: Sync + Send {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod helpers {
|
||||||
|
use async_stream::stream;
|
||||||
|
use futures::{pin_mut, stream::BoxStream, Stream, StreamExt};
|
||||||
|
|
||||||
|
pub async fn stream_to_string(s: impl Stream<Item = String>) -> String {
|
||||||
|
pin_mut!(s);
|
||||||
|
|
||||||
|
let mut text = "".to_owned();
|
||||||
|
while let Some(value) = s.next().await {
|
||||||
|
text += &value;
|
||||||
|
}
|
||||||
|
|
||||||
|
text
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn string_to_stream(s: String) -> BoxStream<'static, String> {
|
||||||
|
let stream = stream! {
|
||||||
|
yield s
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(stream)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,9 @@ anyhow = { workspace = true }
|
||||||
sysinfo = "0.29.8"
|
sysinfo = "0.29.8"
|
||||||
nvml-wrapper = "0.9.0"
|
nvml-wrapper = "0.9.0"
|
||||||
http-api-bindings = { path = "../http-api-bindings" }
|
http-api-bindings = { path = "../http-api-bindings" }
|
||||||
|
futures = { workspace = true }
|
||||||
|
async-stream = { workspace = true }
|
||||||
|
axum-streams = { version = "0.9.1", features = ["json"] }
|
||||||
|
|
||||||
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
|
||||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,17 @@
|
||||||
mod languages;
|
mod languages;
|
||||||
mod prompt;
|
mod prompt;
|
||||||
|
|
||||||
use std::{path::Path, sync::Arc};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
|
||||||
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
|
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use tabby_common::{config::Config, events};
|
||||||
use tabby_common::{config::Config, events, path::ModelDir};
|
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use self::languages::get_stop_words;
|
use self::languages::get_stop_words;
|
||||||
use crate::fatal;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
#[schema(example=json!({
|
#[schema(example=json!({
|
||||||
|
|
@ -124,14 +120,16 @@ pub async fn completion(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CompletionState {
|
pub struct CompletionState {
|
||||||
engine: Box<dyn TextGeneration>,
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
prompt_builder: prompt::PromptBuilder,
|
prompt_builder: prompt::PromptBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
pub fn new(
|
||||||
let (engine, prompt_template) = create_engine(args);
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
|
prompt_template: Option<String>,
|
||||||
|
config: &Config,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: prompt::PromptBuilder::new(
|
prompt_builder: prompt::PromptBuilder::new(
|
||||||
|
|
@ -141,120 +139,3 @@ impl CompletionState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_param(params: &Value, key: &str) -> String {
|
|
||||||
params
|
|
||||||
.get(key)
|
|
||||||
.unwrap_or_else(|| panic!("Missing {} field", key))
|
|
||||||
.as_str()
|
|
||||||
.expect("Type unmatched")
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
|
||||||
if args.device != super::Device::ExperimentalHttp {
|
|
||||||
let model_dir = get_model_dir(&args.model);
|
|
||||||
let metadata = read_metadata(&model_dir);
|
|
||||||
let engine = create_local_engine(args, &model_dir, &metadata);
|
|
||||||
(engine, metadata.prompt_template)
|
|
||||||
} else {
|
|
||||||
let params: Value =
|
|
||||||
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
|
||||||
|
|
||||||
let kind = get_param(¶ms, "kind");
|
|
||||||
|
|
||||||
if kind == "vertex-ai" {
|
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
|
||||||
let authorization = get_param(¶ms, "authorization");
|
|
||||||
let engine = Box::new(VertexAIEngine::create(
|
|
||||||
api_endpoint.as_str(),
|
|
||||||
authorization.as_str(),
|
|
||||||
));
|
|
||||||
(engine, Some(VertexAIEngine::prompt_template()))
|
|
||||||
} else if kind == "fastchat" {
|
|
||||||
let model_name = get_param(¶ms, "model_name");
|
|
||||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
|
||||||
let authorization = get_param(¶ms, "authorization");
|
|
||||||
let engine = Box::new(FastChatEngine::create(
|
|
||||||
api_endpoint.as_str(),
|
|
||||||
model_name.as_str(),
|
|
||||||
authorization.as_str(),
|
|
||||||
));
|
|
||||||
(engine, Some(FastChatEngine::prompt_template()))
|
|
||||||
} else {
|
|
||||||
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
|
||||||
fn create_local_engine(
|
|
||||||
args: &crate::serve::ServeArgs,
|
|
||||||
model_dir: &ModelDir,
|
|
||||||
metadata: &Metadata,
|
|
||||||
) -> Box<dyn TextGeneration> {
|
|
||||||
create_ctranslate2_engine(args, model_dir, metadata)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
|
||||||
fn create_local_engine(
|
|
||||||
args: &crate::serve::ServeArgs,
|
|
||||||
model_dir: &ModelDir,
|
|
||||||
metadata: &Metadata,
|
|
||||||
) -> Box<dyn TextGeneration> {
|
|
||||||
if args.device != super::Device::Metal {
|
|
||||||
create_ctranslate2_engine(args, model_dir, metadata)
|
|
||||||
} else {
|
|
||||||
create_llama_engine(model_dir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_ctranslate2_engine(
|
|
||||||
args: &crate::serve::ServeArgs,
|
|
||||||
model_dir: &ModelDir,
|
|
||||||
metadata: &Metadata,
|
|
||||||
) -> Box<dyn TextGeneration> {
|
|
||||||
let device = format!("{}", args.device);
|
|
||||||
let compute_type = format!("{}", args.compute_type);
|
|
||||||
let options = CTranslate2EngineOptionsBuilder::default()
|
|
||||||
.model_path(model_dir.ctranslate2_dir())
|
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
|
||||||
.device(device)
|
|
||||||
.model_type(metadata.auto_model.clone())
|
|
||||||
.device_indices(args.device_indices.clone())
|
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
|
||||||
.compute_type(compute_type)
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
Box::new(CTranslate2Engine::create(options))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
|
||||||
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
|
||||||
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
|
|
||||||
.model_path(model_dir.ggml_q8_0_file())
|
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model_dir(model: &str) -> ModelDir {
|
|
||||||
if Path::new(model).exists() {
|
|
||||||
ModelDir::from(model)
|
|
||||||
} else {
|
|
||||||
ModelDir::new(model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct Metadata {
|
|
||||||
auto_model: String,
|
|
||||||
prompt_template: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
|
||||||
serdeconv::from_json_file(model_dir.metadata_file())
|
|
||||||
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
|
||||||
|
use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tabby_common::path::ModelDir;
|
||||||
|
use tabby_inference::TextGeneration;
|
||||||
|
|
||||||
|
use crate::fatal;
|
||||||
|
|
||||||
|
fn get_param(params: &Value, key: &str) -> String {
|
||||||
|
params
|
||||||
|
.get(key)
|
||||||
|
.unwrap_or_else(|| panic!("Missing {} field", key))
|
||||||
|
.as_str()
|
||||||
|
.expect("Type unmatched")
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box<dyn TextGeneration>, Option<String>) {
|
||||||
|
if args.device != super::Device::ExperimentalHttp {
|
||||||
|
let model_dir = get_model_dir(&args.model);
|
||||||
|
let metadata = read_metadata(&model_dir);
|
||||||
|
let engine = create_local_engine(args, &model_dir, &metadata);
|
||||||
|
(engine, metadata.prompt_template)
|
||||||
|
} else {
|
||||||
|
let params: Value =
|
||||||
|
serdeconv::from_json_str(&args.model).expect("Failed to parse model string");
|
||||||
|
|
||||||
|
let kind = get_param(¶ms, "kind");
|
||||||
|
|
||||||
|
if kind == "vertex-ai" {
|
||||||
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(VertexAIEngine::create(
|
||||||
|
api_endpoint.as_str(),
|
||||||
|
authorization.as_str(),
|
||||||
|
));
|
||||||
|
(engine, Some(VertexAIEngine::prompt_template()))
|
||||||
|
} else if kind == "fastchat" {
|
||||||
|
let model_name = get_param(¶ms, "model_name");
|
||||||
|
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||||
|
let authorization = get_param(¶ms, "authorization");
|
||||||
|
let engine = Box::new(FastChatEngine::create(
|
||||||
|
api_endpoint.as_str(),
|
||||||
|
model_name.as_str(),
|
||||||
|
authorization.as_str(),
|
||||||
|
));
|
||||||
|
(engine, Some(FastChatEngine::prompt_template()))
|
||||||
|
} else {
|
||||||
|
fatal!("Only vertex_ai and fastchat are supported for http backend");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
fn create_local_engine(
|
||||||
|
args: &crate::serve::ServeArgs,
|
||||||
|
model_dir: &ModelDir,
|
||||||
|
metadata: &Metadata,
|
||||||
|
) -> Box<dyn TextGeneration> {
|
||||||
|
create_ctranslate2_engine(args, model_dir, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
|
fn create_local_engine(
|
||||||
|
args: &crate::serve::ServeArgs,
|
||||||
|
model_dir: &ModelDir,
|
||||||
|
metadata: &Metadata,
|
||||||
|
) -> Box<dyn TextGeneration> {
|
||||||
|
if args.device != super::Device::Metal {
|
||||||
|
create_ctranslate2_engine(args, model_dir, metadata)
|
||||||
|
} else {
|
||||||
|
create_llama_engine(model_dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_ctranslate2_engine(
|
||||||
|
args: &crate::serve::ServeArgs,
|
||||||
|
model_dir: &ModelDir,
|
||||||
|
metadata: &Metadata,
|
||||||
|
) -> Box<dyn TextGeneration> {
|
||||||
|
let device = format!("{}", args.device);
|
||||||
|
let compute_type = format!("{}", args.compute_type);
|
||||||
|
let options = CTranslate2EngineOptionsBuilder::default()
|
||||||
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
|
.device(device)
|
||||||
|
.model_type(metadata.auto_model.clone())
|
||||||
|
.device_indices(args.device_indices.clone())
|
||||||
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
|
.compute_type(compute_type)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
Box::new(CTranslate2Engine::create(options))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||||
|
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||||
|
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
|
||||||
|
.model_path(model_dir.ggml_q8_0_file())
|
||||||
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_dir(model: &str) -> ModelDir {
|
||||||
|
if Path::new(model).exists() {
|
||||||
|
ModelDir::from(model)
|
||||||
|
} else {
|
||||||
|
ModelDir::new(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Metadata {
|
||||||
|
auto_model: String,
|
||||||
|
prompt_template: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
|
serdeconv::from_json_file(model_dir.metadata_file())
|
||||||
|
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
|
use axum_streams::StreamBodyAs;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
pub struct GenerateState {
|
||||||
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateState {
|
||||||
|
pub fn new(engine: Arc<Box<dyn TextGeneration>>) -> Self {
|
||||||
|
Self { engine }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
#[schema(
|
||||||
|
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef"
|
||||||
|
)]
|
||||||
|
prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct GenerateResponse {
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/v1/generate",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
operation_id = "generate",
|
||||||
|
tag = "v1",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip(state, request))]
|
||||||
|
pub async fn generate(
|
||||||
|
State(state): State<Arc<GenerateState>>,
|
||||||
|
Json(request): Json<GenerateRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let options = build_options(&request);
|
||||||
|
Json(GenerateResponse {
|
||||||
|
text: state.engine.generate(&request.prompt, options).await,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/v1/generate_stream",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
operation_id = "generate_stream",
|
||||||
|
tag = "v1",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip(state, request))]
|
||||||
|
pub async fn generate_stream(
|
||||||
|
State(state): State<Arc<GenerateState>>,
|
||||||
|
Json(request): Json<GenerateRequest>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let options = build_options(&request);
|
||||||
|
let s = stream! {
|
||||||
|
for await text in state.engine.generate_stream(&request.prompt, options).await {
|
||||||
|
yield GenerateResponse { text }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
StreamBodyAs::json_nl(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
|
||||||
|
TextGenerationOptionsBuilder::default()
|
||||||
|
.max_input_length(2048)
|
||||||
|
.max_decoding_length(usize::MAX)
|
||||||
|
.sampling_temperature(0.1)
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
mod completions;
|
mod completions;
|
||||||
|
mod engine;
|
||||||
mod events;
|
mod events;
|
||||||
|
mod generate;
|
||||||
mod health;
|
mod health;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
|
|
@ -19,7 +21,7 @@ use tracing::{info, warn};
|
||||||
use utoipa::{openapi::ServerBuilder, OpenApi};
|
use utoipa::{openapi::ServerBuilder, OpenApi};
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use self::health::HealthState;
|
use self::{engine::create_engine, health::HealthState};
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
|
@ -39,13 +41,15 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
servers(
|
servers(
|
||||||
(url = "https://playground.app.tabbyml.com", description = "Playground server"),
|
(url = "https://playground.app.tabbyml.com", description = "Playground server"),
|
||||||
),
|
),
|
||||||
paths(events::log_event, completions::completion, health::health),
|
paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
events::LogEventRequest,
|
events::LogEventRequest,
|
||||||
completions::CompletionRequest,
|
completions::CompletionRequest,
|
||||||
completions::CompletionResponse,
|
completions::CompletionResponse,
|
||||||
completions::Segments,
|
completions::Segments,
|
||||||
completions::Choice,
|
completions::Choice,
|
||||||
|
generate::GenerateRequest,
|
||||||
|
generate::GenerateResponse,
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
health::Version,
|
health::Version,
|
||||||
))
|
))
|
||||||
|
|
@ -171,6 +175,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
|
let (engine, prompt_template) = create_engine(args);
|
||||||
|
let engine = Arc::new(engine);
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/events", routing::post(events::log_event))
|
.route("/events", routing::post(events::log_event))
|
||||||
.route(
|
.route(
|
||||||
|
|
@ -179,8 +185,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/completions",
|
"/completions",
|
||||||
routing::post(completions::completion)
|
routing::post(completions::completion).with_state(Arc::new(
|
||||||
.with_state(Arc::new(completions::CompletionState::new(args, config))),
|
completions::CompletionState::new(engine.clone(), prompt_template, config),
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/generate",
|
||||||
|
routing::post(generate::generate)
|
||||||
|
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/generate_stream",
|
||||||
|
routing::post(generate::generate_stream)
|
||||||
|
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||||
)
|
)
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(opentelemetry_tracing_layer())
|
.layer(opentelemetry_tracing_layer())
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue