fix: deadlock between background job and requests (#720)
* fix: deadlock between background job and requests * refactor: extract LlamaServicerefactor-extract-code
parent
b001816671
commit
1ad0d39903
|
|
@ -1,20 +1,14 @@
|
||||||
|
mod llama;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
use std::{collections::HashMap, sync::Arc};
|
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use cxx::UniquePtr;
|
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use ffi::create_engine;
|
use ffi::create_engine;
|
||||||
use futures::{lock::Mutex, stream::BoxStream};
|
use futures::stream::BoxStream;
|
||||||
|
use llama::LlamaService;
|
||||||
use tabby_inference::{
|
use tabby_inference::{
|
||||||
decoding::{StopCondition, StopConditionFactory},
|
decoding::StopConditionFactory, helpers, TextGeneration, TextGenerationOptions,
|
||||||
helpers, TextGeneration, TextGenerationOptions,
|
|
||||||
};
|
|
||||||
use tokio::{
|
|
||||||
sync::mpsc::{channel, Sender},
|
|
||||||
task::yield_now,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "llama")]
|
#[cxx::bridge(namespace = "llama")]
|
||||||
|
|
@ -45,66 +39,36 @@ mod ffi {
|
||||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||||
|
|
||||||
struct InferenceRequest {
|
#[derive(Builder, Debug)]
|
||||||
tx: Sender<String>,
|
pub struct LlamaTextGenerationOptions {
|
||||||
stop_condition: StopCondition,
|
model_path: String,
|
||||||
|
use_gpu: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AsyncTextInferenceEngine {
|
pub struct LlamaTextGeneration {
|
||||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
service: LlamaService,
|
||||||
stop_condition_factory: StopConditionFactory,
|
stop_condition_factory: StopConditionFactory,
|
||||||
requests: Mutex<HashMap<u32, InferenceRequest>>,
|
|
||||||
|
|
||||||
next_request_id: Mutex<u32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncTextInferenceEngine {
|
impl LlamaTextGeneration {
|
||||||
fn create(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
|
pub fn new(options: LlamaTextGenerationOptions) -> Self {
|
||||||
|
let engine = create_engine(options.use_gpu, &options.model_path);
|
||||||
|
if engine.is_null() {
|
||||||
|
fatal!("Unable to load model: {}", options.model_path);
|
||||||
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
engine: Mutex::new(engine),
|
service: LlamaService::new(engine),
|
||||||
stop_condition_factory: StopConditionFactory::default(),
|
stop_condition_factory: StopConditionFactory::default(),
|
||||||
requests: Mutex::new(HashMap::new()),
|
|
||||||
next_request_id: Mutex::new(0),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn background_job(&self) {
|
#[async_trait]
|
||||||
let mut requests = self.requests.lock().await;
|
impl TextGeneration for LlamaTextGeneration {
|
||||||
if requests.len() == 0 {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
return;
|
let s = self.generate_stream(prompt, options).await;
|
||||||
}
|
helpers::stream_to_string(s).await
|
||||||
|
|
||||||
let mut engine = self.engine.lock().await;
|
|
||||||
|
|
||||||
let result = match engine.as_mut().unwrap().step() {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(err) => {
|
|
||||||
fatal!("Failed to step: {}", err)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for ffi::StepOutput { request_id, text } in result {
|
|
||||||
let mut stopped = false;
|
|
||||||
let InferenceRequest { tx, stop_condition } = requests.get_mut(&request_id).unwrap();
|
|
||||||
|
|
||||||
if tx.is_closed() || text.is_empty() {
|
|
||||||
// Cancelled by client side or hit eos.
|
|
||||||
stopped = true;
|
|
||||||
} else if !stop_condition.should_stop(&text) {
|
|
||||||
match tx.send(text).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(_) => stopped = true,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Stoop words stopped
|
|
||||||
stopped = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopped {
|
|
||||||
requests.remove(&request_id);
|
|
||||||
engine.as_mut().unwrap().stop_request(request_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
|
|
@ -114,23 +78,10 @@ impl AsyncTextInferenceEngine {
|
||||||
) -> BoxStream<String> {
|
) -> BoxStream<String> {
|
||||||
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
|
let stop_condition = self.stop_condition_factory.create(prompt, options.language);
|
||||||
|
|
||||||
let (tx, mut rx) = channel::<String>(4);
|
let mut rx = self
|
||||||
{
|
.service
|
||||||
let mut engine = self.engine.lock().await;
|
.add_request(prompt, options.max_input_length, stop_condition)
|
||||||
|
.await;
|
||||||
let mut request_id = self.next_request_id.lock().await;
|
|
||||||
self.requests
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert(*request_id, InferenceRequest { tx, stop_condition });
|
|
||||||
engine
|
|
||||||
.as_mut()
|
|
||||||
.unwrap()
|
|
||||||
.add_request(*request_id, prompt, options.max_input_length);
|
|
||||||
|
|
||||||
// 2048 should be large enough to avoid collision.
|
|
||||||
*request_id = (*request_id + 1) % 2048;
|
|
||||||
}
|
|
||||||
|
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
let mut length = 0;
|
let mut length = 0;
|
||||||
|
|
@ -148,53 +99,3 @@ impl AsyncTextInferenceEngine {
|
||||||
Box::pin(s)
|
Box::pin(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
|
||||||
pub struct LlamaTextGenerationOptions {
|
|
||||||
model_path: String,
|
|
||||||
use_gpu: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct LlamaTextGeneration {
|
|
||||||
engine: Arc<AsyncTextInferenceEngine>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlamaTextGeneration {
|
|
||||||
pub fn create(options: LlamaTextGenerationOptions) -> Self {
|
|
||||||
let engine = create_engine(options.use_gpu, &options.model_path);
|
|
||||||
if engine.is_null() {
|
|
||||||
fatal!("Unable to load model: {}", options.model_path);
|
|
||||||
}
|
|
||||||
let ret = LlamaTextGeneration {
|
|
||||||
engine: Arc::new(AsyncTextInferenceEngine::create(engine)),
|
|
||||||
};
|
|
||||||
ret.start_background_job();
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn start_background_job(&self) {
|
|
||||||
let engine = self.engine.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
loop {
|
|
||||||
engine.background_job().await;
|
|
||||||
yield_now().await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl TextGeneration for LlamaTextGeneration {
|
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
|
||||||
let s = self.generate_stream(prompt, options).await;
|
|
||||||
helpers::stream_to_string(s).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
options: TextGenerationOptions,
|
|
||||||
) -> BoxStream<String> {
|
|
||||||
self.engine.generate_stream(prompt, options).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
use std::{collections::HashMap, thread::JoinHandle};
|
||||||
|
|
||||||
|
use cxx::UniquePtr;
|
||||||
|
use tabby_inference::decoding::StopCondition;
|
||||||
|
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||||
|
|
||||||
|
use crate::ffi;
|
||||||
|
|
||||||
|
struct LlamaInitRequest {
|
||||||
|
prompt: String,
|
||||||
|
max_input_length: usize,
|
||||||
|
|
||||||
|
tx: Sender<String>,
|
||||||
|
stop_condition: StopCondition,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LlamaRunningRequest {
|
||||||
|
tx: Sender<String>,
|
||||||
|
stop_condition: StopCondition,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LlamaServiceImpl {
|
||||||
|
next_request_id: u32,
|
||||||
|
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
|
||||||
|
rx: Receiver<LlamaInitRequest>,
|
||||||
|
requests: HashMap<u32, LlamaRunningRequest>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaServiceImpl {
|
||||||
|
fn new(engine: UniquePtr<ffi::TextInferenceEngine>, rx: Receiver<LlamaInitRequest>) -> Self {
|
||||||
|
Self {
|
||||||
|
next_request_id: 0,
|
||||||
|
engine,
|
||||||
|
rx,
|
||||||
|
requests: HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn alloc_request_id(&mut self) -> u32 {
|
||||||
|
let ret = self.next_request_id;
|
||||||
|
self.next_request_id += 1;
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn next_request(&mut self) -> Option<LlamaInitRequest> {
|
||||||
|
if self.requests.is_empty() {
|
||||||
|
self.rx.recv().await
|
||||||
|
} else {
|
||||||
|
self.rx.try_recv().ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn background_job(&mut self) {
|
||||||
|
while let Some(LlamaInitRequest {
|
||||||
|
prompt,
|
||||||
|
tx,
|
||||||
|
max_input_length,
|
||||||
|
stop_condition,
|
||||||
|
}) = self.next_request().await
|
||||||
|
{
|
||||||
|
let request_id = self.alloc_request_id();
|
||||||
|
self.requests
|
||||||
|
.insert(request_id, LlamaRunningRequest { tx, stop_condition });
|
||||||
|
self.engine
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.add_request(request_id, &prompt, max_input_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = match self.engine.as_mut().unwrap().step() {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(err) => {
|
||||||
|
crate::fatal!("Failed to step: {}", err)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for ffi::StepOutput { request_id, text } in result {
|
||||||
|
let mut stopped = false;
|
||||||
|
let LlamaRunningRequest { tx, stop_condition } =
|
||||||
|
self.requests.get_mut(&request_id).unwrap();
|
||||||
|
|
||||||
|
if tx.is_closed() || text.is_empty() {
|
||||||
|
// Cancelled by client side or hit eos.
|
||||||
|
stopped = true;
|
||||||
|
} else if !stop_condition.should_stop(&text) {
|
||||||
|
match tx.send(text).await {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(_) => stopped = true,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Stoop words stopped
|
||||||
|
stopped = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopped {
|
||||||
|
self.requests.remove(&request_id);
|
||||||
|
self.engine.as_mut().unwrap().stop_request(request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_llama_service_impl(
|
||||||
|
engine: UniquePtr<ffi::TextInferenceEngine>,
|
||||||
|
rx: Receiver<LlamaInitRequest>,
|
||||||
|
) -> JoinHandle<()> {
|
||||||
|
let mut service = LlamaServiceImpl::new(engine, rx);
|
||||||
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let local = tokio::task::LocalSet::new();
|
||||||
|
local.spawn_local(async move {
|
||||||
|
loop {
|
||||||
|
service.background_job().await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
rt.block_on(local);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LlamaService {
|
||||||
|
tx: Sender<LlamaInitRequest>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaService {
|
||||||
|
pub fn new(engine: UniquePtr<ffi::TextInferenceEngine>) -> Self {
|
||||||
|
let (tx, rx) = channel(20);
|
||||||
|
start_llama_service_impl(engine, rx);
|
||||||
|
Self { tx }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_request(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
max_input_length: usize,
|
||||||
|
stop_condition: StopCondition,
|
||||||
|
) -> Receiver<String> {
|
||||||
|
let (tx, rx) = channel(8);
|
||||||
|
self.tx
|
||||||
|
.send(LlamaInitRequest {
|
||||||
|
prompt: prompt.to_owned(),
|
||||||
|
tx,
|
||||||
|
max_input_length,
|
||||||
|
stop_condition,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("Failed to add request");
|
||||||
|
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -64,5 +64,5 @@ fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextG
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
|
Box::new(llama_cpp_bindings::LlamaTextGeneration::new(options))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue