From 486e507079bf79d190dbb358f099e6e1736565ba Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 29 Sep 2023 06:06:47 -0700 Subject: [PATCH] fix: correct Decoding behavior in incremental manner (#491) * feat: implement IncrementalDecoding * refactor: use IncrementalDecoding for ctranslate2 * refactor: rename StopWords to DecodingFactory * refactor: move decoding logic to tabby-inference * feat: optimize decoding range * cleanup --- Cargo.lock | 14 +-- Cargo.toml | 1 - crates/ctranslate2-bindings/Cargo.toml | 1 - crates/ctranslate2-bindings/src/lib.rs | 41 +++---- crates/llama-cpp-bindings/Cargo.toml | 1 - crates/llama-cpp-bindings/include/engine.h | 4 +- crates/llama-cpp-bindings/src/engine.cc | 11 +- crates/llama-cpp-bindings/src/lib.rs | 60 +++++----- crates/stop-words/Cargo.toml | 11 -- crates/stop-words/src/lib.rs | 80 -------------- crates/tabby-inference/Cargo.toml | 3 + crates/tabby-inference/src/decoding.rs | 123 +++++++++++++++++++++ crates/tabby-inference/src/lib.rs | 2 + 13 files changed, 191 insertions(+), 161 deletions(-) delete mode 100644 crates/stop-words/Cargo.toml delete mode 100644 crates/stop-words/src/lib.rs create mode 100644 crates/tabby-inference/src/decoding.rs diff --git a/Cargo.lock b/Cargo.lock index 5014842..20cc422 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -700,7 +700,6 @@ dependencies = [ "derive_builder", "futures", "rust-cxx-cmake-bridge", - "stop-words", "tabby-inference", "tokenizers", "tokio", @@ -1661,7 +1660,6 @@ dependencies = [ "cxx-build", "derive_builder", "futures", - "stop-words", "tabby-inference", "tokenizers", "tokio", @@ -2940,15 +2938,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "stop-words" -version = "0.1.0" -dependencies = [ - "dashmap", - "regex", - "tokenizers", -] - [[package]] name = "strfmt" version = "0.2.4" @@ -3122,8 +3111,11 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "dashmap", "derive_builder", "futures", + "regex", + "tokenizers", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a54eb4b..51e543b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,6 @@ members = [ "crates/ctranslate2-bindings", "crates/rust-cxx-cmake-bridge", "crates/llama-cpp-bindings", - "crates/stop-words", "crates/http-api-bindings", ] diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 9c96506..81d8d7b 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -11,7 +11,6 @@ tokio = { workspace = true, features = ["rt"] } tokio-util = { workspace = true } tabby-inference = { path = "../tabby-inference" } async-trait = { workspace = true } -stop-words = { path = "../stop-words" } futures.workspace = true async-stream.workspace = true diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index afb8b42..25ce843 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -4,8 +4,10 @@ use async_stream::stream; use async_trait::async_trait; use derive_builder::Builder; use futures::stream::BoxStream; -use stop_words::{StopWords, StopWordsCondition}; -use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; +use tabby_inference::{ + decoding::{DecodingFactory, IncrementalDecoding}, + helpers, TextGeneration, TextGenerationOptions, +}; use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc::{channel, Sender}; use tokio_util::sync::CancellationToken; @@ -70,20 +72,20 @@ pub struct CTranslate2EngineOptions { } pub struct InferenceContext { - sender: Sender, - stop_condition: StopWordsCondition, + sender: Sender, + decoding: IncrementalDecoding, cancel: CancellationToken, } impl InferenceContext { fn new( - sender: Sender, - stop_condition: StopWordsCondition, + sender: Sender, + decoding: IncrementalDecoding, cancel: CancellationToken, ) -> Self { InferenceContext { sender, - stop_condition, + decoding, cancel, } } @@ -91,7 +93,7 @@ impl InferenceContext { pub struct CTranslate2Engine { engine: cxx::SharedPtr, - stop_words: StopWords, + decoding_factory: DecodingFactory, tokenizer: Arc, } @@ -108,7 +110,7 @@ impl CTranslate2Engine { return Self { engine, - stop_words: StopWords::default(), + decoding_factory: DecodingFactory::default(), tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), }; } @@ -133,12 +135,12 @@ impl TextGeneration for CTranslate2Engine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let stop_condition = self - .stop_words - .create_condition(self.tokenizer.clone(), options.stop_words); + let decoding = self + .decoding_factory + .create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words); - let (sender, mut receiver) = channel::(8); - let context = InferenceContext::new(sender, stop_condition, cancel_for_inference); + let (sender, mut receiver) = channel::(8); + let context = InferenceContext::new(sender, decoding, cancel_for_inference); tokio::task::spawn(async move { let context = Box::new(context); engine.inference( @@ -150,8 +152,7 @@ impl TextGeneration for CTranslate2Engine { ); }); - while let Some(next_token_id) = receiver.recv().await { - let text = self.tokenizer.decode(&[next_token_id], true).unwrap(); + while let Some(text) = receiver.recv().await { yield text; } }; @@ -159,7 +160,7 @@ impl TextGeneration for CTranslate2Engine { } } -fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] { +fn truncate_tokens(tokens: &[T], max_length: usize) -> &[T] { if max_length < tokens.len() { let start = tokens.len() - max_length; &tokens[start..] @@ -174,10 +175,12 @@ fn inference_callback( token_id: u32, _token: String, ) -> bool { - let _ = context.sender.blocking_send(token_id); if context.cancel.is_cancelled() { true + } else if let Some(new_text) = context.decoding.next_token(token_id) { + let _ = context.sender.blocking_send(new_text); + false } else { - context.stop_condition.next_token(token_id) + true } } diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index 65d7d13..8e4d840 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -14,7 +14,6 @@ tokio = { workspace = true, features = ["rt"] } tabby-inference = { path = "../tabby-inference" } derive_builder = { workspace = true } tokenizers = { workspace = true } -stop-words = { version = "0.1.0", path = "../stop-words" } tokio-util = { workspace = true } futures.workspace = true async-stream.workspace = true diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index 0fae02f..1dad36d 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -9,8 +9,8 @@ class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0; - virtual uint32_t step(uint32_t next_token_id) const = 0; + virtual void start(rust::Slice input_token_ids) const = 0; + virtual uint32_t step() const = 0; virtual void end() const = 0; virtual uint32_t eos_token() const = 0; diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 5c380e9..e433c41 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -45,22 +45,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine { ctx_(std::move(ctx)) { } - uint32_t start(const rust::Str prompt, size_t max_input_length) const override { + void start(rust::Slice input_token_ids) const override { auto* ctx = ctx_.get(); llama_reset_timings(ctx); - std::vector tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ false); + std::vector tokens_list(input_token_ids.begin(), input_token_ids.end()); for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) { const size_t size = std::min(N_BATCH, tokens_list.size() - i); eval(tokens_list.data() + i, size, /* reset = */ i == 0); } - return sample(); } - uint32_t step(uint32_t next_token_id) const override { - const llama_token id = next_token_id; + uint32_t step() const override { + const llama_token id = sample(); eval(const_cast(&id), 1, /* reset = */ false); - return sample(); + return id; } void end() const override { diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index c9aab91..ec42066 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -5,8 +5,7 @@ use async_trait::async_trait; use derive_builder::Builder; use ffi::create_engine; use futures::{lock::Mutex, stream::BoxStream}; -use stop_words::StopWords; -use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; +use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; #[cxx::bridge(namespace = "llama")] @@ -18,8 +17,8 @@ mod ffi { fn create_engine(model_path: &str) -> SharedPtr; - fn start(&self, prompt: &str, max_input_length: usize) -> u32; - fn step(&self, next_token_id: u32) -> u32; + fn start(&self, input_token_ids: &[u32]); + fn step(&self) -> u32; fn end(&self); fn eos_token(&self) -> u32; @@ -38,7 +37,7 @@ pub struct LlamaEngineOptions { pub struct LlamaEngine { engine: Mutex>, tokenizer: Arc, - stop_words: StopWords, + decoding_factory: DecodingFactory, } impl LlamaEngine { @@ -46,7 +45,7 @@ impl LlamaEngine { LlamaEngine { engine: Mutex::new(create_engine(&options.model_path)), tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), - stop_words: StopWords::default(), + decoding_factory: DecodingFactory::default(), } } } @@ -63,35 +62,29 @@ impl TextGeneration for LlamaEngine { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - let prompt = prompt.to_owned(); - let mut stop_condition = self - .stop_words - .create_condition(self.tokenizer.clone(), options.stop_words); + let encoding = self.tokenizer.encode(prompt, true).unwrap(); let s = stream! { let engine = self.engine.lock().await; let eos_token = engine.eos_token(); - let mut next_token_id = engine.start(&prompt, options.max_input_length); - if next_token_id == eos_token { - yield "".to_owned(); - } else { - let mut n_remains = options.max_decoding_length - 1; - - while n_remains > 0 { - next_token_id = engine.step(next_token_id); - if next_token_id == eos_token { - break; - } - - if stop_condition.next_token(next_token_id) { - break; - } - - let text = self.tokenizer.decode(&[next_token_id], true).unwrap(); - yield text; - n_remains -= 1; + let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); + engine.start(input_token_ids); + let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words); + let mut n_remains = options.max_decoding_length ; + while n_remains > 0 { + let next_token_id = engine.step(); + if next_token_id == eos_token { + break; } + + if let Some(new_text) = decoding.next_token(next_token_id) { + yield new_text; + } else { + break; + } + + n_remains -= 1; } engine.end(); @@ -100,3 +93,12 @@ impl TextGeneration for LlamaEngine { Box::pin(s) } } + +fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] { + if max_length < tokens.len() { + let start = tokens.len() - max_length; + &tokens[start..] + } else { + tokens + } +} diff --git a/crates/stop-words/Cargo.toml b/crates/stop-words/Cargo.toml deleted file mode 100644 index fea97ae..0000000 --- a/crates/stop-words/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "stop-words" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -dashmap = "5.5.3" -regex = "1.9.5" -tokenizers.workspace = true diff --git a/crates/stop-words/src/lib.rs b/crates/stop-words/src/lib.rs deleted file mode 100644 index b18a362..0000000 --- a/crates/stop-words/src/lib.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::sync::Arc; - -use dashmap::DashMap; -use regex::Regex; -use tokenizers::tokenizer::Tokenizer; - -pub struct StopWords { - stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, -} - -fn reverse(s: &&str) -> String { - s.chars().rev().collect() -} - -impl Default for StopWords { - fn default() -> Self { - Self { - stop_regex_cache: DashMap::new(), - } - } -} - -impl StopWords { - pub fn create_condition( - &self, - tokenizer: Arc, - stop_words: &'static Vec<&'static str>, - ) -> StopWordsCondition { - let re = if stop_words.is_empty() { - None - } else { - let mut re = self.stop_regex_cache.get(stop_words); - if re.is_none() { - self.stop_regex_cache - .insert(stop_words, create_stop_regex(stop_words)); - re = self.stop_regex_cache.get(stop_words); - } - re.map(|x| x.value().clone()) - }; - - StopWordsCondition::new(tokenizer, re) - } -} - -fn create_stop_regex(stop_words: &[&str]) -> Regex { - let tokens: Vec = stop_words.iter().map(reverse).collect(); - - // (?m) enables multi-line matching mode. - // \A means absolute begins of string. - let regex_string = r"(?m)\A".to_owned() + &tokens.join("|"); - Regex::new(®ex_string).unwrap() -} - -pub struct StopWordsCondition { - tokenizer: Arc, - stop_re: Option, - reversed_output_text: String, -} - -impl StopWordsCondition { - pub fn new(tokenizer: Arc, stop_re: Option) -> Self { - Self { - tokenizer, - stop_re, - reversed_output_text: String::new(), - } - } - - pub fn next_token(&mut self, token_id: u32) -> bool { - if let Some(re) = &self.stop_re { - let token = self.tokenizer.decode(&[token_id], false).unwrap(); - let mut new_token = reverse(&token.as_str()); - new_token.push_str(&self.reversed_output_text); - self.reversed_output_text = new_token; - re.find(&self.reversed_output_text).is_some() - } else { - false - } - } -} diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index fa29afa..9ca1af5 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -8,5 +8,8 @@ edition = "2021" [dependencies] async-stream = { workspace = true } async-trait = { workspace = true } +dashmap = "5.5.3" derive_builder = "0.12.0" futures = { workspace = true } +regex = "1.9.5" +tokenizers.workspace = true diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs new file mode 100644 index 0000000..78cb1a7 --- /dev/null +++ b/crates/tabby-inference/src/decoding.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use regex::Regex; +use tokenizers::tokenizer::Tokenizer; + +pub struct DecodingFactory { + stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, +} + +fn reverse(s: T) -> String +where + T: Into, +{ + s.into().chars().rev().collect() +} + +impl Default for DecodingFactory { + fn default() -> Self { + Self { + stop_regex_cache: DashMap::new(), + } + } +} + +impl DecodingFactory { + pub fn create_incremental_decoding( + &self, + tokenizer: Arc, + input_token_ids: &[u32], + stop_words: &'static Vec<&'static str>, + ) -> IncrementalDecoding { + IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) + } + + fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option { + if stop_words.is_empty() { + None + } else { + let mut re = self.stop_regex_cache.get(stop_words); + if re.is_none() { + self.stop_regex_cache + .insert(stop_words, create_stop_regex(stop_words)); + re = self.stop_regex_cache.get(stop_words); + } + re.map(|x| x.value().clone()) + } + } +} + +fn create_stop_regex(stop_words: &[&str]) -> Regex { + let tokens: Vec = stop_words.iter().map(|x| reverse(*x)).collect(); + + // (?m) enables multi-line matching mode. + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &tokens.join("|"); + Regex::new(®ex_string).unwrap() +} + +pub struct IncrementalDecoding { + tokenizer: Arc, + stop_re: Option, + + token_ids: Vec, + prefix_offset: usize, + read_offset: usize, + + reversed_text: String, +} + +impl IncrementalDecoding { + pub fn new(tokenizer: Arc, stop_re: Option, input_token_ids: &[u32]) -> Self { + let text = tokenizer + .decode(input_token_ids, /* skip_special_token = */ true) + .expect("Cannot decode token from tokenizer."); + Self { + tokenizer, + stop_re, + token_ids: input_token_ids.to_owned(), + prefix_offset: 0, + read_offset: input_token_ids.len(), + reversed_text: reverse(text), + } + } + + pub fn next_token(&mut self, token_id: u32) -> Option { + let skip_special_token = true; + self.token_ids.push(token_id); + + let prefix_text = self + .tokenizer + .decode( + &self.token_ids[self.prefix_offset..self.read_offset], + skip_special_token, + ) + .expect("Cannot decode token from tokenizer."); + + let new_text = self + .tokenizer + .decode(&self.token_ids[self.prefix_offset..], skip_special_token) + .expect("Cannot decode token from tokenizer."); + + let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('�') { + self.prefix_offset = self.read_offset; + self.read_offset = self.token_ids.len(); + &new_text[prefix_text.len()..] + } else { + "" + }; + + if !new_text.is_empty() { + self.reversed_text = reverse(new_text) + &self.reversed_text; + + if let Some(re) = &self.stop_re { + if re.find(&self.reversed_text).is_some() { + return None; + } + } + } + + Some(new_text.to_owned()) + } +} diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 04cad0d..495785e 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -1,3 +1,5 @@ +pub mod decoding; + use async_trait::async_trait; use derive_builder::Builder; use futures::stream::BoxStream;