refactor: extract create_stop_regex

support-stop-sequences
Meng Zhang 2023-06-06 16:20:38 -07:00
parent 2430a18599
commit dee5c99182
1 changed files with 18 additions and 15 deletions

View File

@ -124,21 +124,10 @@ impl TextInferenceEngine {
} else {
let mut re = self.stop_regex_cache.get(options.stop_words);
if re.is_none() {
let encodings = self
.tokenizer
.encode_batch(options.stop_words.clone(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.collect();
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
let regex = Regex::new(&regex_string).unwrap();
self.stop_regex_cache.insert(options.stop_words, regex);
self.stop_regex_cache.insert(
options.stop_words,
create_stop_regex(&self.tokenizer, options.stop_words),
);
re = self.stop_regex_cache.get(options.stop_words);
}
re.map(|x| x.value().clone())
@ -182,3 +171,17 @@ fn inference_callback(
fn reverse(s: String) -> String {
s.chars().rev().collect()
}
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex {
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.collect();
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
Regex::new(&regex_string).unwrap()
}