2023-10-29 06:37:05 +00:00
|
|
|
use std::{collections::HashMap, sync::Arc};
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
use async_stream::stream;
|
2023-09-03 01:59:07 +00:00
|
|
|
use async_trait::async_trait;
|
2023-10-29 06:37:05 +00:00
|
|
|
use cxx::UniquePtr;
|
2023-09-03 01:59:07 +00:00
|
|
|
use derive_builder::Builder;
|
|
|
|
|
use ffi::create_engine;
|
2023-09-28 17:20:50 +00:00
|
|
|
use futures::{lock::Mutex, stream::BoxStream};
|
2023-10-29 06:37:05 +00:00
|
|
|
use tabby_inference::{
|
2023-10-31 22:16:09 +00:00
|
|
|
decoding::{StopCondition, StopConditionFactory},
|
2023-10-29 06:37:05 +00:00
|
|
|
helpers, TextGeneration, TextGenerationOptions,
|
|
|
|
|
};
|
|
|
|
|
use tokio::{
|
|
|
|
|
sync::mpsc::{channel, Sender},
|
|
|
|
|
task::yield_now,
|
|
|
|
|
};
|
2023-09-03 01:59:07 +00:00
|
|
|
|
|
|
|
|
#[cxx::bridge(namespace = "llama")]
|
|
|
|
|
mod ffi {
|
2023-10-31 22:16:09 +00:00
|
|
|
struct StepOutput {
|
|
|
|
|
request_id: u32,
|
|
|
|
|
text: String,
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-03 01:59:07 +00:00
|
|
|
unsafe extern "C++" {
|
|
|
|
|
include!("llama-cpp-bindings/include/engine.h");
|
|
|
|
|
|
|
|
|
|
type TextInferenceEngine;
|
|
|
|
|
|
2023-10-25 22:40:11 +00:00
|
|
|
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
fn add_request(
|
|
|
|
|
self: Pin<&mut TextInferenceEngine>,
|
|
|
|
|
request_id: u32,
|
2023-10-31 22:16:09 +00:00
|
|
|
prompt: &str,
|
|
|
|
|
max_input_length: usize,
|
2023-10-29 06:37:05 +00:00
|
|
|
);
|
|
|
|
|
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
|
2023-10-31 22:16:09 +00:00
|
|
|
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<StepOutput>>;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
|
|
|
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
|
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
struct InferenceRequest {
|
|
|
|
|
tx: Sender<String>,
|
2023-10-31 22:16:09 +00:00
|
|
|
stop_condition: StopCondition,
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
struct AsyncTextInferenceEngine {
|
2023-09-30 15:37:36 +00:00
|
|
|
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
2023-10-31 22:16:09 +00:00
|
|
|
stop_condition_factory: StopConditionFactory,
|
2023-10-29 06:37:05 +00:00
|
|
|
requests: Mutex<HashMap<u32, InferenceRequest>>,
|
|
|
|
|
|
|
|
|
|
next_request_id: Mutex<u32>,
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
impl AsyncTextInferenceEngine {
|
2023-10-31 22:16:09 +00:00
|
|
|
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
|
2023-10-29 06:37:05 +00:00
|
|
|
Self {
|
2023-10-02 05:25:25 +00:00
|
|
|
engine: Mutex::new(engine),
|
2023-10-31 22:16:09 +00:00
|
|
|
stop_condition_factory: StopConditionFactory::default(),
|
2023-10-29 06:37:05 +00:00
|
|
|
requests: Mutex::new(HashMap::new()),
|
|
|
|
|
next_request_id: Mutex::new(0),
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
async fn background_job(&self) {
|
|
|
|
|
let mut requests = self.requests.lock().await;
|
|
|
|
|
if requests.len() == 0 {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut engine = self.engine.lock().await;
|
|
|
|
|
|
2023-11-02 23:15:06 +00:00
|
|
|
let result = match engine.as_mut().unwrap().step() {
|
|
|
|
|
Ok(result) => result,
|
|
|
|
|
Err(err) => panic!("Failed to step: {}", err),
|
2023-10-29 06:37:05 +00:00
|
|
|
};
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
for ffi::StepOutput { request_id, text } in result {
|
2023-10-29 06:37:05 +00:00
|
|
|
let mut stopped = false;
|
2023-10-31 22:16:09 +00:00
|
|
|
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();
|
2023-10-29 06:37:05 +00:00
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
if tx.is_closed() || text.is_empty() {
|
2023-10-29 06:37:05 +00:00
|
|
|
// Cancelled by client side or hit eos.
|
|
|
|
|
stopped = true;
|
2023-10-31 22:16:09 +00:00
|
|
|
} else if !stop_condition.should_stop(&text) {
|
|
|
|
|
match tx.send(text).await {
|
2023-10-30 06:27:09 +00:00
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(_) => stopped = true,
|
|
|
|
|
}
|
2023-10-29 06:37:05 +00:00
|
|
|
} else {
|
|
|
|
|
// Stoop words stopped
|
|
|
|
|
stopped = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if stopped {
|
|
|
|
|
requests.remove(&request_id);
|
|
|
|
|
engine.as_mut().unwrap().stop_request(request_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-09-28 17:20:50 +00:00
|
|
|
}
|
2023-09-03 01:59:07 +00:00
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
options: TextGenerationOptions,
|
|
|
|
|
) -> BoxStream<String> {
|
2023-10-31 22:16:09 +00:00
|
|
|
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
|
2023-10-29 06:37:05 +00:00
|
|
|
|
|
|
|
|
let (tx, mut rx) = channel::<String>(4);
|
|
|
|
|
{
|
2023-09-30 15:37:36 +00:00
|
|
|
let mut engine = self.engine.lock().await;
|
2023-10-29 06:37:05 +00:00
|
|
|
|
|
|
|
|
let mut request_id = self.next_request_id.lock().await;
|
|
|
|
|
self.requests
|
|
|
|
|
.lock()
|
|
|
|
|
.await
|
2023-10-31 22:16:09 +00:00
|
|
|
.insert(*request_id, InferenceRequest { tx, stop_condition });
|
|
|
|
|
engine
|
|
|
|
|
.as_mut()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.add_request(*request_id, prompt, options.max_input_length);
|
2023-10-29 06:37:05 +00:00
|
|
|
|
|
|
|
|
// 2048 should be large enough to avoid collision.
|
|
|
|
|
*request_id = (*request_id + 1) % 2048;
|
|
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
let s = stream! {
|
|
|
|
|
let mut length = 0;
|
|
|
|
|
while let Some(new_text) = rx.recv().await {
|
|
|
|
|
yield new_text;
|
|
|
|
|
length += 1;
|
|
|
|
|
if length >= options.max_decoding_length {
|
2023-09-29 13:06:47 +00:00
|
|
|
break;
|
|
|
|
|
}
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
2023-09-03 02:15:54 +00:00
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
rx.close();
|
2023-09-28 17:20:50 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Box::pin(s)
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
|
2023-10-29 06:37:05 +00:00
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct LlamaTextGenerationOptions {
|
|
|
|
|
model_path: String,
|
|
|
|
|
use_gpu: bool,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct LlamaTextGeneration {
|
|
|
|
|
engine: Arc<AsyncTextInferenceEngine>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl LlamaTextGeneration {
|
|
|
|
|
pub fn create(options: LlamaTextGenerationOptions) -> Self {
|
|
|
|
|
let engine = create_engine(options.use_gpu, &options.model_path);
|
|
|
|
|
if engine.is_null() {
|
|
|
|
|
panic!("Unable to load model: {}", options.model_path);
|
|
|
|
|
}
|
|
|
|
|
let ret = LlamaTextGeneration {
|
2023-10-31 22:16:09 +00:00
|
|
|
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
|
2023-10-29 06:37:05 +00:00
|
|
|
};
|
|
|
|
|
ret.start_background_job();
|
|
|
|
|
ret
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn start_background_job(&self) {
|
|
|
|
|
let engine = self.engine.clone();
|
|
|
|
|
tokio::spawn(async move {
|
|
|
|
|
loop {
|
|
|
|
|
engine.background_job().await;
|
|
|
|
|
yield_now().await;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl TextGeneration for LlamaTextGeneration {
|
|
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
|
|
|
|
let s = self.generate_stream(prompt, options).await;
|
|
|
|
|
helpers::stream_to_string(s).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
options: TextGenerationOptions,
|
|
|
|
|
) -> BoxStream<String> {
|
|
|
|
|
self.engine.generate_stream(prompt, options).await
|
|
|
|
|
}
|
|
|
|
|
}
|