diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 2a61337..01f57b3 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -124,21 +124,10 @@ impl TextInferenceEngine { } else { let mut re = self.stop_regex_cache.get(options.stop_words); if re.is_none() { - let encodings = self - .tokenizer - .encode_batch(options.stop_words.clone(), false) - .unwrap(); - let stop_tokens: Vec = encodings - .iter() - .map(|x| x.get_tokens().join("")) - // Reverse for efficient suffix matching. - .map(reverse) - .collect(); - - // \A means absolute begins of string. - let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); - let regex = Regex::new(®ex_string).unwrap(); - self.stop_regex_cache.insert(options.stop_words, regex); + self.stop_regex_cache.insert( + options.stop_words, + create_stop_regex(&self.tokenizer, options.stop_words), + ); re = self.stop_regex_cache.get(options.stop_words); } re.map(|x| x.value().clone()) @@ -182,3 +171,17 @@ fn inference_callback( fn reverse(s: String) -> String { s.chars().rev().collect() } + +fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex { + let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap(); + let stop_tokens: Vec = encodings + .iter() + .map(|x| x.get_tokens().join("")) + // Reverse for efficient suffix matching. + .map(reverse) + .collect(); + + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); + Regex::new(®ex_string).unwrap() +}