2023-05-25 21:05:28 +00:00
|
|
|
use tokenizers::tokenizer::{Model, Tokenizer};
|
|
|
|
|
|
|
|
|
|
#[macro_use]
|
|
|
|
|
extern crate derive_builder;
|
|
|
|
|
|
|
|
|
|
#[cxx::bridge(namespace = "tabby")]
|
|
|
|
|
mod ffi {
|
|
|
|
|
unsafe extern "C++" {
|
|
|
|
|
include!("ctranslate2-bindings/include/ctranslate2.h");
|
|
|
|
|
|
|
|
|
|
type TextInferenceEngine;
|
|
|
|
|
|
2023-05-26 06:23:07 +00:00
|
|
|
fn create_engine(
|
|
|
|
|
model_path: &str,
|
|
|
|
|
device: &str,
|
|
|
|
|
device_indices: &[i32],
|
|
|
|
|
num_replicas_per_device: usize,
|
|
|
|
|
) -> UniquePtr<TextInferenceEngine>;
|
|
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
fn inference(
|
|
|
|
|
&self,
|
|
|
|
|
tokens: &[String],
|
|
|
|
|
max_decoding_length: usize,
|
|
|
|
|
sampling_temperature: f32,
|
|
|
|
|
beam_size: usize,
|
|
|
|
|
) -> Vec<String>;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-26 07:06:08 +00:00
|
|
|
#[derive(Builder, Debug)]
|
2023-05-26 06:23:07 +00:00
|
|
|
pub struct TextInferenceEngineCreateOptions {
|
|
|
|
|
model_path: String,
|
|
|
|
|
|
|
|
|
|
tokenizer_path: String,
|
|
|
|
|
|
|
|
|
|
device: String,
|
|
|
|
|
|
|
|
|
|
device_indices: Vec<i32>,
|
|
|
|
|
|
|
|
|
|
num_replicas_per_device: usize,
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-25 21:05:28 +00:00
|
|
|
#[derive(Builder, Debug)]
|
|
|
|
|
pub struct TextInferenceOptions {
|
|
|
|
|
#[builder(default = "256")]
|
|
|
|
|
max_decoding_length: usize,
|
|
|
|
|
|
|
|
|
|
#[builder(default = "1.0")]
|
|
|
|
|
sampling_temperature: f32,
|
|
|
|
|
|
|
|
|
|
#[builder(default = "2")]
|
|
|
|
|
beam_size: usize,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct TextInferenceEngine {
|
2023-05-26 07:06:08 +00:00
|
|
|
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
|
2023-05-25 21:05:28 +00:00
|
|
|
tokenizer: Tokenizer,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsafe impl Send for TextInferenceEngine {}
|
|
|
|
|
unsafe impl Sync for TextInferenceEngine {}
|
|
|
|
|
|
|
|
|
|
impl TextInferenceEngine {
|
2023-05-26 06:23:07 +00:00
|
|
|
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
|
|
|
|
let engine = ffi::create_engine(
|
|
|
|
|
&options.model_path,
|
|
|
|
|
&options.device,
|
|
|
|
|
&options.device_indices,
|
|
|
|
|
options.num_replicas_per_device,
|
|
|
|
|
);
|
2023-05-25 21:05:28 +00:00
|
|
|
return TextInferenceEngine {
|
2023-05-26 07:06:08 +00:00
|
|
|
engine: engine,
|
2023-05-26 06:23:07 +00:00
|
|
|
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
2023-05-25 21:05:28 +00:00
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
|
|
|
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
2023-05-26 07:06:08 +00:00
|
|
|
let output_tokens = self.engine.inference(
|
2023-05-25 21:05:28 +00:00
|
|
|
encoding.get_tokens(),
|
|
|
|
|
options.max_decoding_length,
|
|
|
|
|
options.sampling_temperature,
|
|
|
|
|
options.beam_size,
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let model = self.tokenizer.get_model();
|
|
|
|
|
let output_ids: Vec<u32> = output_tokens
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|x| model.token_to_id(x).unwrap())
|
|
|
|
|
.collect();
|
|
|
|
|
self.tokenizer.decode(output_ids, true).unwrap()
|
|
|
|
|
}
|
|
|
|
|
}
|