refactor: extract create_stop_regex
parent
2430a18599
commit
dee5c99182
|
|
@ -124,21 +124,10 @@ impl TextInferenceEngine {
|
|||
} else {
|
||||
let mut re = self.stop_regex_cache.get(options.stop_words);
|
||||
if re.is_none() {
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
.encode_batch(options.stop_words.clone(), false)
|
||||
.unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
.map(reverse)
|
||||
.collect();
|
||||
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
|
||||
let regex = Regex::new(®ex_string).unwrap();
|
||||
self.stop_regex_cache.insert(options.stop_words, regex);
|
||||
self.stop_regex_cache.insert(
|
||||
options.stop_words,
|
||||
create_stop_regex(&self.tokenizer, options.stop_words),
|
||||
);
|
||||
re = self.stop_regex_cache.get(options.stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
|
|
@ -182,3 +171,17 @@ fn inference_callback(
|
|||
fn reverse(s: String) -> String {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
|
||||
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex {
|
||||
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
.map(reverse)
|
||||
.collect();
|
||||
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue