fix: should use spawn_blocking for workload without `yield`

r0.2
Meng Zhang 2023-10-09 16:30:46 -07:00
parent 2d5b3e4ff5
commit b69ad9a532
5 changed files with 25 additions and 25 deletions

4
Cargo.lock generated
View File

@ -690,7 +690,7 @@ dependencies = [
[[package]] [[package]]
name = "ctranslate2-bindings" name = "ctranslate2-bindings"
version = "0.2.0" version = "0.2.2-rc.4"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@ -3086,7 +3086,7 @@ dependencies = [
[[package]] [[package]]
name = "tabby" name = "tabby"
version = "0.2.1" version = "0.2.2-rc.4"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",

View File

@ -13,7 +13,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.2.1" version = "0.2.2-rc.4"
edition = "2021" edition = "2021"
authors = ["Meng Zhang"] authors = ["Meng Zhang"]
homepage = "https://github.com/TabbyML/tabby" homepage = "https://github.com/TabbyML/tabby"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "ctranslate2-bindings" name = "ctranslate2-bindings"
version = "0.2.0" version = "0.2.2-rc.4"
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View File

@ -129,19 +129,17 @@ impl TextGeneration for CTranslate2Engine {
options: TextGenerationOptions, options: TextGenerationOptions,
) -> BoxStream<String> { ) -> BoxStream<String> {
let encoding = self.tokenizer.encode(prompt, true).unwrap(); let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone(); let decoding = self.decoding_factory.create_incremental_decoding(
let s = stream! { self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
);
let cancel = CancellationToken::new(); let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone(); let engine = self.engine.clone();
let _guard = cancel.drop_guard();
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let (sender, mut receiver) = channel::<String>(8); let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference); let context = InferenceContext::new(sender, decoding, cancel.clone());
tokio::task::spawn(async move { tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
context, context,
@ -152,6 +150,8 @@ impl TextGeneration for CTranslate2Engine {
); );
}); });
let s = stream! {
let _guard = cancel.drop_guard();
while let Some(text) = receiver.recv().await { while let Some(text) = receiver.recv().await {
yield text; yield text;
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "tabby" name = "tabby"
version = "0.2.1" version = "0.2.2-rc.4"
edition = "2021" edition = "2021"
[dependencies] [dependencies]