fix: switch ctranslate2 to synchornous implementation
parent
2d5b3e4ff5
commit
3c7af24047
|
|
@ -1,15 +1,10 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
use tabby_inference::{
|
||||
decoding::{DecodingFactory, IncrementalDecoding},
|
||||
helpers, TextGeneration, TextGenerationOptions,
|
||||
};
|
||||
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions, decoding::{DecodingFactory, IncrementalDecoding}};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc::{channel, Sender};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[cxx::bridge(namespace = "tabby")]
|
||||
|
|
@ -72,19 +67,13 @@ pub struct CTranslate2EngineOptions {
|
|||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
) -> Self {
|
||||
fn new(decoding: IncrementalDecoding, cancel: CancellationToken) -> Self {
|
||||
InferenceContext {
|
||||
sender,
|
||||
decoding,
|
||||
cancel,
|
||||
}
|
||||
|
|
@ -119,8 +108,31 @@ impl CTranslate2Engine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for CTranslate2Engine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let s = self.generate_stream(prompt, options).await;
|
||||
helpers::stream_to_string(s).await
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let engine = self.engine.clone();
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let decoding = self
|
||||
.decoding_factory
|
||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||
|
||||
let context = InferenceContext::new(decoding, cancel_for_inference);
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
inference_callback,
|
||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
|
|
@ -128,35 +140,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let engine = self.engine.clone();
|
||||
let s = stream! {
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let decoding = self
|
||||
.decoding_factory
|
||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||
|
||||
let (sender, mut receiver) = channel::<String>(8);
|
||||
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||
tokio::task::spawn(async move {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
inference_callback,
|
||||
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
);
|
||||
});
|
||||
|
||||
while let Some(text) = receiver.recv().await {
|
||||
yield text;
|
||||
}
|
||||
};
|
||||
Box::pin(s)
|
||||
helpers::string_to_stream(self.generate(prompt, options).await).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -177,8 +161,7 @@ fn inference_callback(
|
|||
) -> bool {
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(new_text) = context.decoding.next_token(token_id) {
|
||||
let _ = context.sender.blocking_send(new_text);
|
||||
} else if let Some(_) = context.decoding.next_token(token_id) {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
|
|
|
|||
Loading…
Reference in New Issue