feat: add back streaming for ctranslate2

Meng Zhang 2023-10-09 18:10:25 -07:00
parent c8d9b4d9ce
commit 64e0f92837
1 changed files with 48 additions and 31 deletions

View File

@ -1,5 +1,6 @@
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
@ -8,6 +9,7 @@ use tabby_inference::{
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "tabby")]
@ -70,13 +72,22 @@ pub struct CTranslate2EngineOptions {
}
pub struct InferenceContext {
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
}
impl InferenceContext {
fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self {
InferenceContext { decoding, cancel }
fn new(
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
) -> Self {
InferenceContext {
sender,
decoding,
cancel,
}
}
}
@ -108,33 +119,8 @@ impl CTranslate2Engine {
#[async_trait]
impl TextGeneration for CTranslate2Engine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone();
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
);
let context = InferenceContext::new(decoding, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
)
})
.await
.expect("Inference failed");
self.tokenizer.decode(&output_ids, true).unwrap()
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
}
async fn generate_stream(
@ -142,7 +128,35 @@ impl TextGeneration for CTranslate2Engine {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
helpers::string_to_stream(self.generate(prompt, options).await).await
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
);
let cancel = CancellationToken::new();
let engine = self.engine.clone();
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel.clone());
tokio::task::spawn(async move {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
);
});
let s = stream! {
let _guard = cancel.drop_guard();
while let Some(text) = receiver.recv().await {
yield text;
}
};
Box::pin(s)
}
}
@ -163,7 +177,10 @@ fn inference_callback(
) -> bool {
if context.cancel.is_cancelled() {
true
} else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.blocking_send(new_text);
false
} else {
!context.decoding.next_token(token_id).is_some()
true
}
}