From fd1baff8d5f51d18eec6ba71bb49da811265b874 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 6 Jun 2023 16:28:58 -0700 Subject: [PATCH] feat: support stop sequences [TAB-52] (#212) * refactor: pass step and string token to callback * add token to callback * add stop regexp * implement stop words logic * pass token_ids from inference * improve effiency of regexp match with reversed regex * fmt * add typescript and javascript stop words * add cache for stop words regexp --- Cargo.lock | 20 +++- crates/ctranslate2-bindings/Cargo.toml | 2 + .../include/ctranslate2.h | 6 +- .../ctranslate2-bindings/src/ctranslate2.cc | 49 +++++---- crates/ctranslate2-bindings/src/lib.rs | 101 +++++++++++++++--- crates/tabby/Cargo.toml | 1 - crates/tabby/src/serve/completions.rs | 28 +++-- .../tabby/src/serve/completions/languages.rs | 32 +++--- 8 files changed, 169 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ebd2492..149a2d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -578,7 +578,9 @@ dependencies = [ "cmake", "cxx", "cxx-build", + "dashmap", "derive_builder", + "regex", "rust-cxx-cmake-bridge", "tokenizers", "tokio", @@ -664,6 +666,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "derive_builder" version = "0.12.0" @@ -2022,9 +2037,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.3" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" dependencies = [ "aho-corasick 1.0.1", "memchr", @@ -2503,7 +2518,6 @@ dependencies = [ "hyper", "lazy_static", "mime_guess", - "regex", "rust-embed", "serde", "serde_json", diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 046e303..b5a45fb 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -5,7 +5,9 @@ edition = "2021" [dependencies] cxx = "1.0" +dashmap = "5.4.0" derive_builder = "0.12.0" +regex = "1.8.4" tokenizers = "0.13.3" tokio = { workspace = true, features = ["rt"] } tokio-util = { workspace = true } diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index b1f67db..aa67db2 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -7,12 +7,14 @@ namespace tabby { struct InferenceContext; +typedef rust::Fn InferenceCallback; + class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual rust::Vec inference( + virtual rust::Vec inference( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index beb289c..3b71b19 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -15,27 +15,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine { }; public: - rust::Vec inference( + rust::Vec inference( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature ) const { // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); - const auto output_tokens = process( + return process( std::move(context), - std::move(is_context_cancelled), + std::move(callback), input_tokens, Options{max_decoding_length, sampling_temperature} ); - - // Convert to rust vec. - rust::Vec output; - output.reserve(output_tokens.size()); - std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); - return output; } static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { @@ -45,9 +39,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const = 0; std::unique_ptr model_; @@ -55,28 +49,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine { class EncoderDecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::TranslationOptions x; x.max_decoding_length = options.max_decoding_length; x.sampling_temperature = options.sampling_temperature; x.beam_size = 1; + rust::Vec output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { - return is_context_cancelled(*context); + bool stop = callback(*context, result.step, result.token_id, result.token); + if (!stop) { + output_ids.push_back(result.token_id); + } else if (result.is_last) { + output_ids.push_back(result.token_id); + } + return stop; }; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; - return std::move(result.output()); + return output_ids; } }; class DecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::GenerationOptions x; @@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { - return is_context_cancelled(*context); + bool stop = callback(*context, result.step, result.token_id, result.token); + if (!stop) { + output_ids.push_back(result.token_id); + } else if (result.is_last) { + output_ids.push_back(result.token_id); + } + return stop; }; ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); - return std::move(result.sequences[0]); + return output_ids; } }; diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 2861972..62ed19b 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,3 +1,5 @@ +use dashmap::DashMap; +use regex::Regex; use tokenizers::tokenizer::Tokenizer; use tokio_util::sync::CancellationToken; @@ -26,11 +28,19 @@ mod ffi { fn inference( &self, context: Box, - is_context_cancelled: fn(&InferenceContext) -> bool, + callback: fn( + &mut InferenceContext, + // step + usize, + // token_id + u32, + // token + String, + ) -> bool, tokens: &[String], max_decoding_length: usize, sampling_temperature: f32, - ) -> Vec; + ) -> Vec; } } @@ -59,13 +69,30 @@ pub struct TextInferenceOptions { #[builder(default = "1.0")] sampling_temperature: f32, + + stop_words: &'static Vec<&'static str>, } -struct InferenceContext(CancellationToken); +pub struct InferenceContext { + stop_re: Option, + cancel: CancellationToken, + reversed_output_text: String, +} + +impl InferenceContext { + fn new(stop_re: Option, cancel: CancellationToken) -> Self { + InferenceContext { + stop_re, + cancel, + reversed_output_text: "".to_owned(), + } + } +} pub struct TextInferenceEngine { engine: cxx::SharedPtr, tokenizer: Tokenizer, + stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, } impl TextInferenceEngine { @@ -79,6 +106,7 @@ impl TextInferenceEngine { ); return TextInferenceEngine { engine, + stop_regex_cache: DashMap::new(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), }; } @@ -91,12 +119,26 @@ impl TextInferenceEngine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let context = InferenceContext(cancel_for_inference); - let output_tokens = tokio::task::spawn_blocking(move || { + let stop_re: Option = if options.stop_words.is_empty() { + None + } else { + let mut re = self.stop_regex_cache.get(options.stop_words); + if re.is_none() { + self.stop_regex_cache.insert( + options.stop_words, + create_stop_regex(&self.tokenizer, options.stop_words), + ); + re = self.stop_regex_cache.get(options.stop_words); + } + re.map(|x| x.value().clone()) + }; + + let context = InferenceContext::new(stop_re, cancel_for_inference); + let output_ids = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( context, - |context| context.0.is_cancelled(), + inference_callback, encoding.get_tokens(), options.max_decoding_length, options.sampling_temperature, @@ -104,16 +146,43 @@ impl TextInferenceEngine { }) .await .expect("Inference failed"); - let output_ids: Vec = output_tokens - .iter() - .filter_map(|x| match self.tokenizer.token_to_id(x) { - Some(y) => Some(y), - None => { - println!("Warning: token ({}) missed in vocab", x); - None - } - }) - .collect(); self.tokenizer.decode(output_ids, true).unwrap() } } + +fn inference_callback( + context: &mut InferenceContext, + _step: usize, + _token_id: u32, + token: String, +) -> bool { + if context.cancel.is_cancelled() { + true + } else if let Some(re) = &context.stop_re { + let mut new_token = reverse(token); + new_token.push_str(&context.reversed_output_text); + context.reversed_output_text = new_token; + re.find(&context.reversed_output_text).is_some() + } else { + false + } +} + +fn reverse(s: String) -> String { + s.chars().rev().collect() +} + +fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex { + let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap(); + let stop_tokens: Vec = encodings + .iter() + .map(|x| x.get_tokens().join("")) + // Reverse for efficient suffix matching. + .map(reverse) + .collect(); + + // (?m) enables multi-line matching mode. + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); + Regex::new(®ex_string).unwrap() +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 19b2abf..9714b15 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -19,7 +19,6 @@ serdeconv = { workspace = true } serde_json = "1.0" tower-http = { version = "0.4.0", features = ["cors"] } clap = { version = "4.3.0", features = ["derive"] } -regex = "1.8.3" lazy_static = { workspace = true } rust-embed = "6.6.1" mime_guess = "2.0.4" diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 49af6d0..0e4c136 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder}; use tabby_common::{events, path::ModelDir}; use utoipa::ToSchema; +use self::languages::get_stop_words; + mod languages; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -57,9 +59,11 @@ pub async fn completion( State(state): State>, Json(request): Json, ) -> Json { + let language = request.language.unwrap_or("unknown".into()); let options = TextInferenceOptionsBuilder::default() - .max_decoding_length(64) - .sampling_temperature(0.2) + .max_decoding_length(128) + .sampling_temperature(0.1) + .stop_words(get_stop_words(&language)) .build() .expect("Invalid TextInferenceOptions"); @@ -80,30 +84,24 @@ pub async fn completion( request.prompt.expect("No prompt is set") }; + let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let text = state.engine.inference(&prompt, options).await; - let language = request.language.unwrap_or("unknown".into()); - let filtered_text = languages::remove_stop_words(&language, &text); - - let response = CompletionResponse { - id: format!("cmpl-{}", uuid::Uuid::new_v4()), - choices: vec![Choice { - index: 0, - text: filtered_text.to_string(), - }], - }; events::Event::Completion { - completion_id: &response.id, + completion_id: &completion_id, language: &language, prompt: &prompt, choices: vec![events::Choice { index: 0, - text: filtered_text, + text: &text, }], } .log(); - Json(response) + Json(CompletionResponse { + id: completion_id, + choices: vec![Choice { index: 0, text }], + }) } pub struct CompletionState { diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs index 004f964..b9a5713 100644 --- a/crates/tabby/src/serve/completions/languages.rs +++ b/crates/tabby/src/serve/completions/languages.rs @@ -1,26 +1,32 @@ use std::collections::HashMap; use lazy_static::lazy_static; -use regex::Regex; lazy_static! { - static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap(); - static ref LANGUAGES: HashMap<&'static str, Regex> = { + static ref DEFAULT: Vec<&'static str> = vec!("\n\n"); + static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = { let mut map = HashMap::new(); + map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]); map.insert( - "python", - Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(), + "javascript", + vec!["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"], + ); + map.insert( + "typescript", + vec![ + "\n\n", + "\nfunction", + "\n//", + "\nimport", + "\nclass", + "\ninterface", + "\ntype", + ], ); map }; } -pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str { - let re = LANGUAGES.get(language).unwrap_or(&DEFAULT); - let position = re.find_iter(text).next(); - if let Some(m) = position { - &text[..m.start()] - } else { - text - } +pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> { + LANGUAGES.get(language).unwrap_or(&DEFAULT) }