fix: switch ctranslate2 to synchornous implementation

Meng Zhang 2023-10-09 16:30:46 -07:00
parent 2d5b3e4ff5
commit 3c7af24047
1 changed files with 29 additions and 46 deletions

View File

@ -1,15 +1,10 @@
use std::sync::Arc;
use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
helpers, TextGeneration, TextGenerationOptions,
};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, decoding::{DecodingFactory, IncrementalDecoding}};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
#[cxx::bridge(namespace = "tabby")]
@ -72,19 +67,13 @@ pub struct CTranslate2EngineOptions {
}
pub struct InferenceContext {
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
}
impl InferenceContext {
fn new(
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
) -> Self {
fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self {
InferenceContext {
sender,
decoding,
cancel,
}
@ -119,8 +108,31 @@ impl CTranslate2Engine {
#[async_trait]
impl TextGeneration for CTranslate2Engine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let s = self.generate_stream(prompt, options).await;
helpers::stream_to_string(s).await
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone();
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let context = InferenceContext::new(decoding, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
)
})
.await
.expect("Inference failed");
self.tokenizer.decode(&output_ids, true).unwrap()
}
async fn generate_stream(
@ -128,35 +140,7 @@ impl TextGeneration for CTranslate2Engine {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone();
let s = stream! {
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
tokio::task::spawn(async move {
let context = Box::new(context);
engine.inference(
context,
inference_callback,
truncate_tokens(encoding.get_tokens(), options.max_input_length),
options.max_decoding_length,
options.sampling_temperature,
);
});
while let Some(text) = receiver.recv().await {
yield text;
}
};
Box::pin(s)
helpers::string_to_stream(self.generate(prompt, options).await).await
}
}
@ -177,8 +161,7 @@ fn inference_callback(
) -> bool {
if context.cancel.is_cancelled() {
true
} else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.blocking_send(new_text);
} else if let Some(_) = context.decoding.next_token(token_id) {
false
} else {
true