2023-09-29 13:06:47 +00:00
|
|
|
use dashmap::DashMap;
|
2023-10-06 09:04:37 +00:00
|
|
|
use regex::Regex;
|
2023-10-16 00:24:44 +00:00
|
|
|
use tabby_common::languages::Language;
|
2023-09-29 13:06:47 +00:00
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
pub struct StopConditionFactory {
|
2023-10-16 00:24:44 +00:00
|
|
|
stop_regex_cache: DashMap<String, Regex>,
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn reverse<T>(s: T) -> String
|
|
|
|
|
where
|
|
|
|
|
T: Into<String>,
|
|
|
|
|
{
|
|
|
|
|
s.into().chars().rev().collect()
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
impl Default for StopConditionFactory {
|
2023-09-29 13:06:47 +00:00
|
|
|
fn default() -> Self {
|
|
|
|
|
Self {
|
|
|
|
|
stop_regex_cache: DashMap::new(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
impl StopConditionFactory {
|
2023-11-28 08:57:16 +00:00
|
|
|
pub fn create(
|
|
|
|
|
&self,
|
|
|
|
|
text: &str,
|
|
|
|
|
max_decoding_length: usize,
|
|
|
|
|
language: Option<&'static Language>,
|
|
|
|
|
) -> StopCondition {
|
2023-11-26 03:17:31 +00:00
|
|
|
if let Some(language) = language {
|
2023-11-28 08:57:16 +00:00
|
|
|
StopCondition::new(self.get_re(language), max_decoding_length, text)
|
2023-11-26 03:17:31 +00:00
|
|
|
} else {
|
2023-11-28 08:57:16 +00:00
|
|
|
StopCondition::new(None, max_decoding_length, text)
|
2023-11-26 03:17:31 +00:00
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-16 00:24:44 +00:00
|
|
|
fn get_re(&self, language: &'static Language) -> Option<Regex> {
|
|
|
|
|
let stop_words = language.get_stop_words();
|
2023-09-29 13:06:47 +00:00
|
|
|
if stop_words.is_empty() {
|
|
|
|
|
None
|
|
|
|
|
} else {
|
2023-10-16 00:24:44 +00:00
|
|
|
let hashkey = language.get_hashkey();
|
|
|
|
|
let mut re = self.stop_regex_cache.get(&hashkey);
|
2023-09-29 13:06:47 +00:00
|
|
|
if re.is_none() {
|
|
|
|
|
self.stop_regex_cache
|
2023-10-16 00:24:44 +00:00
|
|
|
.insert(hashkey.clone(), create_stop_regex(stop_words));
|
|
|
|
|
re = self.stop_regex_cache.get(&hashkey);
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
re.map(|x| x.value().clone())
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-11-28 08:57:16 +00:00
|
|
|
|
|
|
|
|
pub fn trim_stop_words(&self, language: &'static Language, text: &str) -> Option<String> {
|
|
|
|
|
let Some(re) = self.get_re(language) else {
|
|
|
|
|
return None;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let text = reverse(text);
|
|
|
|
|
|
|
|
|
|
let text = if let Some(m) = re.find_at(&text, 0) {
|
|
|
|
|
&text[m.end()..]
|
|
|
|
|
} else {
|
|
|
|
|
&text
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Some(reverse(text))
|
|
|
|
|
}
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-16 00:24:44 +00:00
|
|
|
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
|
2023-09-29 13:06:47 +00:00
|
|
|
// (?m) enables multi-line matching mode.
|
|
|
|
|
// \A means absolute begins of string.
|
2023-10-06 09:04:37 +00:00
|
|
|
let reversed_stop_words: Vec<_> = stop_words
|
2023-10-02 23:21:51 +00:00
|
|
|
.iter()
|
2023-10-16 00:24:44 +00:00
|
|
|
.map(|x| regex::escape(&reverse(x)))
|
2023-10-02 23:21:51 +00:00
|
|
|
.collect();
|
2023-10-06 09:04:37 +00:00
|
|
|
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
|
|
|
|
|
Regex::new(®ex_string).expect("Failed to create regex")
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
pub struct StopCondition {
|
2023-10-06 09:04:37 +00:00
|
|
|
stop_re: Option<Regex>,
|
2023-11-28 08:57:16 +00:00
|
|
|
max_decoding_length: usize,
|
2023-09-29 13:06:47 +00:00
|
|
|
reversed_text: String,
|
2023-11-28 08:57:16 +00:00
|
|
|
num_decoded: usize,
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
impl StopCondition {
|
2023-11-28 08:57:16 +00:00
|
|
|
pub fn new(stop_re: Option<Regex>, max_decoding_length: usize, text: &str) -> Self {
|
2023-09-29 13:06:47 +00:00
|
|
|
Self {
|
|
|
|
|
stop_re,
|
2023-11-28 08:57:16 +00:00
|
|
|
max_decoding_length,
|
2023-09-29 13:06:47 +00:00
|
|
|
reversed_text: reverse(text),
|
2023-11-28 08:57:16 +00:00
|
|
|
num_decoded: 0,
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-31 22:16:09 +00:00
|
|
|
pub fn should_stop(&mut self, new_text: &str) -> bool {
|
2023-09-29 13:06:47 +00:00
|
|
|
if !new_text.is_empty() {
|
|
|
|
|
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
2023-10-02 15:39:15 +00:00
|
|
|
|
|
|
|
|
if let Some(re) = &self.stop_re {
|
2023-10-02 23:21:51 +00:00
|
|
|
if re.is_match(&self.reversed_text) {
|
2023-10-31 22:16:09 +00:00
|
|
|
return true;
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-28 08:57:16 +00:00
|
|
|
self.num_decoded += 1;
|
|
|
|
|
self.num_decoded >= self.max_decoding_length
|
2023-09-29 13:06:47 +00:00
|
|
|
}
|
|
|
|
|
}
|
2023-10-02 23:21:51 +00:00
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
#[test]
|
2023-10-06 09:04:37 +00:00
|
|
|
fn test_it_works() {
|
2023-10-02 23:21:51 +00:00
|
|
|
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
|
2023-10-16 00:24:44 +00:00
|
|
|
assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
|
|
|
|
|
assert!(create_stop_regex(vec![
|
|
|
|
|
"\n\n".to_owned(),
|
|
|
|
|
"\n\n ".to_owned(),
|
|
|
|
|
"\nvoid".to_owned()
|
|
|
|
|
])
|
|
|
|
|
.is_match(&text));
|
2023-10-02 23:21:51 +00:00
|
|
|
}
|
|
|
|
|
}
|