diff --git a/Cargo.lock b/Cargo.lock index 83d01df..5014842 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -247,6 +247,26 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-streams" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a3e367d27d8c1ce16fbd0d96ddf05105fd1147f5d35ffc55e254dab914e72e8" +dependencies = [ + "axum", + "bytes", + "cargo-husky", + "futures", + "futures-util", + "http", + "mime", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "tokio-util", +] + [[package]] name = "axum-tracing-opentelemetry" version = "0.10.0" @@ -417,6 +437,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663" +[[package]] +name = "cargo-husky" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" + [[package]] name = "cc" version = "1.0.79" @@ -666,11 +692,13 @@ dependencies = [ name = "ctranslate2-bindings" version = "0.1.0" dependencies = [ + "async-stream", "async-trait", "cmake", "cxx", "cxx-build", "derive_builder", + "futures", "rust-cxx-cmake-bridge", "stop-words", "tabby-inference", @@ -1295,6 +1323,7 @@ name = "http-api-bindings" version = "0.1.0" dependencies = [ "async-trait", + "futures", "reqwest", "serde", "serde_json", @@ -1625,11 +1654,13 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" name = "llama-cpp-bindings" version = "0.1.0" dependencies = [ + "async-stream", "async-trait", "cmake", "cxx", "cxx-build", "derive_builder", + "futures", "stop-words", "tabby-inference", "tokenizers", @@ -3012,10 +3043,13 @@ name = "tabby" version = "0.1.1" dependencies = [ "anyhow", + "async-stream", "axum", + "axum-streams", "axum-tracing-opentelemetry", "clap", "ctranslate2-bindings", + "futures", "http-api-bindings", "hyper", "lazy_static", @@ -3086,8 +3120,10 @@ dependencies = [ name = "tabby-inference" version = "0.1.0" dependencies = [ + "async-stream", "async-trait", "derive_builder", + "futures", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8b4df43..a54eb4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,5 @@ async-trait = "0.1.72" reqwest = { version = "0.11.18" } derive_builder = "0.12.0" tokenizers = "0.13.4-rc3" +futures = "0.3.28" +async-stream = "0.3.5" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 6753902..9c96506 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -12,6 +12,8 @@ tokio-util = { workspace = true } tabby-inference = { path = "../tabby-inference" } async-trait = { workspace = true } stop-words = { path = "../stop-words" } +futures.workspace = true +async-stream.workspace = true [build-dependencies] cxx-build = "1.0" diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 73a953b..afb8b42 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,10 +1,13 @@ use std::sync::Arc; +use async_stream::stream; use async_trait::async_trait; use derive_builder::Builder; +use futures::stream::BoxStream; use stop_words::{StopWords, StopWordsCondition}; -use tabby_inference::{TextGeneration, TextGenerationOptions}; +use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; +use tokio::sync::mpsc::{channel, Sender}; use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "tabby")] @@ -67,13 +70,19 @@ pub struct CTranslate2EngineOptions { } pub struct InferenceContext { + sender: Sender, stop_condition: StopWordsCondition, cancel: CancellationToken, } impl InferenceContext { - fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self { + fn new( + sender: Sender, + stop_condition: StopWordsCondition, + cancel: CancellationToken, + ) -> Self { InferenceContext { + sender, stop_condition, cancel, } @@ -108,30 +117,45 @@ impl CTranslate2Engine { #[async_trait] impl TextGeneration for CTranslate2Engine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let s = self.generate_stream(prompt, options).await; + helpers::stream_to_string(s).await + } + + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { let encoding = self.tokenizer.encode(prompt, true).unwrap(); let engine = self.engine.clone(); + let s = stream! { + let cancel = CancellationToken::new(); + let cancel_for_inference = cancel.clone(); + let _guard = cancel.drop_guard(); - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); - let _guard = cancel.drop_guard(); + let stop_condition = self + .stop_words + .create_condition(self.tokenizer.clone(), options.stop_words); - let stop_condition = self - .stop_words - .create_condition(self.tokenizer.clone(), options.stop_words); - let context = InferenceContext::new(stop_condition, cancel_for_inference); - let output_ids = tokio::task::spawn_blocking(move || { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ) - }) - .await - .expect("Inference failed"); - self.tokenizer.decode(&output_ids, true).unwrap() + let (sender, mut receiver) = channel::(8); + let context = InferenceContext::new(sender, stop_condition, cancel_for_inference); + tokio::task::spawn(async move { + let context = Box::new(context); + engine.inference( + context, + inference_callback, + truncate_tokens(encoding.get_tokens(), options.max_input_length), + options.max_decoding_length, + options.sampling_temperature, + ); + }); + + while let Some(next_token_id) = receiver.recv().await { + let text = self.tokenizer.decode(&[next_token_id], true).unwrap(); + yield text; + } + }; + Box::pin(s) } } @@ -150,6 +174,7 @@ fn inference_callback( token_id: u32, _token: String, ) -> bool { + let _ = context.sender.blocking_send(token_id); if context.cancel.is_cancelled() { true } else { diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 92ab725..a74266b 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] async-trait.workspace = true +futures.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/crates/http-api-bindings/src/fastchat.rs b/crates/http-api-bindings/src/fastchat.rs index a97a1df..f71e048 100644 --- a/crates/http-api-bindings/src/fastchat.rs +++ b/crates/http-api-bindings/src/fastchat.rs @@ -1,8 +1,9 @@ use async_trait::async_trait; +use futures::stream::BoxStream; use reqwest::header; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tabby_inference::{TextGeneration, TextGenerationOptions}; +use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; #[derive(Serialize)] struct Request { @@ -87,4 +88,12 @@ impl TextGeneration for FastChatEngine { resp.choices[0].text[0].clone() } + + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { + helpers::string_to_stream(self.generate(prompt, options).await).await + } } diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index c2dd226..1d74b59 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -1,8 +1,9 @@ use async_trait::async_trait; +use futures::stream::BoxStream; use reqwest::header; use serde::{Deserialize, Serialize}; use serde_json::Value; -use tabby_inference::{TextGeneration, TextGenerationOptions}; +use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; #[derive(Serialize)] struct Request { @@ -107,4 +108,12 @@ impl TextGeneration for VertexAIEngine { resp.predictions[0].content.clone() } + + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { + helpers::string_to_stream(self.generate(prompt, options).await).await + } } diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index ba2a09d..65d7d13 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -16,3 +16,5 @@ derive_builder = { workspace = true } tokenizers = { workspace = true } stop-words = { version = "0.1.0", path = "../stop-words" } tokio-util = { workspace = true } +futures.workspace = true +async-stream.workspace = true diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 176d7c7..c9aab91 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -1,12 +1,13 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use async_stream::stream; use async_trait::async_trait; use derive_builder::Builder; use ffi::create_engine; +use futures::{lock::Mutex, stream::BoxStream}; use stop_words::StopWords; -use tabby_inference::{TextGeneration, TextGenerationOptions}; +use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; -use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "llama")] mod ffi { @@ -35,7 +36,7 @@ pub struct LlamaEngineOptions { } pub struct LlamaEngine { - engine: Arc>>, + engine: Mutex>, tokenizer: Arc, stop_words: StopWords, } @@ -43,7 +44,7 @@ pub struct LlamaEngine { impl LlamaEngine { pub fn create(options: LlamaEngineOptions) -> Self { LlamaEngine { - engine: Arc::new(Mutex::new(create_engine(&options.model_path))), + engine: Mutex::new(create_engine(&options.model_path)), tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), stop_words: StopWords::default(), } @@ -53,51 +54,49 @@ impl LlamaEngine { #[async_trait] impl TextGeneration for LlamaEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); - let _guard = cancel.drop_guard(); + let s = self.generate_stream(prompt, options).await; + helpers::stream_to_string(s).await + } + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { let prompt = prompt.to_owned(); - let engine = self.engine.clone(); let mut stop_condition = self .stop_words .create_condition(self.tokenizer.clone(), options.stop_words); - let output_ids = tokio::task::spawn_blocking(move || { - let engine = engine.lock().unwrap(); + let s = stream! { + let engine = self.engine.lock().await; let eos_token = engine.eos_token(); let mut next_token_id = engine.start(&prompt, options.max_input_length); if next_token_id == eos_token { - return Vec::new(); - } + yield "".to_owned(); + } else { + let mut n_remains = options.max_decoding_length - 1; - let mut n_remains = options.max_decoding_length - 1; - let mut output_ids = vec![next_token_id]; + while n_remains > 0 { + next_token_id = engine.step(next_token_id); + if next_token_id == eos_token { + break; + } - while n_remains > 0 { - if cancel_for_inference.is_cancelled() { - // The token was cancelled - break; + if stop_condition.next_token(next_token_id) { + break; + } + + let text = self.tokenizer.decode(&[next_token_id], true).unwrap(); + yield text; + n_remains -= 1; } - - next_token_id = engine.step(next_token_id); - if next_token_id == eos_token { - break; - } - - if stop_condition.next_token(next_token_id) { - break; - } - output_ids.push(next_token_id); - n_remains -= 1; } engine.end(); - output_ids - }) - .await - .expect("Inference failed"); - self.tokenizer.decode(&output_ids, true).unwrap() + }; + + Box::pin(s) } } diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index 9be1df2..fa29afa 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -6,5 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-stream = { workspace = true } async-trait = { workspace = true } derive_builder = "0.12.0" +futures = { workspace = true } diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 1622dc0..04cad0d 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use derive_builder::Builder; +use futures::stream::BoxStream; #[derive(Builder, Debug)] pub struct TextGenerationOptions { @@ -21,4 +22,33 @@ static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; #[async_trait] pub trait TextGeneration: Sync + Send { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream; +} + +pub mod helpers { + use async_stream::stream; + use futures::{pin_mut, stream::BoxStream, Stream, StreamExt}; + + pub async fn stream_to_string(s: impl Stream) -> String { + pin_mut!(s); + + let mut text = "".to_owned(); + while let Some(value) = s.next().await { + text += &value; + } + + text + } + + pub async fn string_to_stream(s: String) -> BoxStream<'static, String> { + let stream = stream! { + yield s + }; + + Box::pin(stream) + } } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 5976206..624473d 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -36,6 +36,9 @@ anyhow = { workspace = true } sysinfo = "0.29.8" nvml-wrapper = "0.9.0" http-api-bindings = { path = "../http-api-bindings" } +futures = { workspace = true } +async-stream = { workspace = true } +axum-streams = { version = "0.9.1", features = ["json"] } [target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies] llama-cpp-bindings = { path = "../llama-cpp-bindings" } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 1849e5d..73113e3 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -1,21 +1,17 @@ mod languages; mod prompt; -use std::{path::Path, sync::Arc}; +use std::sync::Arc; use axum::{extract::State, Json}; -use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; -use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tabby_common::{config::Config, events, path::ModelDir}; +use tabby_common::{config::Config, events}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; use self::languages::get_stop_words; -use crate::fatal; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ @@ -124,14 +120,16 @@ pub async fn completion( } pub struct CompletionState { - engine: Box, + engine: Arc>, prompt_builder: prompt::PromptBuilder, } impl CompletionState { - pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { - let (engine, prompt_template) = create_engine(args); - + pub fn new( + engine: Arc>, + prompt_template: Option, + config: &Config, + ) -> Self { Self { engine, prompt_builder: prompt::PromptBuilder::new( @@ -141,120 +139,3 @@ impl CompletionState { } } } - -fn get_param(params: &Value, key: &str) -> String { - params - .get(key) - .unwrap_or_else(|| panic!("Missing {} field", key)) - .as_str() - .expect("Type unmatched") - .to_string() -} - -fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Option) { - if args.device != super::Device::ExperimentalHttp { - let model_dir = get_model_dir(&args.model); - let metadata = read_metadata(&model_dir); - let engine = create_local_engine(args, &model_dir, &metadata); - (engine, metadata.prompt_template) - } else { - let params: Value = - serdeconv::from_json_str(&args.model).expect("Failed to parse model string"); - - let kind = get_param(¶ms, "kind"); - - if kind == "vertex-ai" { - let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(VertexAIEngine::create( - api_endpoint.as_str(), - authorization.as_str(), - )); - (engine, Some(VertexAIEngine::prompt_template())) - } else if kind == "fastchat" { - let model_name = get_param(¶ms, "model_name"); - let api_endpoint = get_param(¶ms, "api_endpoint"); - let authorization = get_param(¶ms, "authorization"); - let engine = Box::new(FastChatEngine::create( - api_endpoint.as_str(), - model_name.as_str(), - authorization.as_str(), - )); - (engine, Some(FastChatEngine::prompt_template())) - } else { - fatal!("Only vertex_ai and fastchat are supported for http backend"); - } - } -} - -#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] -fn create_local_engine( - args: &crate::serve::ServeArgs, - model_dir: &ModelDir, - metadata: &Metadata, -) -> Box { - create_ctranslate2_engine(args, model_dir, metadata) -} - -#[cfg(all(target_os = "macos", target_arch = "aarch64"))] -fn create_local_engine( - args: &crate::serve::ServeArgs, - model_dir: &ModelDir, - metadata: &Metadata, -) -> Box { - if args.device != super::Device::Metal { - create_ctranslate2_engine(args, model_dir, metadata) - } else { - create_llama_engine(model_dir) - } -} - -fn create_ctranslate2_engine( - args: &crate::serve::ServeArgs, - model_dir: &ModelDir, - metadata: &Metadata, -) -> Box { - let device = format!("{}", args.device); - let compute_type = format!("{}", args.compute_type); - let options = CTranslate2EngineOptionsBuilder::default() - .model_path(model_dir.ctranslate2_dir()) - .tokenizer_path(model_dir.tokenizer_file()) - .device(device) - .model_type(metadata.auto_model.clone()) - .device_indices(args.device_indices.clone()) - .num_replicas_per_device(args.num_replicas_per_device) - .compute_type(compute_type) - .build() - .unwrap(); - Box::new(CTranslate2Engine::create(options)) -} - -#[cfg(all(target_os = "macos", target_arch = "aarch64"))] -fn create_llama_engine(model_dir: &ModelDir) -> Box { - let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() - .model_path(model_dir.ggml_q8_0_file()) - .tokenizer_path(model_dir.tokenizer_file()) - .build() - .unwrap(); - - Box::new(llama_cpp_bindings::LlamaEngine::create(options)) -} - -fn get_model_dir(model: &str) -> ModelDir { - if Path::new(model).exists() { - ModelDir::from(model) - } else { - ModelDir::new(model) - } -} - -#[derive(Deserialize)] -struct Metadata { - auto_model: String, - prompt_template: Option, -} - -fn read_metadata(model_dir: &ModelDir) -> Metadata { - serdeconv::from_json_file(model_dir.metadata_file()) - .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file())) -} diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs new file mode 100644 index 0000000..d46a450 --- /dev/null +++ b/crates/tabby/src/serve/engine.rs @@ -0,0 +1,127 @@ +use std::path::Path; + +use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; +use http_api_bindings::{fastchat::FastChatEngine, vertex_ai::VertexAIEngine}; +use serde::Deserialize; +use serde_json::Value; +use tabby_common::path::ModelDir; +use tabby_inference::TextGeneration; + +use crate::fatal; + +fn get_param(params: &Value, key: &str) -> String { + params + .get(key) + .unwrap_or_else(|| panic!("Missing {} field", key)) + .as_str() + .expect("Type unmatched") + .to_string() +} + +pub fn create_engine(args: &crate::serve::ServeArgs) -> (Box, Option) { + if args.device != super::Device::ExperimentalHttp { + let model_dir = get_model_dir(&args.model); + let metadata = read_metadata(&model_dir); + let engine = create_local_engine(args, &model_dir, &metadata); + (engine, metadata.prompt_template) + } else { + let params: Value = + serdeconv::from_json_str(&args.model).expect("Failed to parse model string"); + + let kind = get_param(¶ms, "kind"); + + if kind == "vertex-ai" { + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(VertexAIEngine::create( + api_endpoint.as_str(), + authorization.as_str(), + )); + (engine, Some(VertexAIEngine::prompt_template())) + } else if kind == "fastchat" { + let model_name = get_param(¶ms, "model_name"); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let authorization = get_param(¶ms, "authorization"); + let engine = Box::new(FastChatEngine::create( + api_endpoint.as_str(), + model_name.as_str(), + authorization.as_str(), + )); + (engine, Some(FastChatEngine::prompt_template())) + } else { + fatal!("Only vertex_ai and fastchat are supported for http backend"); + } + } +} + +#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] +fn create_local_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + create_ctranslate2_engine(args, model_dir, metadata) +} + +#[cfg(all(target_os = "macos", target_arch = "aarch64"))] +fn create_local_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + if args.device != super::Device::Metal { + create_ctranslate2_engine(args, model_dir, metadata) + } else { + create_llama_engine(model_dir) + } +} + +fn create_ctranslate2_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + let device = format!("{}", args.device); + let compute_type = format!("{}", args.compute_type); + let options = CTranslate2EngineOptionsBuilder::default() + .model_path(model_dir.ctranslate2_dir()) + .tokenizer_path(model_dir.tokenizer_file()) + .device(device) + .model_type(metadata.auto_model.clone()) + .device_indices(args.device_indices.clone()) + .num_replicas_per_device(args.num_replicas_per_device) + .compute_type(compute_type) + .build() + .unwrap(); + Box::new(CTranslate2Engine::create(options)) +} + +#[cfg(all(target_os = "macos", target_arch = "aarch64"))] +fn create_llama_engine(model_dir: &ModelDir) -> Box { + let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() + .model_path(model_dir.ggml_q8_0_file()) + .tokenizer_path(model_dir.tokenizer_file()) + .build() + .unwrap(); + + Box::new(llama_cpp_bindings::LlamaEngine::create(options)) +} + +fn get_model_dir(model: &str) -> ModelDir { + if Path::new(model).exists() { + ModelDir::from(model) + } else { + ModelDir::new(model) + } +} + +#[derive(Deserialize)] +struct Metadata { + auto_model: String, + prompt_template: Option, +} + +fn read_metadata(model_dir: &ModelDir) -> Metadata { + serdeconv::from_json_file(model_dir.metadata_file()) + .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file())) +} diff --git a/crates/tabby/src/serve/generate.rs b/crates/tabby/src/serve/generate.rs new file mode 100644 index 0000000..4dc2f8a --- /dev/null +++ b/crates/tabby/src/serve/generate.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use async_stream::stream; +use axum::{extract::State, response::IntoResponse, Json}; +use axum_streams::StreamBodyAs; +use serde::{Deserialize, Serialize}; +use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; +use tracing::instrument; +use utoipa::ToSchema; + +pub struct GenerateState { + engine: Arc>, +} + +impl GenerateState { + pub fn new(engine: Arc>) -> Self { + Self { engine } + } +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct GenerateRequest { + #[schema( + example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef" + )] + prompt: String, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct GenerateResponse { + text: String, +} + +#[utoipa::path( + post, + path = "/v1/generate", + request_body = GenerateRequest, + operation_id = "generate", + tag = "v1", + responses( + (status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"), + ) +)] +#[instrument(skip(state, request))] +pub async fn generate( + State(state): State>, + Json(request): Json, +) -> impl IntoResponse { + let options = build_options(&request); + Json(GenerateResponse { + text: state.engine.generate(&request.prompt, options).await, + }) +} + +#[utoipa::path( + post, + path = "/v1/generate_stream", + request_body = GenerateRequest, + operation_id = "generate_stream", + tag = "v1", + responses( + (status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"), + ) +)] +#[instrument(skip(state, request))] +pub async fn generate_stream( + State(state): State>, + Json(request): Json, +) -> impl IntoResponse { + let options = build_options(&request); + let s = stream! { + for await text in state.engine.generate_stream(&request.prompt, options).await { + yield GenerateResponse { text } + } + }; + + StreamBodyAs::json_nl(s) +} + +fn build_options(_request: &GenerateRequest) -> TextGenerationOptions { + TextGenerationOptionsBuilder::default() + .max_input_length(2048) + .max_decoding_length(usize::MAX) + .sampling_temperature(0.1) + .build() + .unwrap() +} diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 9691952..51deaf4 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -1,5 +1,7 @@ mod completions; +mod engine; mod events; +mod generate; mod health; use std::{ @@ -19,7 +21,7 @@ use tracing::{info, warn}; use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa_swagger_ui::SwaggerUi; -use self::health::HealthState; +use self::{engine::create_engine, health::HealthState}; use crate::fatal; #[derive(OpenApi)] @@ -39,13 +41,15 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "https://playground.app.tabbyml.com", description = "Playground server"), ), - paths(events::log_event, completions::completion, health::health), + paths(events::log_event, completions::completion, generate::generate, generate::generate_stream, health::health), components(schemas( events::LogEventRequest, completions::CompletionRequest, completions::CompletionResponse, completions::Segments, completions::Choice, + generate::GenerateRequest, + generate::GenerateResponse, health::HealthState, health::Version, )) @@ -171,6 +175,8 @@ pub async fn main(config: &Config, args: &ServeArgs) { } fn api_router(args: &ServeArgs, config: &Config) -> Router { + let (engine, prompt_template) = create_engine(args); + let engine = Arc::new(engine); Router::new() .route("/events", routing::post(events::log_event)) .route( @@ -179,8 +185,19 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { ) .route( "/completions", - routing::post(completions::completion) - .with_state(Arc::new(completions::CompletionState::new(args, config))), + routing::post(completions::completion).with_state(Arc::new( + completions::CompletionState::new(engine.clone(), prompt_template, config), + )), + ) + .route( + "/generate", + routing::post(generate::generate) + .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), + ) + .route( + "/generate_stream", + routing::post(generate::generate_stream) + .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), ) .layer(CorsLayer::permissive()) .layer(opentelemetry_tracing_layer())