From 64e0f9283717f2cb75eefb3cc37535db50170148 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 9 Oct 2023 18:10:25 -0700 Subject: [PATCH] feat: add back streaming for ctranslate2 --- crates/ctranslate2-bindings/src/lib.rs | 79 ++++++++++++++++---------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 0aafd8e..2cea705 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use async_stream::stream; use async_trait::async_trait; use derive_builder::Builder; use futures::stream::BoxStream; @@ -8,6 +9,7 @@ use tabby_inference::{ helpers, TextGeneration, TextGenerationOptions, }; use tokenizers::tokenizer::Tokenizer; +use tokio::sync::mpsc::{channel, Sender}; use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "tabby")] @@ -70,13 +72,22 @@ pub struct CTranslate2EngineOptions { } pub struct InferenceContext { + sender: Sender, decoding: IncrementalDecoding, cancel: CancellationToken, } impl InferenceContext { - fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self { - InferenceContext { decoding, cancel } + fn new( + sender: Sender, + decoding: IncrementalDecoding, + cancel: CancellationToken, + ) -> Self { + InferenceContext { + sender, + decoding, + cancel, + } } } @@ -108,33 +119,8 @@ impl CTranslate2Engine { #[async_trait] impl TextGeneration for CTranslate2Engine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let encoding = self.tokenizer.encode(prompt, true).unwrap(); - let engine = self.engine.clone(); - - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); - let _guard = cancel.drop_guard(); - - let decoding = self.decoding_factory.create_incremental_decoding( - self.tokenizer.clone(), - truncate_tokens(encoding.get_ids(), options.max_input_length), - options.stop_words, - ); - - let context = InferenceContext::new(decoding, cancel_for_inference); - let output_ids = tokio::task::spawn_blocking(move || { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ) - }) - .await - .expect("Inference failed"); - self.tokenizer.decode(&output_ids, true).unwrap() + let s = self.generate_stream(prompt, options).await; + helpers::stream_to_string(s).await } async fn generate_stream( @@ -142,7 +128,35 @@ impl TextGeneration for CTranslate2Engine { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - helpers::string_to_stream(self.generate(prompt, options).await).await + let encoding = self.tokenizer.encode(prompt, true).unwrap(); + let decoding = self.decoding_factory.create_incremental_decoding( + self.tokenizer.clone(), + truncate_tokens(encoding.get_ids(), options.max_input_length), + options.stop_words, + ); + + let cancel = CancellationToken::new(); + let engine = self.engine.clone(); + let (sender, mut receiver) = channel::(8); + let context = InferenceContext::new(sender, decoding, cancel.clone()); + tokio::task::spawn(async move { + let context = Box::new(context); + engine.inference( + context, + inference_callback, + truncate_tokens(encoding.get_tokens(), options.max_input_length), + options.max_decoding_length, + options.sampling_temperature, + ); + }); + + let s = stream! { + let _guard = cancel.drop_guard(); + while let Some(text) = receiver.recv().await { + yield text; + } + }; + Box::pin(s) } } @@ -163,7 +177,10 @@ fn inference_callback( ) -> bool { if context.cancel.is_cancelled() { true + } else if let Some(new_text) = context.decoding.next_token(token_id) { + let _ = context.sender.blocking_send(new_text); + false } else { - !context.decoding.next_token(token_id).is_some() + true } }