improve effiency of regexp match with reversed regex
parent
301c86a985
commit
c3e57147cf
|
|
@ -75,12 +75,12 @@ pub struct TextInferenceOptions {
|
|||
pub struct InferenceContext {
|
||||
stop_re: Option<Regex>,
|
||||
cancel: CancellationToken,
|
||||
output_text: String
|
||||
reversed_output_text: String
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||
InferenceContext { stop_re, cancel, output_text: "".to_owned() }
|
||||
InferenceContext { stop_re, cancel, reversed_output_text: "".to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -115,9 +115,17 @@ impl TextInferenceEngine {
|
|||
let stop_re = if options.stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
// FIXME(meng): consider cache the regexp.
|
||||
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
|
||||
let stop_tokens : Vec<String> = encodings.iter().map(|x| x.get_tokens().join("")).collect();
|
||||
let regex_string = r"(?m)".to_owned() + &stop_tokens.join("|");
|
||||
let stop_tokens : Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
.map(reverse)
|
||||
.collect();
|
||||
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
|
||||
Some(Regex::new(®ex_string).unwrap())
|
||||
};
|
||||
|
||||
|
|
@ -143,10 +151,16 @@ fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u
|
|||
true
|
||||
} else {
|
||||
if let Some(re) = &context.stop_re {
|
||||
context.output_text.push_str(&token);
|
||||
re.find(&context.output_text).is_some()
|
||||
let mut new_token = reverse(token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse(s: String) -> String {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue