diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index c45aa01..8cb1a62 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -121,29 +121,29 @@ impl TextGeneration for CTranslate2Engine { options: TextGenerationOptions, ) -> BoxStream { let encoding = self.tokenizer.encode(prompt, true).unwrap(); + let decoding = self.decoding_factory.create_incremental_decoding( + self.tokenizer.clone(), + truncate_tokens(encoding.get_ids(), options.max_input_length), + options.stop_words, + ); + + let cancel = CancellationToken::new(); let engine = self.engine.clone(); + let (sender, mut receiver) = channel::(8); + let context = InferenceContext::new(sender, decoding, cancel.clone()); + tokio::task::spawn_blocking(move || { + let context = Box::new(context); + engine.inference( + context, + inference_callback, + truncate_tokens(encoding.get_tokens(), options.max_input_length), + options.max_decoding_length, + options.sampling_temperature, + ); + }); + let s = stream! { - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - - let decoding = self - .decoding_factory - .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); - - let (sender, mut receiver) = channel::(8); - let context = InferenceContext::new(sender, decoding, cancel_for_inference); - tokio::task::spawn(async move { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ); - }); - while let Some(text) = receiver.recv().await { yield text; }