refactor: use RegexSet for cleaer stop regex construction (#499)
* fix: add a regression test cased for stop words regex matching * refactor: use RegexSet for cleaer stop regex constructionrelease-0.2
parent
63612d5a67
commit
ce20bd6154
|
|
@ -1,11 +1,11 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use regex::Regex;
|
use regex::RegexSet;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
pub struct DecodingFactory {
|
pub struct DecodingFactory {
|
||||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
stop_regex_cache: DashMap<&'static Vec<&'static str>, RegexSet>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reverse<T>(s: T) -> String
|
fn reverse<T>(s: T) -> String
|
||||||
|
|
@ -33,7 +33,7 @@ impl DecodingFactory {
|
||||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<RegexSet> {
|
||||||
if stop_words.is_empty() {
|
if stop_words.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -48,18 +48,19 @@ impl DecodingFactory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
fn create_stop_regex(stop_words: &[&str]) -> RegexSet {
|
||||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
|
||||||
|
|
||||||
// (?m) enables multi-line matching mode.
|
// (?m) enables multi-line matching mode.
|
||||||
// \A means absolute begins of string.
|
// \A means absolute begins of string.
|
||||||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
let tokens: Vec<String> = stop_words
|
||||||
Regex::new(®ex_string).unwrap()
|
.iter()
|
||||||
|
.map(|x| r"(?m)\A".to_owned() + &reverse(*x))
|
||||||
|
.collect();
|
||||||
|
RegexSet::new(tokens).expect("Failed to create regex set")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct IncrementalDecoding {
|
pub struct IncrementalDecoding {
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
stop_re: Option<Regex>,
|
stop_re: Option<RegexSet>,
|
||||||
|
|
||||||
token_ids: Vec<u32>,
|
token_ids: Vec<u32>,
|
||||||
prefix_offset: usize,
|
prefix_offset: usize,
|
||||||
|
|
@ -69,7 +70,11 @@ pub struct IncrementalDecoding {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IncrementalDecoding {
|
impl IncrementalDecoding {
|
||||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
pub fn new(
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
stop_re: Option<RegexSet>,
|
||||||
|
input_token_ids: &[u32],
|
||||||
|
) -> Self {
|
||||||
let text = tokenizer
|
let text = tokenizer
|
||||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||||
.expect("Cannot decode token from tokenizer.");
|
.expect("Cannot decode token from tokenizer.");
|
||||||
|
|
@ -112,7 +117,7 @@ impl IncrementalDecoding {
|
||||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||||
|
|
||||||
if let Some(re) = &self.stop_re {
|
if let Some(re) = &self.stop_re {
|
||||||
if re.find(&self.reversed_text).is_some() {
|
if re.is_match(&self.reversed_text) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -121,3 +126,16 @@ impl IncrementalDecoding {
|
||||||
Some(new_text.to_owned())
|
Some(new_text.to_owned())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_it_should_not_match() {
|
||||||
|
let stop_words = vec!["\n\n", "\n\n "];
|
||||||
|
let re = create_stop_regex(&stop_words);
|
||||||
|
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
|
||||||
|
assert!(!re.is_match(&text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ lazy_static! {
|
||||||
"\n\n ",
|
"\n\n ",
|
||||||
"\n\n ",
|
"\n\n ",
|
||||||
"\n\n ",
|
"\n\n ",
|
||||||
"\n\n",
|
|
||||||
"\n\n\t",
|
"\n\n\t",
|
||||||
"\n\n\t\t",
|
"\n\n\t\t",
|
||||||
"\n\n\t\t\t",
|
"\n\n\t\t\t",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue