From b69ad9a532faf1f3e4a256db214156eb4ad26abc Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 9 Oct 2023 16:30:46 -0700 Subject: [PATCH] fix: should use spawn_blocking for workload without `yield` --- Cargo.lock | 4 +-- Cargo.toml | 2 +- crates/ctranslate2-bindings/Cargo.toml | 2 +- crates/ctranslate2-bindings/src/lib.rs | 40 +++++++++++++------------- crates/tabby/Cargo.toml | 2 +- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1da931..4716f40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -690,7 +690,7 @@ dependencies = [ [[package]] name = "ctranslate2-bindings" -version = "0.2.0" +version = "0.2.2-rc.4" dependencies = [ "async-stream", "async-trait", @@ -3086,7 +3086,7 @@ dependencies = [ [[package]] name = "tabby" -version = "0.2.1" +version = "0.2.2-rc.4" dependencies = [ "anyhow", "async-stream", diff --git a/Cargo.toml b/Cargo.toml index d8ea93c..f62f70f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ ] [workspace.package] -version = "0.2.1" +version = "0.2.2-rc.4" edition = "2021" authors = ["Meng Zhang"] homepage = "https://github.com/TabbyML/tabby" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 922b9a4..9864674 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ctranslate2-bindings" -version = "0.2.0" +version = "0.2.2-rc.4" edition = "2021" [dependencies] diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 25ce843..f04f8ad 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -129,29 +129,29 @@ impl TextGeneration for CTranslate2Engine { options: TextGenerationOptions, ) -> BoxStream { let encoding = self.tokenizer.encode(prompt, true).unwrap(); + let decoding = self.decoding_factory.create_incremental_decoding( + self.tokenizer.clone(), + truncate_tokens(encoding.get_ids(), options.max_input_length), + options.stop_words, + ); + + let cancel = CancellationToken::new(); let engine = self.engine.clone(); + let (sender, mut receiver) = channel::(8); + let context = InferenceContext::new(sender, decoding, cancel.clone()); + tokio::task::spawn_blocking(move || { + let context = Box::new(context); + engine.inference( + context, + inference_callback, + truncate_tokens(encoding.get_tokens(), options.max_input_length), + options.max_decoding_length, + options.sampling_temperature, + ); + }); + let s = stream! { - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - - let decoding = self - .decoding_factory - .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); - - let (sender, mut receiver) = channel::(8); - let context = InferenceContext::new(sender, decoding, cancel_for_inference); - tokio::task::spawn(async move { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ); - }); - while let Some(text) = receiver.recv().await { yield text; } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index ed2391d..60e4882 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tabby" -version = "0.2.1" +version = "0.2.2-rc.4" edition = "2021" [dependencies]