From 3c7af24047133a35da90f12543f1b9fa0237af29 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 9 Oct 2023 16:30:46 -0700 Subject: [PATCH] fix: switch ctranslate2 to synchornous implementation --- crates/ctranslate2-bindings/src/lib.rs | 75 ++++++++++---------------- 1 file changed, 29 insertions(+), 46 deletions(-) diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 25ce843..54a89bd 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,15 +1,10 @@ use std::sync::Arc; -use async_stream::stream; use async_trait::async_trait; use derive_builder::Builder; use futures::stream::BoxStream; -use tabby_inference::{ - decoding::{DecodingFactory, IncrementalDecoding}, - helpers, TextGeneration, TextGenerationOptions, -}; +use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, decoding::{DecodingFactory, IncrementalDecoding}}; use tokenizers::tokenizer::Tokenizer; -use tokio::sync::mpsc::{channel, Sender}; use tokio_util::sync::CancellationToken; #[cxx::bridge(namespace = "tabby")] @@ -72,19 +67,13 @@ pub struct CTranslate2EngineOptions { } pub struct InferenceContext { - sender: Sender, decoding: IncrementalDecoding, cancel: CancellationToken, } impl InferenceContext { - fn new( - sender: Sender, - decoding: IncrementalDecoding, - cancel: CancellationToken, - ) -> Self { + fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self { InferenceContext { - sender, decoding, cancel, } @@ -119,8 +108,31 @@ impl CTranslate2Engine { #[async_trait] impl TextGeneration for CTranslate2Engine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let s = self.generate_stream(prompt, options).await; - helpers::stream_to_string(s).await + let encoding = self.tokenizer.encode(prompt, true).unwrap(); + let engine = self.engine.clone(); + + let cancel = CancellationToken::new(); + let cancel_for_inference = cancel.clone(); + let _guard = cancel.drop_guard(); + + let decoding = self + .decoding_factory + .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); + + let context = InferenceContext::new(decoding, cancel_for_inference); + let output_ids = tokio::task::spawn_blocking(move || { + let context = Box::new(context); + engine.inference( + context, + inference_callback, + truncate_tokens(encoding.get_tokens(), options.max_input_length), + options.max_decoding_length, + options.sampling_temperature, + ) + }) + .await + .expect("Inference failed"); + self.tokenizer.decode(&output_ids, true).unwrap() } async fn generate_stream( @@ -128,35 +140,7 @@ impl TextGeneration for CTranslate2Engine { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - let encoding = self.tokenizer.encode(prompt, true).unwrap(); - let engine = self.engine.clone(); - let s = stream! { - let cancel = CancellationToken::new(); - let cancel_for_inference = cancel.clone(); - let _guard = cancel.drop_guard(); - - let decoding = self - .decoding_factory - .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); - - let (sender, mut receiver) = channel::(8); - let context = InferenceContext::new(sender, decoding, cancel_for_inference); - tokio::task::spawn(async move { - let context = Box::new(context); - engine.inference( - context, - inference_callback, - truncate_tokens(encoding.get_tokens(), options.max_input_length), - options.max_decoding_length, - options.sampling_temperature, - ); - }); - - while let Some(text) = receiver.recv().await { - yield text; - } - }; - Box::pin(s) + helpers::string_to_stream(self.generate(prompt, options).await).await } } @@ -177,8 +161,7 @@ fn inference_callback( ) -> bool { if context.cancel.is_cancelled() { true - } else if let Some(new_text) = context.decoding.next_token(token_id) { - let _ = context.sender.blocking_send(new_text); + } else if let Some(_) = context.decoding.next_token(token_id) { false } else { true