feat: add back streaming for ctranslate2
parent
c8d9b4d9ce
commit
64e0f92837
|
|
@ -1,5 +1,6 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
|
|
@ -8,6 +9,7 @@ use tabby_inference::{
|
||||||
helpers, TextGeneration, TextGenerationOptions,
|
helpers, TextGeneration, TextGenerationOptions,
|
||||||
};
|
};
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokio::sync::mpsc::{channel, Sender};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "tabby")]
|
#[cxx::bridge(namespace = "tabby")]
|
||||||
|
|
@ -70,13 +72,22 @@ pub struct CTranslate2EngineOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
|
sender: Sender<String>,
|
||||||
decoding: IncrementalDecoding,
|
decoding: IncrementalDecoding,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceContext {
|
impl InferenceContext {
|
||||||
fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self {
|
fn new(
|
||||||
InferenceContext { decoding, cancel }
|
sender: Sender<String>,
|
||||||
|
decoding: IncrementalDecoding,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> Self {
|
||||||
|
InferenceContext {
|
||||||
|
sender,
|
||||||
|
decoding,
|
||||||
|
cancel,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -108,33 +119,8 @@ impl CTranslate2Engine {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for CTranslate2Engine {
|
impl TextGeneration for CTranslate2Engine {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
let s = self.generate_stream(prompt, options).await;
|
||||||
let engine = self.engine.clone();
|
helpers::stream_to_string(s).await
|
||||||
|
|
||||||
let cancel = CancellationToken::new();
|
|
||||||
let cancel_for_inference = cancel.clone();
|
|
||||||
let _guard = cancel.drop_guard();
|
|
||||||
|
|
||||||
let decoding = self.decoding_factory.create_incremental_decoding(
|
|
||||||
self.tokenizer.clone(),
|
|
||||||
truncate_tokens(encoding.get_ids(), options.max_input_length),
|
|
||||||
options.stop_words,
|
|
||||||
);
|
|
||||||
|
|
||||||
let context = InferenceContext::new(decoding, cancel_for_inference);
|
|
||||||
let output_ids = tokio::task::spawn_blocking(move || {
|
|
||||||
let context = Box::new(context);
|
|
||||||
engine.inference(
|
|
||||||
context,
|
|
||||||
inference_callback,
|
|
||||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
|
||||||
options.max_decoding_length,
|
|
||||||
options.sampling_temperature,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.expect("Inference failed");
|
|
||||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
|
|
@ -142,7 +128,35 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
options: TextGenerationOptions,
|
options: TextGenerationOptions,
|
||||||
) -> BoxStream<String> {
|
) -> BoxStream<String> {
|
||||||
helpers::string_to_stream(self.generate(prompt, options).await).await
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
|
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||||
|
self.tokenizer.clone(),
|
||||||
|
truncate_tokens(encoding.get_ids(), options.max_input_length),
|
||||||
|
options.stop_words,
|
||||||
|
);
|
||||||
|
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
let engine = self.engine.clone();
|
||||||
|
let (sender, mut receiver) = channel::<String>(8);
|
||||||
|
let context = InferenceContext::new(sender, decoding, cancel.clone());
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let context = Box::new(context);
|
||||||
|
engine.inference(
|
||||||
|
context,
|
||||||
|
inference_callback,
|
||||||
|
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||||
|
options.max_decoding_length,
|
||||||
|
options.sampling_temperature,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
let s = stream! {
|
||||||
|
let _guard = cancel.drop_guard();
|
||||||
|
while let Some(text) = receiver.recv().await {
|
||||||
|
yield text;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Box::pin(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -163,7 +177,10 @@ fn inference_callback(
|
||||||
) -> bool {
|
) -> bool {
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
|
} else if let Some(new_text) = context.decoding.next_token(token_id) {
|
||||||
|
let _ = context.sender.blocking_send(new_text);
|
||||||
|
false
|
||||||
} else {
|
} else {
|
||||||
!context.decoding.next_token(token_id).is_some()
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue