feat: support cancellation in llama backend
parent
3573d4378e
commit
3f7aa99b0d
|
|
@ -1565,6 +1565,7 @@ dependencies = [
|
|||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -15,3 +15,4 @@ tabby-inference = { path = "../tabby-inference" }
|
|||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
stop-words = { version = "0.1.0", path = "../stop-words" }
|
||||
tokio-util = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use ffi::create_engine;
|
|||
use stop_words::StopWords;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[cxx::bridge(namespace = "llama")]
|
||||
mod ffi {
|
||||
|
|
@ -31,7 +32,7 @@ pub struct LlamaEngineOptions {
|
|||
}
|
||||
|
||||
pub struct LlamaEngine {
|
||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||
engine: Arc<Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_words: StopWords,
|
||||
}
|
||||
|
|
@ -39,7 +40,7 @@ pub struct LlamaEngine {
|
|||
impl LlamaEngine {
|
||||
pub fn create(options: LlamaEngineOptions) -> Self {
|
||||
LlamaEngine {
|
||||
engine: Mutex::new(create_engine(&options.model_path)),
|
||||
engine: Arc::new(Mutex::new(create_engine(&options.model_path))),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
stop_words: StopWords::default(),
|
||||
}
|
||||
|
|
@ -49,24 +50,40 @@ impl LlamaEngine {
|
|||
#[async_trait]
|
||||
impl TextGeneration for LlamaEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let engine = self.engine.lock().unwrap();
|
||||
let mut next_token_id = engine.start(prompt);
|
||||
let mut n_remains = options.max_decoding_length - 1;
|
||||
let mut output_ids = vec![next_token_id];
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let prompt = prompt.to_owned();
|
||||
let engine = self.engine.clone();
|
||||
let mut stop_condition = self
|
||||
.stop_words
|
||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||
|
||||
// FIXME(meng): supports cancellation.
|
||||
while n_remains > 0 {
|
||||
next_token_id = engine.step(next_token_id);
|
||||
if stop_condition.next_token(next_token_id) {
|
||||
break;
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let engine = engine.lock().unwrap();
|
||||
let mut next_token_id = engine.start(&prompt);
|
||||
let mut n_remains = options.max_decoding_length - 1;
|
||||
let mut output_ids = vec![next_token_id];
|
||||
|
||||
while n_remains > 0 {
|
||||
if cancel_for_inference.is_cancelled() {
|
||||
// The token was cancelled
|
||||
break
|
||||
}
|
||||
|
||||
next_token_id = engine.step(next_token_id);
|
||||
if stop_condition.next_token(next_token_id) {
|
||||
break;
|
||||
}
|
||||
output_ids.push(next_token_id);
|
||||
n_remains -= 1;
|
||||
}
|
||||
output_ids.push(next_token_id);
|
||||
n_remains -= 1;
|
||||
}
|
||||
|
||||
output_ids
|
||||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
|
||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue