2023-09-29 13:06:47 +00:00
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
|
|
|
|
use dashmap::DashMap;
|
|
|
|
|
|
use regex::Regex;
|
|
|
|
|
|
use tokenizers::tokenizer::Tokenizer;
|
|
|
|
|
|
|
|
|
|
|
|
pub struct DecodingFactory {
|
|
|
|
|
|
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn reverse<T>(s: T) -> String
|
|
|
|
|
|
where
|
|
|
|
|
|
T: Into<String>,
|
|
|
|
|
|
{
|
|
|
|
|
|
s.into().chars().rev().collect()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl Default for DecodingFactory {
|
|
|
|
|
|
fn default() -> Self {
|
|
|
|
|
|
Self {
|
|
|
|
|
|
stop_regex_cache: DashMap::new(),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl DecodingFactory {
|
2023-10-02 15:39:15 +00:00
|
|
|
|
pub fn create_incremental_decoding(
|
2023-09-29 13:06:47 +00:00
|
|
|
|
&self,
|
|
|
|
|
|
tokenizer: Arc<Tokenizer>,
|
|
|
|
|
|
input_token_ids: &[u32],
|
2023-10-02 15:39:15 +00:00
|
|
|
|
stop_words: &'static Vec<&'static str>,
|
2023-09-29 13:06:47 +00:00
|
|
|
|
) -> IncrementalDecoding {
|
2023-10-02 15:39:15 +00:00
|
|
|
|
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
2023-09-29 13:06:47 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
2023-10-02 15:39:15 +00:00
|
|
|
|
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
2023-09-29 13:06:47 +00:00
|
|
|
|
if stop_words.is_empty() {
|
|
|
|
|
|
None
|
|
|
|
|
|
} else {
|
|
|
|
|
|
let mut re = self.stop_regex_cache.get(stop_words);
|
|
|
|
|
|
if re.is_none() {
|
|
|
|
|
|
self.stop_regex_cache
|
|
|
|
|
|
.insert(stop_words, create_stop_regex(stop_words));
|
|
|
|
|
|
re = self.stop_regex_cache.get(stop_words);
|
|
|
|
|
|
}
|
|
|
|
|
|
re.map(|x| x.value().clone())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-10-02 15:39:15 +00:00
|
|
|
|
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
|
|
|
|
|
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
2023-09-29 13:06:47 +00:00
|
|
|
|
|
|
|
|
|
|
// (?m) enables multi-line matching mode.
|
|
|
|
|
|
// \A means absolute begins of string.
|
|
|
|
|
|
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
|
|
|
|
|
Regex::new(®ex_string).unwrap()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub struct IncrementalDecoding {
|
|
|
|
|
|
tokenizer: Arc<Tokenizer>,
|
2023-10-02 15:39:15 +00:00
|
|
|
|
stop_re: Option<Regex>,
|
2023-09-29 13:06:47 +00:00
|
|
|
|
|
|
|
|
|
|
token_ids: Vec<u32>,
|
|
|
|
|
|
prefix_offset: usize,
|
|
|
|
|
|
read_offset: usize,
|
|
|
|
|
|
|
|
|
|
|
|
reversed_text: String,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl IncrementalDecoding {
|
2023-10-02 15:39:15 +00:00
|
|
|
|
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
2023-09-29 13:06:47 +00:00
|
|
|
|
let text = tokenizer
|
|
|
|
|
|
.decode(input_token_ids, /* skip_special_token = */ true)
|
|
|
|
|
|
.expect("Cannot decode token from tokenizer.");
|
|
|
|
|
|
Self {
|
|
|
|
|
|
tokenizer,
|
|
|
|
|
|
stop_re,
|
|
|
|
|
|
token_ids: input_token_ids.to_owned(),
|
|
|
|
|
|
prefix_offset: 0,
|
|
|
|
|
|
read_offset: input_token_ids.len(),
|
|
|
|
|
|
reversed_text: reverse(text),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
|
|
|
|
|
|
let skip_special_token = true;
|
|
|
|
|
|
self.token_ids.push(token_id);
|
|
|
|
|
|
|
|
|
|
|
|
let prefix_text = self
|
|
|
|
|
|
.tokenizer
|
|
|
|
|
|
.decode(
|
|
|
|
|
|
&self.token_ids[self.prefix_offset..self.read_offset],
|
|
|
|
|
|
skip_special_token,
|
|
|
|
|
|
)
|
|
|
|
|
|
.expect("Cannot decode token from tokenizer.");
|
|
|
|
|
|
|
|
|
|
|
|
let new_text = self
|
|
|
|
|
|
.tokenizer
|
|
|
|
|
|
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
|
|
|
|
|
|
.expect("Cannot decode token from tokenizer.");
|
|
|
|
|
|
|
|
|
|
|
|
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
|
|
|
|
|
|
self.prefix_offset = self.read_offset;
|
|
|
|
|
|
self.read_offset = self.token_ids.len();
|
|
|
|
|
|
&new_text[prefix_text.len()..]
|
|
|
|
|
|
} else {
|
|
|
|
|
|
""
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
if !new_text.is_empty() {
|
|
|
|
|
|
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
2023-10-02 15:39:15 +00:00
|
|
|
|
|
|
|
|
|
|
if let Some(re) = &self.stop_re {
|
2023-09-29 13:06:47 +00:00
|
|
|
|
if re.find(&self.reversed_text).is_some() {
|
|
|
|
|
|
return None;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Some(new_text.to_owned())
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|