2023-09-03 01:59:07 +00:00
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use derive_builder::Builder;
|
|
|
|
|
use ffi::create_engine;
|
|
|
|
|
use stop_words::StopWords;
|
|
|
|
|
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
|
|
|
|
use tokenizers::tokenizer::Tokenizer;
|
2023-09-03 02:15:54 +00:00
|
|
|
use tokio_util::sync::CancellationToken;
|
2023-09-03 01:59:07 +00:00
|
|
|
|
|
|
|
|
#[cxx::bridge(namespace = "llama")]
|
|
|
|
|
mod ffi {
|
|
|
|
|
unsafe extern "C++" {
|
|
|
|
|
include!("llama-cpp-bindings/include/engine.h");
|
|
|
|
|
|
|
|
|
|
type TextInferenceEngine;
|
|
|
|
|
|
|
|
|
|
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
|
|
|
|
|
|
|
|
|
fn start(&self, prompt: &str) -> u32;
|
|
|
|
|
fn step(&self, next_token_id: u32) -> u32;
|
2023-09-05 02:14:29 +00:00
|
|
|
fn end(&self);
|
|
|
|
|
|
|
|
|
|
fn eos_token(&self) -> u32;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
|
|
|
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
|
|
|
|
|
|
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct LlamaEngineOptions {
|
|
|
|
|
model_path: String,
|
|
|
|
|
tokenizer_path: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct LlamaEngine {
|
2023-09-03 02:15:54 +00:00
|
|
|
engine: Arc<Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>>,
|
2023-09-03 01:59:07 +00:00
|
|
|
tokenizer: Arc<Tokenizer>,
|
|
|
|
|
stop_words: StopWords,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl LlamaEngine {
|
|
|
|
|
pub fn create(options: LlamaEngineOptions) -> Self {
|
|
|
|
|
LlamaEngine {
|
2023-09-03 02:15:54 +00:00
|
|
|
engine: Arc::new(Mutex::new(create_engine(&options.model_path))),
|
2023-09-03 01:59:07 +00:00
|
|
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
|
|
|
|
stop_words: StopWords::default(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl TextGeneration for LlamaEngine {
|
|
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
2023-09-03 02:15:54 +00:00
|
|
|
let cancel = CancellationToken::new();
|
|
|
|
|
let cancel_for_inference = cancel.clone();
|
|
|
|
|
let _guard = cancel.drop_guard();
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-03 02:15:54 +00:00
|
|
|
let prompt = prompt.to_owned();
|
|
|
|
|
let engine = self.engine.clone();
|
2023-09-03 01:59:07 +00:00
|
|
|
let mut stop_condition = self
|
|
|
|
|
.stop_words
|
|
|
|
|
.create_condition(self.tokenizer.clone(), options.stop_words);
|
|
|
|
|
|
2023-09-03 02:15:54 +00:00
|
|
|
let output_ids = tokio::task::spawn_blocking(move || {
|
|
|
|
|
let engine = engine.lock().unwrap();
|
2023-09-05 02:14:29 +00:00
|
|
|
let eos_token = engine.eos_token();
|
|
|
|
|
|
2023-09-03 02:15:54 +00:00
|
|
|
let mut next_token_id = engine.start(&prompt);
|
2023-09-05 02:14:29 +00:00
|
|
|
if next_token_id == eos_token {
|
|
|
|
|
return Vec::new();
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-03 02:15:54 +00:00
|
|
|
let mut n_remains = options.max_decoding_length - 1;
|
|
|
|
|
let mut output_ids = vec![next_token_id];
|
|
|
|
|
|
|
|
|
|
while n_remains > 0 {
|
|
|
|
|
if cancel_for_inference.is_cancelled() {
|
|
|
|
|
// The token was cancelled
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
next_token_id = engine.step(next_token_id);
|
2023-09-05 02:14:29 +00:00
|
|
|
if next_token_id == eos_token {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-03 02:15:54 +00:00
|
|
|
if stop_condition.next_token(next_token_id) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
output_ids.push(next_token_id);
|
|
|
|
|
n_remains -= 1;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
2023-09-03 02:15:54 +00:00
|
|
|
|
2023-09-05 02:14:29 +00:00
|
|
|
engine.end();
|
2023-09-03 02:15:54 +00:00
|
|
|
output_ids
|
|
|
|
|
})
|
|
|
|
|
.await
|
|
|
|
|
.expect("Inference failed");
|
2023-09-03 01:59:07 +00:00
|
|
|
self.tokenizer.decode(&output_ids, true).unwrap()
|
|
|
|
|
}
|
|
|
|
|
}
|