diff --git a/Cargo.lock b/Cargo.lock index 1f1a9ac..0e55557 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -161,18 +161,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] name = "async-trait" -version = "0.1.68" +version = "0.1.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -443,7 +443,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -613,6 +613,7 @@ dependencies = [ name = "ctranslate2-bindings" version = "0.1.0" dependencies = [ + "async-trait", "cmake", "cxx", "cxx-build", @@ -620,6 +621,7 @@ dependencies = [ "derive_builder", "regex", "rust-cxx-cmake-bridge", + "tabby-inference", "tokenizers", "tokio", "tokio-util", @@ -649,7 +651,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -666,7 +668,7 @@ checksum = "4a076022ece33e7686fb76513518e219cca4fce5750a8ae6d1ce6c0f48fd1af9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -1028,7 +1030,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -1683,7 +1685,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -1848,7 +1850,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -2049,7 +2051,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -2127,9 +2129,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.59" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] @@ -2190,9 +2192,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.28" +version = "1.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" dependencies = [ "proc-macro2", ] @@ -2555,7 +2557,7 @@ checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -2757,9 +2759,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.18" +version = "2.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" dependencies = [ "proc-macro2", "quote", @@ -2794,6 +2796,7 @@ dependencies = [ "strum", "tabby-common", "tabby-download", + "tabby-inference", "tabby-scheduler", "tantivy", "tokio", @@ -2833,6 +2836,14 @@ dependencies = [ "tokio-retry", ] +[[package]] +name = "tabby-inference" +version = "0.1.0" +dependencies = [ + "async-trait", + "derive_builder", +] + [[package]] name = "tabby-scheduler" version = "0.1.0" @@ -3002,7 +3013,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -3141,7 +3152,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -3359,7 +3370,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -3664,7 +3675,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -3712,7 +3723,7 @@ checksum = "3f67b459f42af2e6e1ee213cb9da4dbd022d3320788c3fb3e1b893093f1e45da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", ] [[package]] @@ -3786,7 +3797,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", "wasm-bindgen-shared", ] @@ -3820,7 +3831,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.18", + "syn 2.0.28", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 3a221b9..b1ff613 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/tabby-common", "crates/tabby-scheduler", "crates/tabby-download", + "crates/tabby-inference", "crates/ctranslate2-bindings", "crates/rust-cxx-cmake-bridge", ] @@ -25,3 +26,4 @@ tracing-subscriber = "0.3" anyhow = "1.0.71" serde-jsonlines = "0.4.0" tantivy = "0.19.2" +async-trait = "0.1.72" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index b5a45fb..d6484a5 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -11,6 +11,8 @@ regex = "1.8.4" tokenizers = "0.13.3" tokio = { workspace = true, features = ["rt"] } tokio-util = { workspace = true } +tabby-inference = { path = "../tabby-inference" } +async-trait = { workspace = true } [build-dependencies] cxx-build = "1.0" diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 664048c..19dcf5f 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,11 +1,11 @@ +use async_trait::async_trait; use dashmap::DashMap; +use derive_builder::Builder; use regex::Regex; +use tabby_inference::{TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; use tokio_util::sync::CancellationToken; -#[macro_use] -extern crate derive_builder; - #[cxx::bridge(namespace = "tabby")] mod ffi { extern "Rust" { @@ -49,7 +49,7 @@ unsafe impl Send for ffi::TextInferenceEngine {} unsafe impl Sync for ffi::TextInferenceEngine {} #[derive(Builder, Debug)] -pub struct TextInferenceEngineCreateOptions { +pub struct CTranslate2EngineOptions { model_path: String, model_type: String, @@ -65,17 +65,6 @@ pub struct TextInferenceEngineCreateOptions { compute_type: String, } -#[derive(Builder, Debug)] -pub struct TextInferenceOptions { - #[builder(default = "256")] - max_decoding_length: usize, - - #[builder(default = "1.0")] - sampling_temperature: f32, - - stop_words: &'static Vec<&'static str>, -} - pub struct InferenceContext { stop_re: Option, cancel: CancellationToken, @@ -92,14 +81,14 @@ impl InferenceContext { } } -pub struct TextInferenceEngine { +pub struct CTranslate2Engine { engine: cxx::SharedPtr, tokenizer: Tokenizer, stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, } -impl TextInferenceEngine { - pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { +impl CTranslate2Engine { + pub fn create(options: CTranslate2EngineOptions) -> Self where { let engine = ffi::create_engine( &options.model_path, &options.model_type, @@ -108,14 +97,18 @@ impl TextInferenceEngine { &options.device_indices, options.num_replicas_per_device, ); - return TextInferenceEngine { + + return Self { engine, stop_regex_cache: DashMap::new(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), }; } +} - pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { +#[async_trait] +impl TextGeneration for CTranslate2Engine { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { let encoding = self.tokenizer.encode(prompt, true).unwrap(); let engine = self.engine.clone(); diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml new file mode 100644 index 0000000..9be1df2 --- /dev/null +++ b/crates/tabby-inference/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "tabby-inference" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = { workspace = true } +derive_builder = "0.12.0" diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs new file mode 100644 index 0000000..8804a58 --- /dev/null +++ b/crates/tabby-inference/src/lib.rs @@ -0,0 +1,18 @@ +use async_trait::async_trait; +use derive_builder::Builder; + +#[derive(Builder, Debug)] +pub struct TextGenerationOptions { + #[builder(default = "256")] + pub max_decoding_length: usize, + + #[builder(default = "1.0")] + pub sampling_temperature: f32, + + pub stop_words: &'static Vec<&'static str>, +} + +#[async_trait] +pub trait TextGeneration { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index be518e8..b614939 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -8,6 +8,7 @@ ctranslate2-bindings = { path = "../ctranslate2-bindings" } tabby-common = { path = "../tabby-common" } tabby-scheduler = { path = "../tabby-scheduler", optional = true } tabby-download = { path = "../tabby-download" } +tabby-inference = { path = "../tabby-inference" } axum = "0.6" hyper = { version = "0.14", features = ["full"] } tokio = { workspace = true, features = ["full"] } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index ed6d22e..5d682ff 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -4,12 +4,11 @@ mod prompt; use std::{path::Path, sync::Arc}; use axum::{extract::State, Json}; -use ctranslate2_bindings::{ - TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder, -}; +use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; use tabby_common::{config::Config, events, path::ModelDir}; +use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; @@ -79,7 +78,7 @@ pub async fn completion( Json(request): Json, ) -> Result, StatusCode> { let language = request.language.unwrap_or("unknown".to_string()); - let options = TextInferenceOptionsBuilder::default() + let options = TextGenerationOptionsBuilder::default() .max_decoding_length(128) .sampling_temperature(0.1) .stop_words(get_stop_words(&language)) @@ -101,7 +100,7 @@ pub async fn completion( let prompt = state.prompt_builder.build(&language, segments); debug!("PROMPT: {}", prompt); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); - let text = state.engine.inference(&prompt, options).await; + let text = state.engine.generate(&prompt, options).await; events::Event::Completion { completion_id: &completion_id, @@ -122,7 +121,7 @@ pub async fn completion( } pub struct CompletionState { - engine: TextInferenceEngine, + engine: CTranslate2Engine, prompt_builder: prompt::PromptBuilder, } @@ -133,7 +132,7 @@ impl CompletionState { let device = format!("{}", args.device); let compute_type = format!("{}", args.compute_type); - let options = TextInferenceEngineCreateOptionsBuilder::default() + let options = CTranslate2EngineOptionsBuilder::default() .model_path(model_dir.ctranslate2_dir()) .tokenizer_path(model_dir.tokenizer_file()) .device(device) @@ -143,7 +142,7 @@ impl CompletionState { .compute_type(compute_type) .build() .unwrap(); - let engine = TextInferenceEngine::create(options); + let engine = CTranslate2Engine::create(options); Self { engine, prompt_builder: prompt::PromptBuilder::new(