fix: correct Decoding behavior in incremental manner (#491)

* feat: implement IncrementalDecoding

* refactor: use IncrementalDecoding for ctranslate2

* refactor: rename StopWords to DecodingFactory

* refactor: move decoding logic to tabby-inference

* feat: optimize decoding range

* cleanup
release-0.2
Meng Zhang 2023-09-29 06:06:47 -07:00 committed by GitHub
parent 52c4ef38d3
commit 486e507079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 191 additions and 161 deletions

14
Cargo.lock generated
View File

@ -700,7 +700,6 @@ dependencies = [
"derive_builder",
"futures",
"rust-cxx-cmake-bridge",
"stop-words",
"tabby-inference",
"tokenizers",
"tokio",
@ -1661,7 +1660,6 @@ dependencies = [
"cxx-build",
"derive_builder",
"futures",
"stop-words",
"tabby-inference",
"tokenizers",
"tokio",
@ -2940,15 +2938,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stop-words"
version = "0.1.0"
dependencies = [
"dashmap",
"regex",
"tokenizers",
]
[[package]]
name = "strfmt"
version = "0.2.4"
@ -3122,8 +3111,11 @@ version = "0.1.0"
dependencies = [
"async-stream",
"async-trait",
"dashmap",
"derive_builder",
"futures",
"regex",
"tokenizers",
]
[[package]]

View File

@ -9,7 +9,6 @@ members = [
"crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/stop-words",
"crates/http-api-bindings",
]

View File

@ -11,7 +11,6 @@ tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true }
stop-words = { path = "../stop-words" }
futures.workspace = true
async-stream.workspace = true

View File

@ -4,8 +4,10 @@ use async_stream::stream;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use stop_words::{StopWords, StopWordsCondition};
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tabby_inference::{
decoding::{DecodingFactory, IncrementalDecoding},
helpers, TextGeneration, TextGenerationOptions,
};
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc::{channel, Sender};
use tokio_util::sync::CancellationToken;
@ -70,20 +72,20 @@ pub struct CTranslate2EngineOptions {
}
pub struct InferenceContext {
sender: Sender<u32>,
stop_condition: StopWordsCondition,
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
}
impl InferenceContext {
fn new(
sender: Sender<u32>,
stop_condition: StopWordsCondition,
sender: Sender<String>,
decoding: IncrementalDecoding,
cancel: CancellationToken,
) -> Self {
InferenceContext {
sender,
stop_condition,
decoding,
cancel,
}
}
@ -91,7 +93,7 @@ impl InferenceContext {
pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
stop_words: StopWords,
decoding_factory: DecodingFactory,
tokenizer: Arc<Tokenizer>,
}
@ -108,7 +110,7 @@ impl CTranslate2Engine {
return Self {
engine,
stop_words: StopWords::default(),
decoding_factory: DecodingFactory::default(),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
};
}
@ -133,12 +135,12 @@ impl TextGeneration for CTranslate2Engine {
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let decoding = self
.decoding_factory
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
let (sender, mut receiver) = channel::<u32>(8);
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
tokio::task::spawn(async move {
let context = Box::new(context);
engine.inference(
@ -150,8 +152,7 @@ impl TextGeneration for CTranslate2Engine {
);
});
while let Some(next_token_id) = receiver.recv().await {
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
while let Some(text) = receiver.recv().await {
yield text;
}
};
@ -159,7 +160,7 @@ impl TextGeneration for CTranslate2Engine {
}
}
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
@ -174,10 +175,12 @@ fn inference_callback(
token_id: u32,
_token: String,
) -> bool {
let _ = context.sender.blocking_send(token_id);
if context.cancel.is_cancelled() {
true
} else if let Some(new_text) = context.decoding.next_token(token_id) {
let _ = context.sender.blocking_send(new_text);
false
} else {
context.stop_condition.next_token(token_id)
true
}
}

View File

@ -14,7 +14,6 @@ tokio = { workspace = true, features = ["rt"] }
tabby-inference = { path = "../tabby-inference" }
derive_builder = { workspace = true }
tokenizers = { workspace = true }
stop-words = { version = "0.1.0", path = "../stop-words" }
tokio-util = { workspace = true }
futures.workspace = true
async-stream.workspace = true

View File

@ -9,8 +9,8 @@ class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0;
virtual uint32_t step(uint32_t next_token_id) const = 0;
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
virtual uint32_t step() const = 0;
virtual void end() const = 0;
virtual uint32_t eos_token() const = 0;

View File

@ -45,22 +45,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
ctx_(std::move(ctx)) {
}
uint32_t start(const rust::Str prompt, size_t max_input_length) const override {
void start(rust::Slice<const uint32_t> input_token_ids) const override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ false);
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
}
return sample();
}
uint32_t step(uint32_t next_token_id) const override {
const llama_token id = next_token_id;
uint32_t step() const override {
const llama_token id = sample();
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
return sample();
return id;
}
void end() const override {

View File

@ -5,8 +5,7 @@ use async_trait::async_trait;
use derive_builder::Builder;
use ffi::create_engine;
use futures::{lock::Mutex, stream::BoxStream};
use stop_words::StopWords;
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer;
#[cxx::bridge(namespace = "llama")]
@ -18,8 +17,8 @@ mod ffi {
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
fn start(&self, prompt: &str, max_input_length: usize) -> u32;
fn step(&self, next_token_id: u32) -> u32;
fn start(&self, input_token_ids: &[u32]);
fn step(&self) -> u32;
fn end(&self);
fn eos_token(&self) -> u32;
@ -38,7 +37,7 @@ pub struct LlamaEngineOptions {
pub struct LlamaEngine {
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
stop_words: StopWords,
decoding_factory: DecodingFactory,
}
impl LlamaEngine {
@ -46,7 +45,7 @@ impl LlamaEngine {
LlamaEngine {
engine: Mutex::new(create_engine(&options.model_path)),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
stop_words: StopWords::default(),
decoding_factory: DecodingFactory::default(),
}
}
}
@ -63,35 +62,29 @@ impl TextGeneration for LlamaEngine {
prompt: &str,
options: TextGenerationOptions,
) -> BoxStream<String> {
let prompt = prompt.to_owned();
let mut stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let s = stream! {
let engine = self.engine.lock().await;
let eos_token = engine.eos_token();
let mut next_token_id = engine.start(&prompt, options.max_input_length);
if next_token_id == eos_token {
yield "".to_owned();
} else {
let mut n_remains = options.max_decoding_length - 1;
while n_remains > 0 {
next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}
if stop_condition.next_token(next_token_id) {
break;
}
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
yield text;
n_remains -= 1;
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.step();
if next_token_id == eos_token {
break;
}
if let Some(new_text) = decoding.next_token(next_token_id) {
yield new_text;
} else {
break;
}
n_remains -= 1;
}
engine.end();
@ -100,3 +93,12 @@ impl TextGeneration for LlamaEngine {
Box::pin(s)
}
}
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
if max_length < tokens.len() {
let start = tokens.len() - max_length;
&tokens[start..]
} else {
tokens
}
}

View File

@ -1,11 +0,0 @@
[package]
name = "stop-words"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
dashmap = "5.5.3"
regex = "1.9.5"
tokenizers.workspace = true

View File

@ -1,80 +0,0 @@
use std::sync::Arc;
use dashmap::DashMap;
use regex::Regex;
use tokenizers::tokenizer::Tokenizer;
pub struct StopWords {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
}
fn reverse(s: &&str) -> String {
s.chars().rev().collect()
}
impl Default for StopWords {
fn default() -> Self {
Self {
stop_regex_cache: DashMap::new(),
}
}
}
impl StopWords {
pub fn create_condition(
&self,
tokenizer: Arc<Tokenizer>,
stop_words: &'static Vec<&'static str>,
) -> StopWordsCondition {
let re = if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
}
re.map(|x| x.value().clone())
};
StopWordsCondition::new(tokenizer, re)
}
}
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(reverse).collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
Regex::new(&regex_string).unwrap()
}
pub struct StopWordsCondition {
tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>,
reversed_output_text: String,
}
impl StopWordsCondition {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>) -> Self {
Self {
tokenizer,
stop_re,
reversed_output_text: String::new(),
}
}
pub fn next_token(&mut self, token_id: u32) -> bool {
if let Some(re) = &self.stop_re {
let token = self.tokenizer.decode(&[token_id], false).unwrap();
let mut new_token = reverse(&token.as_str());
new_token.push_str(&self.reversed_output_text);
self.reversed_output_text = new_token;
re.find(&self.reversed_output_text).is_some()
} else {
false
}
}
}

View File

@ -8,5 +8,8 @@ edition = "2021"
[dependencies]
async-stream = { workspace = true }
async-trait = { workspace = true }
dashmap = "5.5.3"
derive_builder = "0.12.0"
futures = { workspace = true }
regex = "1.9.5"
tokenizers.workspace = true

View File

@ -0,0 +1,123 @@
use std::sync::Arc;
use dashmap::DashMap;
use regex::Regex;
use tokenizers::tokenizer::Tokenizer;
pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
}
fn reverse<T>(s: T) -> String
where
T: Into<String>,
{
s.into().chars().rev().collect()
}
impl Default for DecodingFactory {
fn default() -> Self {
Self {
stop_regex_cache: DashMap::new(),
}
}
}
impl DecodingFactory {
pub fn create_incremental_decoding(
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
}
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
}
re.map(|x| x.value().clone())
}
}
}
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
Regex::new(&regex_string).unwrap()
}
pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>,
token_ids: Vec<u32>,
prefix_offset: usize,
read_offset: usize,
reversed_text: String,
}
impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
Self {
tokenizer,
stop_re,
token_ids: input_token_ids.to_owned(),
prefix_offset: 0,
read_offset: input_token_ids.len(),
reversed_text: reverse(text),
}
}
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
let skip_special_token = true;
self.token_ids.push(token_id);
let prefix_text = self
.tokenizer
.decode(
&self.token_ids[self.prefix_offset..self.read_offset],
skip_special_token,
)
.expect("Cannot decode token from tokenizer.");
let new_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
.expect("Cannot decode token from tokenizer.");
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
self.prefix_offset = self.read_offset;
self.read_offset = self.token_ids.len();
&new_text[prefix_text.len()..]
} else {
""
};
if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;
if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() {
return None;
}
}
}
Some(new_text.to_owned())
}
}

View File

@ -1,3 +1,5 @@
pub mod decoding;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;