add cache for stop words regexp

support-stop-sequences
Meng Zhang 2023-06-06 16:11:31 -07:00
parent 030f694261
commit 2430a18599
3 changed files with 38 additions and 15 deletions

14
Cargo.lock generated
View File

@ -578,6 +578,7 @@ dependencies = [
"cmake",
"cxx",
"cxx-build",
"dashmap",
"derive_builder",
"regex",
"rust-cxx-cmake-bridge",
@ -665,6 +666,19 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "derive_builder"
version = "0.12.0"

View File

@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
cxx = "1.0"
dashmap = "5.4.0"
derive_builder = "0.12.0"
regex = "1.8.4"
tokenizers = "0.13.3"

View File

@ -1,3 +1,4 @@
use dashmap::DashMap;
use regex::Regex;
use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken;
@ -91,6 +92,7 @@ impl InferenceContext {
pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer,
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
}
impl TextInferenceEngine {
@ -104,6 +106,7 @@ impl TextInferenceEngine {
);
return TextInferenceEngine {
engine,
stop_regex_cache: DashMap::new(),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
};
}
@ -116,24 +119,29 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let stop_re = if options.stop_words.is_empty() {
let stop_re: Option<Regex> = if options.stop_words.is_empty() {
None
} else {
// FIXME(meng): consider cache the regexp.
let encodings = self
.tokenizer
.encode_batch(options.stop_words.clone(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.collect();
let mut re = self.stop_regex_cache.get(options.stop_words);
if re.is_none() {
let encodings = self
.tokenizer
.encode_batch(options.stop_words.clone(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.collect();
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
Some(Regex::new(&regex_string).unwrap())
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
let regex = Regex::new(&regex_string).unwrap();
self.stop_regex_cache.insert(options.stop_words, regex);
re = self.stop_regex_cache.get(options.stop_words);
}
re.map(|x| x.value().clone())
};
let context = InferenceContext::new(stop_re, cancel_for_inference);