142 lines
3.9 KiB
Rust
142 lines
3.9 KiB
Rust
use std::sync::Arc;
|
||
|
||
use dashmap::DashMap;
|
||
use regex::Regex;
|
||
use tokenizers::tokenizer::Tokenizer;
|
||
|
||
pub struct DecodingFactory {
|
||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||
}
|
||
|
||
fn reverse<T>(s: T) -> String
|
||
where
|
||
T: Into<String>,
|
||
{
|
||
s.into().chars().rev().collect()
|
||
}
|
||
|
||
impl Default for DecodingFactory {
|
||
fn default() -> Self {
|
||
Self {
|
||
stop_regex_cache: DashMap::new(),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl DecodingFactory {
|
||
pub fn create(
|
||
&self,
|
||
tokenizer: Arc<Tokenizer>,
|
||
input_token_ids: &[u32],
|
||
stop_words: &Vec<String>,
|
||
static_stop_words: &'static Vec<&'static str>,
|
||
) -> IncrementalDecoding {
|
||
IncrementalDecoding::new(
|
||
tokenizer,
|
||
vec![
|
||
self.get_static_re(static_stop_words),
|
||
self.get_re(stop_words),
|
||
]
|
||
.into_iter()
|
||
.flatten()
|
||
.collect(),
|
||
input_token_ids,
|
||
)
|
||
}
|
||
|
||
fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
|
||
if !stop_words.is_empty() {
|
||
Some(create_stop_regex(stop_words))
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||
if stop_words.is_empty() {
|
||
None
|
||
} else {
|
||
let mut re = self.stop_regex_cache.get(stop_words);
|
||
if re.is_none() {
|
||
self.stop_regex_cache
|
||
.insert(stop_words, create_stop_regex(stop_words));
|
||
re = self.stop_regex_cache.get(stop_words);
|
||
}
|
||
re.map(|x| x.value().clone())
|
||
}
|
||
}
|
||
}
|
||
|
||
fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
|
||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
|
||
|
||
// (?m) enables multi-line matching mode.
|
||
// \A means absolute begins of string.
|
||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
||
Regex::new(®ex_string).unwrap()
|
||
}
|
||
|
||
pub struct IncrementalDecoding {
|
||
tokenizer: Arc<Tokenizer>,
|
||
stop_re: Vec<Regex>,
|
||
|
||
token_ids: Vec<u32>,
|
||
prefix_offset: usize,
|
||
read_offset: usize,
|
||
|
||
reversed_text: String,
|
||
}
|
||
|
||
impl IncrementalDecoding {
|
||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
|
||
let text = tokenizer
|
||
.decode(input_token_ids, /* skip_special_token = */ true)
|
||
.expect("Cannot decode token from tokenizer.");
|
||
Self {
|
||
tokenizer,
|
||
stop_re,
|
||
token_ids: input_token_ids.to_owned(),
|
||
prefix_offset: 0,
|
||
read_offset: input_token_ids.len(),
|
||
reversed_text: reverse(text),
|
||
}
|
||
}
|
||
|
||
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
|
||
let skip_special_token = true;
|
||
self.token_ids.push(token_id);
|
||
|
||
let prefix_text = self
|
||
.tokenizer
|
||
.decode(
|
||
&self.token_ids[self.prefix_offset..self.read_offset],
|
||
skip_special_token,
|
||
)
|
||
.expect("Cannot decode token from tokenizer.");
|
||
|
||
let new_text = self
|
||
.tokenizer
|
||
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
|
||
.expect("Cannot decode token from tokenizer.");
|
||
|
||
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
|
||
self.prefix_offset = self.read_offset;
|
||
self.read_offset = self.token_ids.len();
|
||
&new_text[prefix_text.len()..]
|
||
} else {
|
||
""
|
||
};
|
||
|
||
if !new_text.is_empty() {
|
||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||
for re in &self.stop_re {
|
||
if re.find(&self.reversed_text).is_some() {
|
||
return None;
|
||
}
|
||
}
|
||
}
|
||
|
||
Some(new_text.to_owned())
|
||
}
|
||
}
|