fix: correct Decoding behavior in incremental manner (#491)
* feat: implement IncrementalDecoding * refactor: use IncrementalDecoding for ctranslate2 * refactor: rename StopWords to DecodingFactory * refactor: move decoding logic to tabby-inference * feat: optimize decoding range * cleanuprelease-0.2
parent
52c4ef38d3
commit
486e507079
|
|
@ -700,7 +700,6 @@ dependencies = [
|
|||
"derive_builder",
|
||||
"futures",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"stop-words",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
@ -1661,7 +1660,6 @@ dependencies = [
|
|||
"cxx-build",
|
||||
"derive_builder",
|
||||
"futures",
|
||||
"stop-words",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
@ -2940,15 +2938,6 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"dashmap",
|
||||
"regex",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strfmt"
|
||||
version = "0.2.4"
|
||||
|
|
@ -3122,8 +3111,11 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"dashmap",
|
||||
"derive_builder",
|
||||
"futures",
|
||||
"regex",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ members = [
|
|||
"crates/ctranslate2-bindings",
|
||||
"crates/rust-cxx-cmake-bridge",
|
||||
"crates/llama-cpp-bindings",
|
||||
"crates/stop-words",
|
||||
"crates/http-api-bindings",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ tokio = { workspace = true, features = ["rt"] }
|
|||
tokio-util = { workspace = true }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
async-trait = { workspace = true }
|
||||
stop-words = { path = "../stop-words" }
|
||||
futures.workspace = true
|
||||
async-stream.workspace = true
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ use async_stream::stream;
|
|||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
use stop_words::{StopWords, StopWordsCondition};
|
||||
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||
use tabby_inference::{
|
||||
decoding::{DecodingFactory, IncrementalDecoding},
|
||||
helpers, TextGeneration, TextGenerationOptions,
|
||||
};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc::{channel, Sender};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
|
@ -70,20 +72,20 @@ pub struct CTranslate2EngineOptions {
|
|||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
sender: Sender<u32>,
|
||||
stop_condition: StopWordsCondition,
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(
|
||||
sender: Sender<u32>,
|
||||
stop_condition: StopWordsCondition,
|
||||
sender: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
cancel: CancellationToken,
|
||||
) -> Self {
|
||||
InferenceContext {
|
||||
sender,
|
||||
stop_condition,
|
||||
decoding,
|
||||
cancel,
|
||||
}
|
||||
}
|
||||
|
|
@ -91,7 +93,7 @@ impl InferenceContext {
|
|||
|
||||
pub struct CTranslate2Engine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
stop_words: StopWords,
|
||||
decoding_factory: DecodingFactory,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
}
|
||||
|
||||
|
|
@ -108,7 +110,7 @@ impl CTranslate2Engine {
|
|||
|
||||
return Self {
|
||||
engine,
|
||||
stop_words: StopWords::default(),
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
};
|
||||
}
|
||||
|
|
@ -133,12 +135,12 @@ impl TextGeneration for CTranslate2Engine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let stop_condition = self
|
||||
.stop_words
|
||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||
let decoding = self
|
||||
.decoding_factory
|
||||
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||
|
||||
let (sender, mut receiver) = channel::<u32>(8);
|
||||
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
|
||||
let (sender, mut receiver) = channel::<String>(8);
|
||||
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||
tokio::task::spawn(async move {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
|
|
@ -150,8 +152,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
);
|
||||
});
|
||||
|
||||
while let Some(next_token_id) = receiver.recv().await {
|
||||
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
||||
while let Some(text) = receiver.recv().await {
|
||||
yield text;
|
||||
}
|
||||
};
|
||||
|
|
@ -159,7 +160,7 @@ impl TextGeneration for CTranslate2Engine {
|
|||
}
|
||||
}
|
||||
|
||||
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
|
||||
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
|
||||
if max_length < tokens.len() {
|
||||
let start = tokens.len() - max_length;
|
||||
&tokens[start..]
|
||||
|
|
@ -174,10 +175,12 @@ fn inference_callback(
|
|||
token_id: u32,
|
||||
_token: String,
|
||||
) -> bool {
|
||||
let _ = context.sender.blocking_send(token_id);
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(new_text) = context.decoding.next_token(token_id) {
|
||||
let _ = context.sender.blocking_send(new_text);
|
||||
false
|
||||
} else {
|
||||
context.stop_condition.next_token(token_id)
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ tokio = { workspace = true, features = ["rt"] }
|
|||
tabby-inference = { path = "../tabby-inference" }
|
||||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
stop-words = { version = "0.1.0", path = "../stop-words" }
|
||||
tokio-util = { workspace = true }
|
||||
futures.workspace = true
|
||||
async-stream.workspace = true
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ class TextInferenceEngine {
|
|||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
|
||||
virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0;
|
||||
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
|
||||
virtual uint32_t step() const = 0;
|
||||
virtual void end() const = 0;
|
||||
|
||||
virtual uint32_t eos_token() const = 0;
|
||||
|
|
|
|||
|
|
@ -45,22 +45,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
ctx_(std::move(ctx)) {
|
||||
}
|
||||
|
||||
uint32_t start(const rust::Str prompt, size_t max_input_length) const override {
|
||||
void start(rust::Slice<const uint32_t> input_token_ids) const override {
|
||||
auto* ctx = ctx_.get();
|
||||
llama_reset_timings(ctx);
|
||||
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ false);
|
||||
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
||||
|
||||
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
||||
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
||||
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
||||
}
|
||||
return sample();
|
||||
}
|
||||
|
||||
uint32_t step(uint32_t next_token_id) const override {
|
||||
const llama_token id = next_token_id;
|
||||
uint32_t step() const override {
|
||||
const llama_token id = sample();
|
||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
||||
return sample();
|
||||
return id;
|
||||
}
|
||||
|
||||
void end() const override {
|
||||
|
|
|
|||
|
|
@ -5,8 +5,7 @@ use async_trait::async_trait;
|
|||
use derive_builder::Builder;
|
||||
use ffi::create_engine;
|
||||
use futures::{lock::Mutex, stream::BoxStream};
|
||||
use stop_words::StopWords;
|
||||
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
||||
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
#[cxx::bridge(namespace = "llama")]
|
||||
|
|
@ -18,8 +17,8 @@ mod ffi {
|
|||
|
||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
||||
|
||||
fn start(&self, prompt: &str, max_input_length: usize) -> u32;
|
||||
fn step(&self, next_token_id: u32) -> u32;
|
||||
fn start(&self, input_token_ids: &[u32]);
|
||||
fn step(&self) -> u32;
|
||||
fn end(&self);
|
||||
|
||||
fn eos_token(&self) -> u32;
|
||||
|
|
@ -38,7 +37,7 @@ pub struct LlamaEngineOptions {
|
|||
pub struct LlamaEngine {
|
||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_words: StopWords,
|
||||
decoding_factory: DecodingFactory,
|
||||
}
|
||||
|
||||
impl LlamaEngine {
|
||||
|
|
@ -46,7 +45,7 @@ impl LlamaEngine {
|
|||
LlamaEngine {
|
||||
engine: Mutex::new(create_engine(&options.model_path)),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
stop_words: StopWords::default(),
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -63,35 +62,29 @@ impl TextGeneration for LlamaEngine {
|
|||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let prompt = prompt.to_owned();
|
||||
let mut stop_condition = self
|
||||
.stop_words
|
||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
|
||||
let s = stream! {
|
||||
let engine = self.engine.lock().await;
|
||||
let eos_token = engine.eos_token();
|
||||
|
||||
let mut next_token_id = engine.start(&prompt, options.max_input_length);
|
||||
if next_token_id == eos_token {
|
||||
yield "".to_owned();
|
||||
} else {
|
||||
let mut n_remains = options.max_decoding_length - 1;
|
||||
|
||||
while n_remains > 0 {
|
||||
next_token_id = engine.step(next_token_id);
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if stop_condition.next_token(next_token_id) {
|
||||
break;
|
||||
}
|
||||
|
||||
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
||||
yield text;
|
||||
n_remains -= 1;
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let next_token_id = engine.step();
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(new_text) = decoding.next_token(next_token_id) {
|
||||
yield new_text;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
n_remains -= 1;
|
||||
}
|
||||
|
||||
engine.end();
|
||||
|
|
@ -100,3 +93,12 @@ impl TextGeneration for LlamaEngine {
|
|||
Box::pin(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
|
||||
if max_length < tokens.len() {
|
||||
let start = tokens.len() - max_length;
|
||||
&tokens[start..]
|
||||
} else {
|
||||
tokens
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +0,0 @@
|
|||
[package]
|
||||
name = "stop-words"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
dashmap = "5.5.3"
|
||||
regex = "1.9.5"
|
||||
tokenizers.workspace = true
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct StopWords {
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
}
|
||||
|
||||
fn reverse(s: &&str) -> String {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
|
||||
impl Default for StopWords {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stop_regex_cache: DashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StopWords {
|
||||
pub fn create_condition(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
) -> StopWordsCondition {
|
||||
let re = if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(stop_words);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache
|
||||
.insert(stop_words, create_stop_regex(stop_words));
|
||||
re = self.stop_regex_cache.get(stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
};
|
||||
|
||||
StopWordsCondition::new(tokenizer, re)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(reverse).collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
||||
pub struct StopWordsCondition {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_re: Option<Regex>,
|
||||
reversed_output_text: String,
|
||||
}
|
||||
|
||||
impl StopWordsCondition {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
stop_re,
|
||||
reversed_output_text: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_token(&mut self, token_id: u32) -> bool {
|
||||
if let Some(re) = &self.stop_re {
|
||||
let token = self.tokenizer.decode(&[token_id], false).unwrap();
|
||||
let mut new_token = reverse(&token.as_str());
|
||||
new_token.push_str(&self.reversed_output_text);
|
||||
self.reversed_output_text = new_token;
|
||||
re.find(&self.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,5 +8,8 @@ edition = "2021"
|
|||
[dependencies]
|
||||
async-stream = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
dashmap = "5.5.3"
|
||||
derive_builder = "0.12.0"
|
||||
futures = { workspace = true }
|
||||
regex = "1.9.5"
|
||||
tokenizers.workspace = true
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct DecodingFactory {
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
}
|
||||
|
||||
fn reverse<T>(s: T) -> String
|
||||
where
|
||||
T: Into<String>,
|
||||
{
|
||||
s.into().chars().rev().collect()
|
||||
}
|
||||
|
||||
impl Default for DecodingFactory {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stop_regex_cache: DashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DecodingFactory {
|
||||
pub fn create_incremental_decoding(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
input_token_ids: &[u32],
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
) -> IncrementalDecoding {
|
||||
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||
}
|
||||
|
||||
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||
if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(stop_words);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache
|
||||
.insert(stop_words, create_stop_regex(stop_words));
|
||||
re = self.stop_regex_cache.get(stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
||||
pub struct IncrementalDecoding {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_re: Option<Regex>,
|
||||
|
||||
token_ids: Vec<u32>,
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
|
||||
reversed_text: String,
|
||||
}
|
||||
|
||||
impl IncrementalDecoding {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
||||
let text = tokenizer
|
||||
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
Self {
|
||||
tokenizer,
|
||||
stop_re,
|
||||
token_ids: input_token_ids.to_owned(),
|
||||
prefix_offset: 0,
|
||||
read_offset: input_token_ids.len(),
|
||||
reversed_text: reverse(text),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
|
||||
let skip_special_token = true;
|
||||
self.token_ids.push(token_id);
|
||||
|
||||
let prefix_text = self
|
||||
.tokenizer
|
||||
.decode(
|
||||
&self.token_ids[self.prefix_offset..self.read_offset],
|
||||
skip_special_token,
|
||||
)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
||||
let new_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
|
||||
.expect("Cannot decode token from tokenizer.");
|
||||
|
||||
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_ids.len();
|
||||
&new_text[prefix_text.len()..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
if !new_text.is_empty() {
|
||||
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||
|
||||
if let Some(re) = &self.stop_re {
|
||||
if re.find(&self.reversed_text).is_some() {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(new_text.to_owned())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
pub mod decoding;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
|
|
|
|||
Loading…
Reference in New Issue