From fd2a1ab86509fdec297d5023ee419a277e86c973 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 6 Oct 2023 02:04:37 -0700 Subject: [PATCH] fix: switch back to regex based implementation for stop words (#513) --- crates/tabby-inference/src/decoding.rs | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index fddd308..ffefe94 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -1,11 +1,11 @@ use std::sync::Arc; use dashmap::DashMap; -use regex::RegexSet; +use regex::Regex; use tokenizers::tokenizer::Tokenizer; pub struct DecodingFactory { - stop_regex_cache: DashMap<&'static Vec<&'static str>, RegexSet>, + stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, } fn reverse(s: T) -> String @@ -33,7 +33,7 @@ impl DecodingFactory { IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) } - fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { + fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { if stop_words.is_empty() { None } else { @@ -48,19 +48,20 @@ impl DecodingFactory { } } -fn create_stop_regex(stop_words: &[&str]) -> RegexSet { +fn create_stop_regex(stop_words: &[&str]) -> Regex { // (?m) enables multi-line matching mode. // \A means absolute begins of string. - let tokens: Vec = stop_words + let reversed_stop_words: Vec<_> = stop_words .iter() - .map(|x| r"(?m)\A".to_owned() + &reverse(*x)) + .map(|x| regex::escape(&reverse(*x))) .collect(); - RegexSet::new(tokens).expect("Failed to create regex set") + let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))"; + Regex::new(®ex_string).expect("Failed to create regex") } pub struct IncrementalDecoding { tokenizer: Arc, - stop_re: Option, + stop_re: Option, token_ids: Vec, prefix_offset: usize, @@ -70,11 +71,7 @@ pub struct IncrementalDecoding { } impl IncrementalDecoding { - pub fn new( - tokenizer: Arc, - stop_re: Option, - input_token_ids: &[u32], - ) -> Self { + pub fn new(tokenizer: Arc, stop_re: Option, input_token_ids: &[u32]) -> Self { let text = tokenizer .decode(input_token_ids, /* skip_special_token = */ true) .expect("Cannot decode token from tokenizer."); @@ -132,10 +129,9 @@ mod tests { use super::*; #[test] - fn test_it_should_not_match() { - let stop_words = vec!["\n\n", "\n\n "]; - let re = create_stop_regex(&stop_words); + fn test_it_works() { let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid"); - assert!(!re.is_match(&text)) + assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text)); + assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text)); } }