diff --git a/Cargo.lock b/Cargo.lock index c1e821c..39de936 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1565,6 +1565,7 @@ dependencies = [ "tabby-inference", "tokenizers", "tokio", + "tokio-util", ] [[package]] diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index 395829e..ba2a09d 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -15,3 +15,4 @@ tabby-inference = { path = "../tabby-inference" } derive_builder = { workspace = true } tokenizers = { workspace = true } stop-words = { version = "0.1.0", path = "../stop-words" } +tokio-util = { workspace = true } diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index d953f2e..9325c5b 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -6,6 +6,7 @@ use ffi::create_engine; use stop_words::StopWords; use tabby_inference::{TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; +use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "llama")] mod ffi { @@ -31,7 +32,7 @@ pub struct LlamaEngineOptions { } pub struct LlamaEngine { - engine: Mutex>, + engine: Arc>>, tokenizer: Arc, stop_words: StopWords, } @@ -39,7 +40,7 @@ pub struct LlamaEngine { impl LlamaEngine { pub fn create(options: LlamaEngineOptions) -> Self { LlamaEngine { - engine: Mutex::new(create_engine(&options.model_path)), + engine: Arc::new(Mutex::new(create_engine(&options.model_path))), tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), stop_words: StopWords::default(), } @@ -49,24 +50,40 @@ impl LlamaEngine { #[async_trait] impl TextGeneration for LlamaEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let engine = self.engine.lock().unwrap(); - let mut next_token_id = engine.start(prompt); - let mut n_remains = options.max_decoding_length - 1; - let mut output_ids = vec![next_token_id]; + let cancel = CancellationToken::new(); + let cancel_for_inference = cancel.clone(); + let _guard = cancel.drop_guard(); + let prompt = prompt.to_owned(); + let engine = self.engine.clone(); let mut stop_condition = self .stop_words .create_condition(self.tokenizer.clone(), options.stop_words); - // FIXME(meng): supports cancellation. - while n_remains > 0 { - next_token_id = engine.step(next_token_id); - if stop_condition.next_token(next_token_id) { - break; + let output_ids = tokio::task::spawn_blocking(move || { + let engine = engine.lock().unwrap(); + let mut next_token_id = engine.start(&prompt); + let mut n_remains = options.max_decoding_length - 1; + let mut output_ids = vec![next_token_id]; + + while n_remains > 0 { + if cancel_for_inference.is_cancelled() { + // The token was cancelled + break + } + + next_token_id = engine.step(next_token_id); + if stop_condition.next_token(next_token_id) { + break; + } + output_ids.push(next_token_id); + n_remains -= 1; } - output_ids.push(next_token_id); - n_remains -= 1; - } + + output_ids + }) + .await + .expect("Inference failed"); self.tokenizer.decode(&output_ids, true).unwrap() }