diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index ffeece7..a29ae71 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -75,12 +75,12 @@ pub struct TextInferenceOptions { pub struct InferenceContext { stop_re: Option, cancel: CancellationToken, - output_text: String + reversed_output_text: String } impl InferenceContext { fn new(stop_re: Option, cancel: CancellationToken) -> Self { - InferenceContext { stop_re, cancel, output_text: "".to_owned() } + InferenceContext { stop_re, cancel, reversed_output_text: "".to_owned() } } } @@ -115,9 +115,17 @@ impl TextInferenceEngine { let stop_re = if options.stop_words.is_empty() { None } else { + // FIXME(meng): consider cache the regexp. let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap(); - let stop_tokens : Vec = encodings.iter().map(|x| x.get_tokens().join("")).collect(); - let regex_string = r"(?m)".to_owned() + &stop_tokens.join("|"); + let stop_tokens : Vec = encodings + .iter() + .map(|x| x.get_tokens().join("")) + // Reverse for efficient suffix matching. + .map(reverse) + .collect(); + + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); Some(Regex::new(®ex_string).unwrap()) }; @@ -143,10 +151,16 @@ fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u true } else { if let Some(re) = &context.stop_re { - context.output_text.push_str(&token); - re.find(&context.output_text).is_some() + let mut new_token = reverse(token); + new_token.push_str(&context.reversed_output_text); + context.reversed_output_text = new_token; + re.find(&context.reversed_output_text).is_some() } else { false } } } + +fn reverse(s: String) -> String { + s.chars().rev().collect() +}