chore: mark thread safety [TAB-52] (#186)
* mark thread safety * use shared_ptr to ensure thread safety * fmtswitch-to-ctranslate2-subtree
parent
775576b53e
commit
6de61f45bb
|
|
@ -490,6 +490,7 @@ dependencies = [
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"rust-cxx-cmake-bridge",
|
"rust-cxx-cmake-bridge",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -2294,9 +2295,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio"
|
name = "tokio"
|
||||||
version = "1.28.1"
|
version = "1.28.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105"
|
checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
|
|
||||||
|
|
@ -16,3 +16,4 @@ homepage = "https://github.com/TabbyML/tabby"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serdeconv = "0.4.1"
|
serdeconv = "0.4.1"
|
||||||
|
tokio = "1.28"
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ edition = "2021"
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
|
tokio = { workspace = true, features = ["rt"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cxx-build = "1.0"
|
cxx-build = "1.0"
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ class TextInferenceEngine {
|
||||||
) const = 0;
|
) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class DecoderImpl: public TextInferenceEngine {
|
||||||
std::unique_ptr<ctranslate2::Generator> generator_;
|
std::unique_ptr<ctranslate2::Generator> generator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ mod ffi {
|
||||||
device: &str,
|
device: &str,
|
||||||
device_indices: &[i32],
|
device_indices: &[i32],
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
) -> UniquePtr<TextInferenceEngine>;
|
) -> SharedPtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn inference(
|
fn inference(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -28,6 +28,9 @@ mod ffi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||||
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
pub struct TextInferenceEngineCreateOptions {
|
pub struct TextInferenceEngineCreateOptions {
|
||||||
model_path: String,
|
model_path: String,
|
||||||
|
|
@ -56,13 +59,10 @@ pub struct TextInferenceOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TextInferenceEngine {
|
pub struct TextInferenceEngine {
|
||||||
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for TextInferenceEngine {}
|
|
||||||
unsafe impl Sync for TextInferenceEngine {}
|
|
||||||
|
|
||||||
impl TextInferenceEngine {
|
impl TextInferenceEngine {
|
||||||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||||
let engine = ffi::create_engine(
|
let engine = ffi::create_engine(
|
||||||
|
|
@ -78,14 +78,19 @@ impl TextInferenceEngine {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
let output_tokens = self.engine.inference(
|
let engine = self.engine.clone();
|
||||||
encoding.get_tokens(),
|
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||||
options.max_decoding_length,
|
engine.inference(
|
||||||
options.sampling_temperature,
|
encoding.get_tokens(),
|
||||||
options.beam_size,
|
options.max_decoding_length,
|
||||||
);
|
options.sampling_temperature,
|
||||||
|
options.beam_size,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("Inference failed");
|
||||||
let output_ids: Vec<u32> = output_tokens
|
let output_ids: Vec<u32> = output_tokens
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.6"
|
axum = "0.6"
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
tokio = { version = "1.17", features = ["full"] }
|
tokio = { workspace = true, features = ["full"] }
|
||||||
tower = "0.4"
|
tower = "0.4"
|
||||||
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
||||||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ pub async fn completion(
|
||||||
request.prompt.expect("No prompt is set")
|
request.prompt.expect("No prompt is set")
|
||||||
};
|
};
|
||||||
|
|
||||||
let text = state.engine.inference(&prompt, options);
|
let text = state.engine.inference(&prompt, options).await;
|
||||||
let language = request.language.unwrap_or("unknown".into());
|
let language = request.language.unwrap_or("unknown".into());
|
||||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue