fix: switch back to regex based implementation for stop words (#513)

wsxiaoys-patch-1
Meng Zhang 2023-10-06 02:04:37 -07:00 committed by GitHub
parent 4c00ac06fb
commit fd2a1ab865
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 17 deletions

View File

@ -1,11 +1,11 @@
use std::sync::Arc; use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use regex::RegexSet; use regex::Regex;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory { pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static Vec<&'static str>, RegexSet>, stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
} }
fn reverse<T>(s: T) -> String fn reverse<T>(s: T) -> String
@ -33,7 +33,7 @@ impl DecodingFactory {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
} }
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<RegexSet> { fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() { if stop_words.is_empty() {
None None
} else { } else {
@ -48,19 +48,20 @@ impl DecodingFactory {
} }
} }
fn create_stop_regex(stop_words: &[&str]) -> RegexSet { fn create_stop_regex(stop_words: &[&str]) -> Regex {
// (?m) enables multi-line matching mode. // (?m) enables multi-line matching mode.
// \A means absolute begins of string. // \A means absolute begins of string.
let tokens: Vec<String> = stop_words let reversed_stop_words: Vec<_> = stop_words
.iter() .iter()
.map(|x| r"(?m)\A".to_owned() + &reverse(*x)) .map(|x| regex::escape(&reverse(*x)))
.collect(); .collect();
RegexSet::new(tokens).expect("Failed to create regex set") let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
Regex::new(&regex_string).expect("Failed to create regex")
} }
pub struct IncrementalDecoding { pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
stop_re: Option<RegexSet>, stop_re: Option<Regex>,
token_ids: Vec<u32>, token_ids: Vec<u32>,
prefix_offset: usize, prefix_offset: usize,
@ -70,11 +71,7 @@ pub struct IncrementalDecoding {
} }
impl IncrementalDecoding { impl IncrementalDecoding {
pub fn new( pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
tokenizer: Arc<Tokenizer>,
stop_re: Option<RegexSet>,
input_token_ids: &[u32],
) -> Self {
let text = tokenizer let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true) .decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer."); .expect("Cannot decode token from tokenizer.");
@ -132,10 +129,9 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_it_should_not_match() { fn test_it_works() {
let stop_words = vec!["\n\n", "\n\n "];
let re = create_stop_regex(&stop_words);
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid"); let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
assert!(!re.is_match(&text)) assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text));
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text));
} }
} }