2023-09-03 01:59:07 +00:00
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
2023-09-28 17:20:50 +00:00
|
|
|
use async_stream::stream;
|
2023-08-02 06:12:51 +00:00
|
|
|
use async_trait::async_trait;
|
|
|
|
|
use derive_builder::Builder;
|
2023-09-28 17:20:50 +00:00
|
|
|
use futures::stream::BoxStream;
|
2023-09-03 01:59:07 +00:00
|
|
|
use stop_words::{StopWords, StopWordsCondition};
|
2023-09-28 17:20:50 +00:00
|
|
|
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
2023-05-27 23:20:17 +00:00
|
|
|
use tokenizers::tokenizer::Tokenizer;
|
2023-09-28 17:20:50 +00:00
|
|
|
use tokio::sync::mpsc::{channel, Sender};
|
2023-06-04 22:28:39 +00:00
|
|
|
use tokio_util::sync::CancellationToken;
|
2023-05-25 21:05:28 +00:00
|
|
|
|
|
|
|
|
#[cxx::bridge(namespace = "tabby")]
|
|
|
|
|
mod ffi {
|
2023-06-04 22:28:39 +00:00
|
|
|
extern "Rust" {
|
|
|
|
|
type InferenceContext;
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
unsafe extern "C++" {
|
|
|
|
|
include!("ctranslate2-bindings/include/ctranslate2.h");
|
|
|
|
|
|
|
|
|
|
type TextInferenceEngine;
|
|
|
|
|
|
2023-05-26 06:23:07 +00:00
|
|
|
fn create_engine(
|
|
|
|
|
model_path: &str,
|
2023-05-27 08:26:33 +00:00
|
|
|
model_type: &str,
|
2023-05-26 06:23:07 +00:00
|
|
|
device: &str,
|
2023-06-13 19:04:07 +00:00
|
|
|
compute_type: &str,
|
2023-05-26 06:23:07 +00:00
|
|
|
device_indices: &[i32],
|
|
|
|
|
num_replicas_per_device: usize,
|
2023-06-04 06:23:31 +00:00
|
|
|
) -> SharedPtr<TextInferenceEngine>;
|
2023-05-26 06:23:07 +00:00
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
fn inference(
|
|
|
|
|
&self,
|
2023-06-04 22:28:39 +00:00
|
|
|
context: Box<InferenceContext>,
|
2023-06-06 23:28:58 +00:00
|
|
|
callback: fn(
|
|
|
|
|
&mut InferenceContext,
|
|
|
|
|
// step
|
|
|
|
|
usize,
|
|
|
|
|
// token_id
|
|
|
|
|
u32,
|
|
|
|
|
// token
|
|
|
|
|
String,
|
|
|
|
|
) -> bool,
|
2023-05-25 21:05:28 +00:00
|
|
|
tokens: &[String],
|
|
|
|
|
max_decoding_length: usize,
|
|
|
|
|
sampling_temperature: f32,
|
2023-06-06 23:28:58 +00:00
|
|
|
) -> Vec<u32>;
|
2023-05-25 21:05:28 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-04 06:23:31 +00:00
|
|
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
|
|
|
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
|
|
|
|
|
2023-05-26 07:06:08 +00:00
|
|
|
#[derive(Builder, Debug)]
|
2023-08-02 06:12:51 +00:00
|
|
|
pub struct CTranslate2EngineOptions {
|
2023-05-26 06:23:07 +00:00
|
|
|
model_path: String,
|
|
|
|
|
|
2023-05-27 08:26:33 +00:00
|
|
|
model_type: String,
|
|
|
|
|
|
2023-05-26 06:23:07 +00:00
|
|
|
tokenizer_path: String,
|
|
|
|
|
|
|
|
|
|
device: String,
|
|
|
|
|
|
|
|
|
|
device_indices: Vec<i32>,
|
|
|
|
|
|
|
|
|
|
num_replicas_per_device: usize,
|
2023-06-13 19:04:07 +00:00
|
|
|
|
|
|
|
|
compute_type: String,
|
2023-05-26 06:23:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-06-06 23:28:58 +00:00
|
|
|
pub struct InferenceContext {
|
2023-09-28 17:20:50 +00:00
|
|
|
sender: Sender<u32>,
|
2023-09-03 01:59:07 +00:00
|
|
|
stop_condition: StopWordsCondition,
|
2023-06-06 23:28:58 +00:00
|
|
|
cancel: CancellationToken,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl InferenceContext {
|
2023-09-28 17:20:50 +00:00
|
|
|
fn new(
|
|
|
|
|
sender: Sender<u32>,
|
|
|
|
|
stop_condition: StopWordsCondition,
|
|
|
|
|
cancel: CancellationToken,
|
|
|
|
|
) -> Self {
|
2023-06-06 23:28:58 +00:00
|
|
|
InferenceContext {
|
2023-09-28 17:20:50 +00:00
|
|
|
sender,
|
2023-09-03 01:59:07 +00:00
|
|
|
stop_condition,
|
2023-06-06 23:28:58 +00:00
|
|
|
cancel,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-04 22:28:39 +00:00
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
pub struct CTranslate2Engine {
|
2023-06-04 06:23:31 +00:00
|
|
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
2023-09-03 01:59:07 +00:00
|
|
|
stop_words: StopWords,
|
|
|
|
|
tokenizer: Arc<Tokenizer>,
|
2023-05-25 21:05:28 +00:00
|
|
|
}
|
|
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
impl CTranslate2Engine {
|
|
|
|
|
pub fn create(options: CTranslate2EngineOptions) -> Self where {
|
2023-05-26 06:23:07 +00:00
|
|
|
let engine = ffi::create_engine(
|
|
|
|
|
&options.model_path,
|
2023-05-27 08:26:33 +00:00
|
|
|
&options.model_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
&options.device,
|
2023-06-13 19:04:07 +00:00
|
|
|
&options.compute_type,
|
2023-05-26 06:23:07 +00:00
|
|
|
&options.device_indices,
|
|
|
|
|
options.num_replicas_per_device,
|
|
|
|
|
);
|
2023-08-02 06:12:51 +00:00
|
|
|
|
|
|
|
|
return Self {
|
2023-05-27 23:20:17 +00:00
|
|
|
engine,
|
2023-09-03 01:59:07 +00:00
|
|
|
stop_words: StopWords::default(),
|
|
|
|
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
2023-05-25 21:05:28 +00:00
|
|
|
};
|
|
|
|
|
}
|
2023-08-02 06:12:51 +00:00
|
|
|
}
|
2023-05-25 21:05:28 +00:00
|
|
|
|
2023-08-02 06:12:51 +00:00
|
|
|
#[async_trait]
|
|
|
|
|
impl TextGeneration for CTranslate2Engine {
|
|
|
|
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
2023-09-28 17:20:50 +00:00
|
|
|
let s = self.generate_stream(prompt, options).await;
|
|
|
|
|
helpers::stream_to_string(s).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn generate_stream(
|
|
|
|
|
&self,
|
|
|
|
|
prompt: &str,
|
|
|
|
|
options: TextGenerationOptions,
|
|
|
|
|
) -> BoxStream<String> {
|
2023-05-25 21:05:28 +00:00
|
|
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
2023-06-04 06:23:31 +00:00
|
|
|
let engine = self.engine.clone();
|
2023-09-28 17:20:50 +00:00
|
|
|
let s = stream! {
|
|
|
|
|
let cancel = CancellationToken::new();
|
|
|
|
|
let cancel_for_inference = cancel.clone();
|
|
|
|
|
let _guard = cancel.drop_guard();
|
|
|
|
|
|
|
|
|
|
let stop_condition = self
|
|
|
|
|
.stop_words
|
|
|
|
|
.create_condition(self.tokenizer.clone(), options.stop_words);
|
|
|
|
|
|
|
|
|
|
let (sender, mut receiver) = channel::<u32>(8);
|
|
|
|
|
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
|
|
|
|
|
tokio::task::spawn(async move {
|
|
|
|
|
let context = Box::new(context);
|
|
|
|
|
engine.inference(
|
|
|
|
|
context,
|
|
|
|
|
inference_callback,
|
|
|
|
|
truncate_tokens(encoding.get_tokens(), options.max_input_length),
|
|
|
|
|
options.max_decoding_length,
|
|
|
|
|
options.sampling_temperature,
|
|
|
|
|
);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
while let Some(next_token_id) = receiver.recv().await {
|
|
|
|
|
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
|
|
|
|
yield text;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
Box::pin(s)
|
2023-05-25 21:05:28 +00:00
|
|
|
}
|
|
|
|
|
}
|
2023-06-06 23:28:58 +00:00
|
|
|
|
2023-09-08 10:01:03 +00:00
|
|
|
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
|
2023-09-12 12:56:35 +00:00
|
|
|
if max_length < tokens.len() {
|
|
|
|
|
let start = tokens.len() - max_length;
|
|
|
|
|
&tokens[start..]
|
|
|
|
|
} else {
|
|
|
|
|
tokens
|
|
|
|
|
}
|
2023-09-08 10:01:03 +00:00
|
|
|
}
|
|
|
|
|
|
2023-06-06 23:28:58 +00:00
|
|
|
fn inference_callback(
|
|
|
|
|
context: &mut InferenceContext,
|
|
|
|
|
_step: usize,
|
2023-09-03 01:59:07 +00:00
|
|
|
token_id: u32,
|
|
|
|
|
_token: String,
|
2023-06-06 23:28:58 +00:00
|
|
|
) -> bool {
|
2023-09-28 17:20:50 +00:00
|
|
|
let _ = context.sender.blocking_send(token_id);
|
2023-06-06 23:28:58 +00:00
|
|
|
if context.cancel.is_cancelled() {
|
|
|
|
|
true
|
2023-08-28 06:07:01 +00:00
|
|
|
} else {
|
2023-09-03 01:59:07 +00:00
|
|
|
context.stop_condition.next_token(token_id)
|
2023-08-28 06:07:01 +00:00
|
|
|
}
|
2023-06-06 23:28:58 +00:00
|
|
|
}
|