diff --git a/Cargo.lock b/Cargo.lock index 051d658..8ec5e7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2504,7 +2504,6 @@ dependencies = [ "hyper", "lazy_static", "mime_guess", - "regex", "rust-embed", "serde", "serde_json", diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 5202a65..3dd9613 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -69,20 +69,18 @@ pub struct TextInferenceOptions { #[builder(default = "1.0")] sampling_temperature: f32, - #[builder(default = "vec!()")] - stop_words: Vec + stop_words: &'static Vec<&'static str> } pub struct InferenceContext { - stop_regexp: Regex, + stop_re: Option, cancel: CancellationToken, output_text: String } impl InferenceContext { - fn new(stop_words: Vec, cancel: CancellationToken) -> Self { - let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap(); - InferenceContext { stop_regexp, cancel, output_text: "".to_owned() } + fn new(stop_re: Option, cancel: CancellationToken) -> Self { + InferenceContext { stop_re, cancel, output_text: "".to_owned() } } } @@ -114,7 +112,16 @@ impl TextInferenceEngine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let context = InferenceContext::new(options.stop_words, cancel_for_inference); + let stop_re = if options.stop_words.is_empty() { + None + } else { + let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap(); + let stop_tokens : Vec = encodings.iter().map(|x| x.get_tokens().join("")).collect(); + let regex_string = r"(?m)".to_owned() + &stop_tokens.join("|"); + Some(Regex::new(®ex_string).unwrap()) + }; + + let context = InferenceContext::new(stop_re, cancel_for_inference); let output_tokens = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( @@ -137,7 +144,13 @@ impl TextInferenceEngine { } }) .collect(); - self.tokenizer.decode(output_ids, true).unwrap() + let output_text = self.tokenizer.decode(output_ids, true).unwrap(); + for stop_word in options.stop_words { + if let Some(stripped_text) = output_text.strip_suffix(stop_word) { + return stripped_text.to_string(); + } + } + output_text } } @@ -145,9 +158,9 @@ fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u if context.cancel.is_cancelled() { true } else { - context.output_text.push_str(&token); - if let Some(_) = context.stop_regexp.find(&context.output_text) { - true + if let Some(re) = &context.stop_re { + context.output_text.push_str(&token); + re.find(&context.output_text).is_some() } else { false } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 19b2abf..9714b15 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -19,7 +19,6 @@ serdeconv = { workspace = true } serde_json = "1.0" tower-http = { version = "0.4.0", features = ["cors"] } clap = { version = "4.3.0", features = ["derive"] } -regex = "1.8.3" lazy_static = { workspace = true } rust-embed = "6.6.1" mime_guess = "2.0.4" diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 49af6d0..8edaf87 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder}; use tabby_common::{events, path::ModelDir}; use utoipa::ToSchema; +use self::languages::get_stop_words; + mod languages; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -57,9 +59,11 @@ pub async fn completion( State(state): State>, Json(request): Json, ) -> Json { + let language = request.language.unwrap_or("unknown".into()); let options = TextInferenceOptionsBuilder::default() - .max_decoding_length(64) - .sampling_temperature(0.2) + .max_decoding_length(128) + .sampling_temperature(0.1) + .stop_words(get_stop_words(&language)) .build() .expect("Invalid TextInferenceOptions"); @@ -80,30 +84,27 @@ pub async fn completion( request.prompt.expect("No prompt is set") }; + let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let text = state.engine.inference(&prompt, options).await; - let language = request.language.unwrap_or("unknown".into()); - let filtered_text = languages::remove_stop_words(&language, &text); - - let response = CompletionResponse { - id: format!("cmpl-{}", uuid::Uuid::new_v4()), - choices: vec![Choice { - index: 0, - text: filtered_text.to_string(), - }], - }; events::Event::Completion { - completion_id: &response.id, + completion_id: &completion_id, language: &language, prompt: &prompt, choices: vec![events::Choice { index: 0, - text: filtered_text, + text: &text, }], } .log(); - Json(response) + Json(CompletionResponse { + id: completion_id, + choices: vec![Choice { + index: 0, + text, + }], + }) } pub struct CompletionState { diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs index 004f964..a0f4ef3 100644 --- a/crates/tabby/src/serve/completions/languages.rs +++ b/crates/tabby/src/serve/completions/languages.rs @@ -1,26 +1,19 @@ use std::collections::HashMap; use lazy_static::lazy_static; -use regex::Regex; lazy_static! { - static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap(); - static ref LANGUAGES: HashMap<&'static str, Regex> = { + static ref DEFAULT: Vec<&'static str> = vec!("\n\n"); + static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = { let mut map = HashMap::new(); map.insert( "python", - Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(), + vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass") ); map }; } -pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str { - let re = LANGUAGES.get(language).unwrap_or(&DEFAULT); - let position = re.find_iter(text).next(); - if let Some(m) = position { - &text[..m.start()] - } else { - text - } +pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> { + LANGUAGES.get(language).unwrap_or(&DEFAULT) }