diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 19dcf5f..87be81b 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -63,6 +63,8 @@ pub struct CTranslate2EngineOptions { num_replicas_per_device: usize, compute_type: String, + + stop_words_encoding_offset: Option, } pub struct InferenceContext { @@ -85,6 +87,7 @@ pub struct CTranslate2Engine { engine: cxx::SharedPtr, tokenizer: Tokenizer, stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, + stop_words_encoding_offset: Option, } impl CTranslate2Engine { @@ -102,6 +105,7 @@ impl CTranslate2Engine { engine, stop_regex_cache: DashMap::new(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), + stop_words_encoding_offset: options.stop_words_encoding_offset, }; } } @@ -123,7 +127,11 @@ impl TextGeneration for CTranslate2Engine { if re.is_none() { self.stop_regex_cache.insert( options.stop_words, - create_stop_regex(&self.tokenizer, options.stop_words), + create_stop_regex( + &self.tokenizer, + options.stop_words, + self.stop_words_encoding_offset, + ), ); re = self.stop_regex_cache.get(options.stop_words); } @@ -156,7 +164,7 @@ fn inference_callback( if context.cancel.is_cancelled() { true } else if let Some(re) = &context.stop_re { - let mut new_token = reverse(token); + let mut new_token = reverse(&token); new_token.push_str(&context.reversed_output_text); context.reversed_output_text = new_token; re.find(&context.reversed_output_text).is_some() @@ -165,19 +173,37 @@ fn inference_callback( } } -fn reverse(s: String) -> String { - s.chars().rev().collect() +fn reverse(s: &String) -> String { + // Special treatment for byte fallback token. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs + if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') { + // Keep byte fallback tokens like <0x0A> as is, do not reverse it. + // This won't really affect stop words regex logic, but brings more readability when + // debugging decoding steps. + s.to_owned() + } else { + s.chars().rev().collect() + } } -fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex { +fn create_stop_regex( + tokenizer: &Tokenizer, + stop_words: &[&str], + stop_words_encoding_offset: Option, +) -> Regex { let encodings = tokenizer .encode_batch(stop_words.to_owned(), false) .unwrap(); let stop_tokens: Vec = encodings .iter() - .map(|x| x.get_tokens().join("")) - // Reverse for efficient suffix matching. - .map(reverse) + .map(|x| { + x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..] + .iter() + .rev() + .map(reverse) + .collect::>() + .join("") + }) .collect(); // (?m) enables multi-line matching mode. diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 5d682ff..62014d0 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -140,6 +140,7 @@ impl CompletionState { .device_indices(args.device_indices.clone()) .num_replicas_per_device(args.num_replicas_per_device) .compute_type(compute_type) + .stop_words_encoding_offset(metadata.stop_words_encoding_offset) .build() .unwrap(); let engine = CTranslate2Engine::create(options); @@ -165,6 +166,7 @@ fn get_model_dir(model: &str) -> ModelDir { struct Metadata { auto_model: String, prompt_template: Option, + stop_words_encoding_offset: Option, } fn read_metadata(model_dir: &ModelDir) -> Metadata { diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index 7cbd537..ba9fe94 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -52,7 +52,7 @@ impl PromptBuilder { if let Some(suffix) = segments.suffix { self.build_prompt(segments.prefix, suffix) } else { - self.build_prompt(segments.prefix, "".to_owned()) + self.build_prompt(segments.prefix, "\n".to_owned()) } }