use std::sync::Arc; use dashmap::DashMap; use regex::Regex; use tokenizers::tokenizer::Tokenizer; pub struct DecodingFactory { stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, } fn reverse(s: T) -> String where T: Into, { s.into().chars().rev().collect() } impl Default for DecodingFactory { fn default() -> Self { Self { stop_regex_cache: DashMap::new(), } } } impl DecodingFactory { pub fn create( &self, tokenizer: Arc, input_token_ids: &[u32], stop_words: &Vec, static_stop_words: &'static Vec<&'static str>, ) -> IncrementalDecoding { IncrementalDecoding::new( tokenizer, vec![ self.get_static_re(static_stop_words), self.get_re(stop_words), ] .into_iter() .flatten() .collect(), input_token_ids, ) } fn get_re(&self, stop_words: &Vec) -> Option { if !stop_words.is_empty() { Some(create_stop_regex(stop_words)) } else { None } } fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option { if stop_words.is_empty() { None } else { let mut re = self.stop_regex_cache.get(stop_words); if re.is_none() { self.stop_regex_cache .insert(stop_words, create_stop_regex(stop_words)); re = self.stop_regex_cache.get(stop_words); } re.map(|x| x.value().clone()) } } } fn create_stop_regex>(stop_words: &[T]) -> Regex { let tokens: Vec = stop_words.iter().map(|x| reverse(x.as_ref())).collect(); // (?m) enables multi-line matching mode. // \A means absolute begins of string. let regex_string = r"(?m)\A".to_owned() + &tokens.join("|"); Regex::new(®ex_string).unwrap() } pub struct IncrementalDecoding { tokenizer: Arc, stop_re: Vec, token_ids: Vec, prefix_offset: usize, read_offset: usize, reversed_text: String, } impl IncrementalDecoding { pub fn new(tokenizer: Arc, stop_re: Vec, input_token_ids: &[u32]) -> Self { let text = tokenizer .decode(input_token_ids, /* skip_special_token = */ true) .expect("Cannot decode token from tokenizer."); Self { tokenizer, stop_re, token_ids: input_token_ids.to_owned(), prefix_offset: 0, read_offset: input_token_ids.len(), reversed_text: reverse(text), } } pub fn next_token(&mut self, token_id: u32) -> Option { let skip_special_token = true; self.token_ids.push(token_id); let prefix_text = self .tokenizer .decode( &self.token_ids[self.prefix_offset..self.read_offset], skip_special_token, ) .expect("Cannot decode token from tokenizer."); let new_text = self .tokenizer .decode(&self.token_ids[self.prefix_offset..], skip_special_token) .expect("Cannot decode token from tokenizer."); let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('�') { self.prefix_offset = self.read_offset; self.read_offset = self.token_ids.len(); &new_text[prefix_text.len()..] } else { "" }; if !new_text.is_empty() { self.reversed_text = reverse(new_text) + &self.reversed_text; for re in &self.stop_re { if re.find(&self.reversed_text).is_some() { return None; } } } Some(new_text.to_owned()) } }