diff --git a/Cargo.lock b/Cargo.lock index 8ec5e7a..149a2d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -578,6 +578,7 @@ dependencies = [ "cmake", "cxx", "cxx-build", + "dashmap", "derive_builder", "regex", "rust-cxx-cmake-bridge", @@ -665,6 +666,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "derive_builder" version = "0.12.0" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 329ab12..b5a45fb 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] cxx = "1.0" +dashmap = "5.4.0" derive_builder = "0.12.0" regex = "1.8.4" tokenizers = "0.13.3" diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 12fbe68..2a61337 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,3 +1,4 @@ +use dashmap::DashMap; use regex::Regex; use tokenizers::tokenizer::Tokenizer; use tokio_util::sync::CancellationToken; @@ -91,6 +92,7 @@ impl InferenceContext { pub struct TextInferenceEngine { engine: cxx::SharedPtr, tokenizer: Tokenizer, + stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, } impl TextInferenceEngine { @@ -104,6 +106,7 @@ impl TextInferenceEngine { ); return TextInferenceEngine { engine, + stop_regex_cache: DashMap::new(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), }; } @@ -116,24 +119,29 @@ impl TextInferenceEngine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let stop_re = if options.stop_words.is_empty() { + let stop_re: Option = if options.stop_words.is_empty() { None } else { - // FIXME(meng): consider cache the regexp. - let encodings = self - .tokenizer - .encode_batch(options.stop_words.clone(), false) - .unwrap(); - let stop_tokens: Vec = encodings - .iter() - .map(|x| x.get_tokens().join("")) - // Reverse for efficient suffix matching. - .map(reverse) - .collect(); + let mut re = self.stop_regex_cache.get(options.stop_words); + if re.is_none() { + let encodings = self + .tokenizer + .encode_batch(options.stop_words.clone(), false) + .unwrap(); + let stop_tokens: Vec = encodings + .iter() + .map(|x| x.get_tokens().join("")) + // Reverse for efficient suffix matching. + .map(reverse) + .collect(); - // \A means absolute begins of string. - let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); - Some(Regex::new(®ex_string).unwrap()) + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); + let regex = Regex::new(®ex_string).unwrap(); + self.stop_regex_cache.insert(options.stop_words, regex); + re = self.stop_regex_cache.get(options.stop_words); + } + re.map(|x| x.value().clone()) }; let context = InferenceContext::new(stop_re, cancel_for_inference);