2023-09-28 17:20:50 +00:00
|
|
|
use std::sync::Arc;
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
use async_stream::stream;
|
2023-09-03 01:59:07 +00:00
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use derive_builder::Builder;
|
|
|
|
|
use ffi::create_engine;
|
2023-09-28 17:20:50 +00:00
|
|
|
use futures::{lock::Mutex, stream::BoxStream};
|
2023-09-29 13:06:47 +00:00
|
|
|
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
|
2023-09-03 01:59:07 +00:00
|
|
|
use tokenizers::tokenizer::Tokenizer;
|
|
|
|
|
|
|
|
|
|
#[cxx::bridge(namespace = "llama")]
|
|
|
|
|
mod ffi {
|
|
|
|
|
unsafe extern "C++" {
|
|
|
|
|
include!("llama-cpp-bindings/include/engine.h");
|
|
|
|
|
|
|
|
|
|
type TextInferenceEngine;
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
|
|
|
|
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
|
|
|
|
|
fn end(self: Pin<&mut TextInferenceEngine>);
|
2023-09-05 02:14:29 +00:00
|
|
|
|
|
|
|
|
fn eos_token(&self) -> u32;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
|
|
|
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
|
|
|
|
|
|
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct LlamaEngineOptions {
|
|
|
|
|
model_path: String,
|
|
|
|
|
tokenizer_path: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct LlamaEngine {
|
2023-09-30 15:37:36 +00:00
|
|
|
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
2023-09-03 01:59:07 +00:00
|
|
|
tokenizer: Arc<Tokenizer>,
|
2023-09-29 13:06:47 +00:00
|
|
|
decoding_factory: DecodingFactory,
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl LlamaEngine {
|
|
|
|
|
pub fn create(options: LlamaEngineOptions) -> Self {
|
2023-10-02 05:25:25 +00:00
|
|
|
let engine = create_engine(&options.model_path);
|
|
|
|
|
if engine.is_null() {
|
|
|
|
|
panic!("Unable to load model: {}", options.model_path);
|
|
|
|
|
}
|
2023-09-03 01:59:07 +00:00
|
|
|
LlamaEngine {
|
2023-10-02 05:25:25 +00:00
|
|
|
engine: Mutex::new(engine),
|
2023-09-03 01:59:07 +00:00
|
|
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
2023-09-29 13:06:47 +00:00
|
|
|
decoding_factory: DecodingFactory::default(),
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl TextGeneration for LlamaEngine {
|
|
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
2023-09-28 17:20:50 +00:00
|
|
|
let s = self.generate_stream(prompt, options).await;
|
|
|
|
|
helpers::stream_to_string(s).await
|
|
|
|
|
}
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
options: TextGenerationOptions,
|
|
|
|
|
) -> BoxStream<String> {
|
2023-09-29 13:06:47 +00:00
|
|
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
let s = stream! {
|
2023-09-30 15:37:36 +00:00
|
|
|
let mut engine = self.engine.lock().await;
|
|
|
|
|
let mut engine = engine.as_mut().unwrap();
|
2023-09-05 02:14:29 +00:00
|
|
|
let eos_token = engine.eos_token();
|
|
|
|
|
|
2023-09-29 13:06:47 +00:00
|
|
|
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
2023-09-30 15:37:36 +00:00
|
|
|
engine.as_mut().start(input_token_ids);
|
2023-09-29 18:21:57 +00:00
|
|
|
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
2023-09-29 13:06:47 +00:00
|
|
|
let mut n_remains = options.max_decoding_length ;
|
|
|
|
|
while n_remains > 0 {
|
2023-09-30 15:37:36 +00:00
|
|
|
let next_token_id = engine.as_mut().step();
|
2023-09-29 13:06:47 +00:00
|
|
|
if next_token_id == eos_token {
|
|
|
|
|
break;
|
2023-09-03 02:15:54 +00:00
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
|
|
|
|
|
if let Some(new_text) = decoding.next_token(next_token_id) {
|
|
|
|
|
yield new_text;
|
|
|
|
|
} else {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
n_remains -= 1;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
2023-09-03 02:15:54 +00:00
|
|
|
|
2023-09-05 02:14:29 +00:00
|
|
|
engine.end();
|
2023-09-28 17:20:50 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Box::pin(s)
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
|
|
|
|
|
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
|
|
|
|
|
if max_length < tokens.len() {
|
|
|
|
|
let start = tokens.len() - max_length;
|
|
|
|
|
&tokens[start..]
|
|
|
|
|
} else {
|
|
|
|
|
tokens
|
|
|
|
|
}
|
|
|
|
|
}
|