chore: mark thread safety [TAB-52] (#186)

* mark thread safety

* use shared_ptr to ensure thread safety

* fmt
switch-to-ctranslate2-subtree
Meng Zhang 2023-06-03 23:23:31 -07:00 committed by GitHub
parent 775576b53e
commit 6de61f45bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 26 additions and 18 deletions

5
Cargo.lock generated
View File

@ -490,6 +490,7 @@ dependencies = [
"derive_builder", "derive_builder",
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"tokenizers", "tokenizers",
"tokio",
] ]
[[package]] [[package]]
@ -2294,9 +2295,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.28.1" version = "1.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105" checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"bytes", "bytes",

View File

@ -16,3 +16,4 @@ homepage = "https://github.com/TabbyML/tabby"
lazy_static = "1.4.0" lazy_static = "1.4.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serdeconv = "0.4.1" serdeconv = "0.4.1"
tokio = "1.28"

View File

@ -7,6 +7,7 @@ edition = "2021"
cxx = "1.0" cxx = "1.0"
derive_builder = "0.12.0" derive_builder = "0.12.0"
tokenizers = "0.13.3" tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] }
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"

View File

@ -16,7 +16,7 @@ class TextInferenceEngine {
) const = 0; ) const = 0;
}; };
std::unique_ptr<TextInferenceEngine> create_engine( std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type, rust::Str model_type,
rust::Str device, rust::Str device,

View File

@ -77,7 +77,7 @@ class DecoderImpl: public TextInferenceEngine {
std::unique_ptr<ctranslate2::Generator> generator_; std::unique_ptr<ctranslate2::Generator> generator_;
}; };
std::unique_ptr<TextInferenceEngine> create_engine( std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type, rust::Str model_type,
rust::Str device, rust::Str device,

View File

@ -16,7 +16,7 @@ mod ffi {
device: &str, device: &str,
device_indices: &[i32], device_indices: &[i32],
num_replicas_per_device: usize, num_replicas_per_device: usize,
) -> UniquePtr<TextInferenceEngine>; ) -> SharedPtr<TextInferenceEngine>;
fn inference( fn inference(
&self, &self,
@ -28,6 +28,9 @@ mod ffi {
} }
} }
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}
#[derive(Builder, Debug)] #[derive(Builder, Debug)]
pub struct TextInferenceEngineCreateOptions { pub struct TextInferenceEngineCreateOptions {
model_path: String, model_path: String,
@ -56,13 +59,10 @@ pub struct TextInferenceOptions {
} }
pub struct TextInferenceEngine { pub struct TextInferenceEngine {
engine: cxx::UniquePtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
} }
unsafe impl Send for TextInferenceEngine {}
unsafe impl Sync for TextInferenceEngine {}
impl TextInferenceEngine { impl TextInferenceEngine {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
let engine = ffi::create_engine( let engine = ffi::create_engine(
@ -78,14 +78,19 @@ impl TextInferenceEngine {
}; };
} }
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap(); let encoding = self.tokenizer.encode(prompt, true).unwrap();
let output_tokens = self.engine.inference( let engine = self.engine.clone();
let output_tokens = tokio::task::spawn_blocking(move || {
engine.inference(
encoding.get_tokens(), encoding.get_tokens(),
options.max_decoding_length, options.max_decoding_length,
options.sampling_temperature, options.sampling_temperature,
options.beam_size, options.beam_size,
); )
})
.await
.expect("Inference failed");
let output_ids: Vec<u32> = output_tokens let output_ids: Vec<u32> = output_tokens
.iter() .iter()
.filter_map(|x| match self.tokenizer.token_to_id(x) { .filter_map(|x| match self.tokenizer.token_to_id(x) {

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
axum = "0.6" axum = "0.6"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1.17", features = ["full"] } tokio = { workspace = true, features = ["full"] }
tower = "0.4" tower = "0.4"
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] } utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] } utoipa-swagger-ui = { version = "3.1", features = ["axum"] }

View File

@ -80,7 +80,7 @@ pub async fn completion(
request.prompt.expect("No prompt is set") request.prompt.expect("No prompt is set")
}; };
let text = state.engine.inference(&prompt, options); let text = state.engine.inference(&prompt, options).await;
let language = request.language.unwrap_or("unknown".into()); let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text); let filtered_text = languages::remove_stop_words(&language, &text);