fix: correct Decoding behavior in incremental manner (#491)
* feat: implement IncrementalDecoding * refactor: use IncrementalDecoding for ctranslate2 * refactor: rename StopWords to DecodingFactory * refactor: move decoding logic to tabby-inference * feat: optimize decoding range * cleanuprelease-0.2
parent
52c4ef38d3
commit
486e507079
|
|
@ -700,7 +700,6 @@ dependencies = [
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"futures",
|
"futures",
|
||||||
"rust-cxx-cmake-bridge",
|
"rust-cxx-cmake-bridge",
|
||||||
"stop-words",
|
|
||||||
"tabby-inference",
|
"tabby-inference",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -1661,7 +1660,6 @@ dependencies = [
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"futures",
|
"futures",
|
||||||
"stop-words",
|
|
||||||
"tabby-inference",
|
"tabby-inference",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -2940,15 +2938,6 @@ version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "stop-words"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"dashmap",
|
|
||||||
"regex",
|
|
||||||
"tokenizers",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strfmt"
|
name = "strfmt"
|
||||||
version = "0.2.4"
|
version = "0.2.4"
|
||||||
|
|
@ -3122,8 +3111,11 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
"dashmap",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"futures",
|
"futures",
|
||||||
|
"regex",
|
||||||
|
"tokenizers",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ members = [
|
||||||
"crates/ctranslate2-bindings",
|
"crates/ctranslate2-bindings",
|
||||||
"crates/rust-cxx-cmake-bridge",
|
"crates/rust-cxx-cmake-bridge",
|
||||||
"crates/llama-cpp-bindings",
|
"crates/llama-cpp-bindings",
|
||||||
"crates/stop-words",
|
|
||||||
"crates/http-api-bindings",
|
"crates/http-api-bindings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ tokio = { workspace = true, features = ["rt"] }
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
tabby-inference = { path = "../tabby-inference" }
|
tabby-inference = { path = "../tabby-inference" }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
stop-words = { path = "../stop-words" }
|
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
async-stream.workspace = true
|
async-stream.workspace = true
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@ use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use stop_words::{StopWords, StopWordsCondition};
|
use tabby_inference::{
|
||||||
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
decoding::{DecodingFactory, IncrementalDecoding},
|
||||||
|
helpers, TextGeneration, TextGenerationOptions,
|
||||||
|
};
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::mpsc::{channel, Sender};
|
use tokio::sync::mpsc::{channel, Sender};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
@ -70,20 +72,20 @@ pub struct CTranslate2EngineOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
sender: Sender<u32>,
|
sender: Sender<String>,
|
||||||
stop_condition: StopWordsCondition,
|
decoding: IncrementalDecoding,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceContext {
|
impl InferenceContext {
|
||||||
fn new(
|
fn new(
|
||||||
sender: Sender<u32>,
|
sender: Sender<String>,
|
||||||
stop_condition: StopWordsCondition,
|
decoding: IncrementalDecoding,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
InferenceContext {
|
InferenceContext {
|
||||||
sender,
|
sender,
|
||||||
stop_condition,
|
decoding,
|
||||||
cancel,
|
cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -91,7 +93,7 @@ impl InferenceContext {
|
||||||
|
|
||||||
pub struct CTranslate2Engine {
|
pub struct CTranslate2Engine {
|
||||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
stop_words: StopWords,
|
decoding_factory: DecodingFactory,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -108,7 +110,7 @@ impl CTranslate2Engine {
|
||||||
|
|
||||||
return Self {
|
return Self {
|
||||||
engine,
|
engine,
|
||||||
stop_words: StopWords::default(),
|
decoding_factory: DecodingFactory::default(),
|
||||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -133,12 +135,12 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
let cancel_for_inference = cancel.clone();
|
let cancel_for_inference = cancel.clone();
|
||||||
let _guard = cancel.drop_guard();
|
let _guard = cancel.drop_guard();
|
||||||
|
|
||||||
let stop_condition = self
|
let decoding = self
|
||||||
.stop_words
|
.decoding_factory
|
||||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);
|
||||||
|
|
||||||
let (sender, mut receiver) = channel::<u32>(8);
|
let (sender, mut receiver) = channel::<String>(8);
|
||||||
let context = InferenceContext::new(sender, stop_condition, cancel_for_inference);
|
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
|
|
@ -150,8 +152,7 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
while let Some(next_token_id) = receiver.recv().await {
|
while let Some(text) = receiver.recv().await {
|
||||||
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
|
||||||
yield text;
|
yield text;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -159,7 +160,7 @@ impl TextGeneration for CTranslate2Engine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncate_tokens(tokens: &[String], max_length: usize) -> &[String] {
|
fn truncate_tokens<T>(tokens: &[T], max_length: usize) -> &[T] {
|
||||||
if max_length < tokens.len() {
|
if max_length < tokens.len() {
|
||||||
let start = tokens.len() - max_length;
|
let start = tokens.len() - max_length;
|
||||||
&tokens[start..]
|
&tokens[start..]
|
||||||
|
|
@ -174,10 +175,12 @@ fn inference_callback(
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
_token: String,
|
_token: String,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let _ = context.sender.blocking_send(token_id);
|
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
|
} else if let Some(new_text) = context.decoding.next_token(token_id) {
|
||||||
|
let _ = context.sender.blocking_send(new_text);
|
||||||
|
false
|
||||||
} else {
|
} else {
|
||||||
context.stop_condition.next_token(token_id)
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ tokio = { workspace = true, features = ["rt"] }
|
||||||
tabby-inference = { path = "../tabby-inference" }
|
tabby-inference = { path = "../tabby-inference" }
|
||||||
derive_builder = { workspace = true }
|
derive_builder = { workspace = true }
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
stop-words = { version = "0.1.0", path = "../stop-words" }
|
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
async-stream.workspace = true
|
async-stream.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
|
|
||||||
virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0;
|
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
|
||||||
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
virtual uint32_t step() const = 0;
|
||||||
virtual void end() const = 0;
|
virtual void end() const = 0;
|
||||||
|
|
||||||
virtual uint32_t eos_token() const = 0;
|
virtual uint32_t eos_token() const = 0;
|
||||||
|
|
|
||||||
|
|
@ -45,22 +45,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
ctx_(std::move(ctx)) {
|
ctx_(std::move(ctx)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t start(const rust::Str prompt, size_t max_input_length) const override {
|
void start(rust::Slice<const uint32_t> input_token_ids) const override {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
llama_reset_timings(ctx);
|
llama_reset_timings(ctx);
|
||||||
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ false);
|
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
||||||
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
||||||
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
||||||
}
|
}
|
||||||
return sample();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t step(uint32_t next_token_id) const override {
|
uint32_t step() const override {
|
||||||
const llama_token id = next_token_id;
|
const llama_token id = sample();
|
||||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
||||||
return sample();
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void end() const override {
|
void end() const override {
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,7 @@ use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use ffi::create_engine;
|
use ffi::create_engine;
|
||||||
use futures::{lock::Mutex, stream::BoxStream};
|
use futures::{lock::Mutex, stream::BoxStream};
|
||||||
use stop_words::StopWords;
|
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
|
||||||
use tabby_inference::{helpers, TextGeneration, TextGenerationOptions};
|
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "llama")]
|
#[cxx::bridge(namespace = "llama")]
|
||||||
|
|
@ -18,8 +17,8 @@ mod ffi {
|
||||||
|
|
||||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn start(&self, prompt: &str, max_input_length: usize) -> u32;
|
fn start(&self, input_token_ids: &[u32]);
|
||||||
fn step(&self, next_token_id: u32) -> u32;
|
fn step(&self) -> u32;
|
||||||
fn end(&self);
|
fn end(&self);
|
||||||
|
|
||||||
fn eos_token(&self) -> u32;
|
fn eos_token(&self) -> u32;
|
||||||
|
|
@ -38,7 +37,7 @@ pub struct LlamaEngineOptions {
|
||||||
pub struct LlamaEngine {
|
pub struct LlamaEngine {
|
||||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
stop_words: StopWords,
|
decoding_factory: DecodingFactory,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaEngine {
|
impl LlamaEngine {
|
||||||
|
|
@ -46,7 +45,7 @@ impl LlamaEngine {
|
||||||
LlamaEngine {
|
LlamaEngine {
|
||||||
engine: Mutex::new(create_engine(&options.model_path)),
|
engine: Mutex::new(create_engine(&options.model_path)),
|
||||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||||
stop_words: StopWords::default(),
|
decoding_factory: DecodingFactory::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -63,35 +62,29 @@ impl TextGeneration for LlamaEngine {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
options: TextGenerationOptions,
|
options: TextGenerationOptions,
|
||||||
) -> BoxStream<String> {
|
) -> BoxStream<String> {
|
||||||
let prompt = prompt.to_owned();
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
let mut stop_condition = self
|
|
||||||
.stop_words
|
|
||||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
|
||||||
|
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
let engine = self.engine.lock().await;
|
let engine = self.engine.lock().await;
|
||||||
let eos_token = engine.eos_token();
|
let eos_token = engine.eos_token();
|
||||||
|
|
||||||
let mut next_token_id = engine.start(&prompt, options.max_input_length);
|
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||||
if next_token_id == eos_token {
|
engine.start(input_token_ids);
|
||||||
yield "".to_owned();
|
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
|
||||||
} else {
|
let mut n_remains = options.max_decoding_length ;
|
||||||
let mut n_remains = options.max_decoding_length - 1;
|
while n_remains > 0 {
|
||||||
|
let next_token_id = engine.step();
|
||||||
while n_remains > 0 {
|
if next_token_id == eos_token {
|
||||||
next_token_id = engine.step(next_token_id);
|
break;
|
||||||
if next_token_id == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if stop_condition.next_token(next_token_id) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let text = self.tokenizer.decode(&[next_token_id], true).unwrap();
|
|
||||||
yield text;
|
|
||||||
n_remains -= 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(new_text) = decoding.next_token(next_token_id) {
|
||||||
|
yield new_text;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
n_remains -= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.end();
|
engine.end();
|
||||||
|
|
@ -100,3 +93,12 @@ impl TextGeneration for LlamaEngine {
|
||||||
Box::pin(s)
|
Box::pin(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn truncate_tokens(tokens: &[u32], max_length: usize) -> &[u32] {
|
||||||
|
if max_length < tokens.len() {
|
||||||
|
let start = tokens.len() - max_length;
|
||||||
|
&tokens[start..]
|
||||||
|
} else {
|
||||||
|
tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "stop-words"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
dashmap = "5.5.3"
|
|
||||||
regex = "1.9.5"
|
|
||||||
tokenizers.workspace = true
|
|
||||||
|
|
@ -1,80 +0,0 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use dashmap::DashMap;
|
|
||||||
use regex::Regex;
|
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
|
||||||
|
|
||||||
pub struct StopWords {
|
|
||||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reverse(s: &&str) -> String {
|
|
||||||
s.chars().rev().collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for StopWords {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
stop_regex_cache: DashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StopWords {
|
|
||||||
pub fn create_condition(
|
|
||||||
&self,
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
stop_words: &'static Vec<&'static str>,
|
|
||||||
) -> StopWordsCondition {
|
|
||||||
let re = if stop_words.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mut re = self.stop_regex_cache.get(stop_words);
|
|
||||||
if re.is_none() {
|
|
||||||
self.stop_regex_cache
|
|
||||||
.insert(stop_words, create_stop_regex(stop_words));
|
|
||||||
re = self.stop_regex_cache.get(stop_words);
|
|
||||||
}
|
|
||||||
re.map(|x| x.value().clone())
|
|
||||||
};
|
|
||||||
|
|
||||||
StopWordsCondition::new(tokenizer, re)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
|
||||||
let tokens: Vec<String> = stop_words.iter().map(reverse).collect();
|
|
||||||
|
|
||||||
// (?m) enables multi-line matching mode.
|
|
||||||
// \A means absolute begins of string.
|
|
||||||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
|
||||||
Regex::new(®ex_string).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct StopWordsCondition {
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
stop_re: Option<Regex>,
|
|
||||||
reversed_output_text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StopWordsCondition {
|
|
||||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>) -> Self {
|
|
||||||
Self {
|
|
||||||
tokenizer,
|
|
||||||
stop_re,
|
|
||||||
reversed_output_text: String::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn next_token(&mut self, token_id: u32) -> bool {
|
|
||||||
if let Some(re) = &self.stop_re {
|
|
||||||
let token = self.tokenizer.decode(&[token_id], false).unwrap();
|
|
||||||
let mut new_token = reverse(&token.as_str());
|
|
||||||
new_token.push_str(&self.reversed_output_text);
|
|
||||||
self.reversed_output_text = new_token;
|
|
||||||
re.find(&self.reversed_output_text).is_some()
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -8,5 +8,8 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = { workspace = true }
|
async-stream = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
|
dashmap = "5.5.3"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
|
regex = "1.9.5"
|
||||||
|
tokenizers.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use regex::Regex;
|
||||||
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
pub struct DecodingFactory {
|
||||||
|
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reverse<T>(s: T) -> String
|
||||||
|
where
|
||||||
|
T: Into<String>,
|
||||||
|
{
|
||||||
|
s.into().chars().rev().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DecodingFactory {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
stop_regex_cache: DashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodingFactory {
|
||||||
|
pub fn create_incremental_decoding(
|
||||||
|
&self,
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
input_token_ids: &[u32],
|
||||||
|
stop_words: &'static Vec<&'static str>,
|
||||||
|
) -> IncrementalDecoding {
|
||||||
|
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
|
||||||
|
if stop_words.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut re = self.stop_regex_cache.get(stop_words);
|
||||||
|
if re.is_none() {
|
||||||
|
self.stop_regex_cache
|
||||||
|
.insert(stop_words, create_stop_regex(stop_words));
|
||||||
|
re = self.stop_regex_cache.get(stop_words);
|
||||||
|
}
|
||||||
|
re.map(|x| x.value().clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||||
|
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();
|
||||||
|
|
||||||
|
// (?m) enables multi-line matching mode.
|
||||||
|
// \A means absolute begins of string.
|
||||||
|
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
||||||
|
Regex::new(®ex_string).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IncrementalDecoding {
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
stop_re: Option<Regex>,
|
||||||
|
|
||||||
|
token_ids: Vec<u32>,
|
||||||
|
prefix_offset: usize,
|
||||||
|
read_offset: usize,
|
||||||
|
|
||||||
|
reversed_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IncrementalDecoding {
|
||||||
|
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
|
||||||
|
let text = tokenizer
|
||||||
|
.decode(input_token_ids, /* skip_special_token = */ true)
|
||||||
|
.expect("Cannot decode token from tokenizer.");
|
||||||
|
Self {
|
||||||
|
tokenizer,
|
||||||
|
stop_re,
|
||||||
|
token_ids: input_token_ids.to_owned(),
|
||||||
|
prefix_offset: 0,
|
||||||
|
read_offset: input_token_ids.len(),
|
||||||
|
reversed_text: reverse(text),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next_token(&mut self, token_id: u32) -> Option<String> {
|
||||||
|
let skip_special_token = true;
|
||||||
|
self.token_ids.push(token_id);
|
||||||
|
|
||||||
|
let prefix_text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(
|
||||||
|
&self.token_ids[self.prefix_offset..self.read_offset],
|
||||||
|
skip_special_token,
|
||||||
|
)
|
||||||
|
.expect("Cannot decode token from tokenizer.");
|
||||||
|
|
||||||
|
let new_text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&self.token_ids[self.prefix_offset..], skip_special_token)
|
||||||
|
.expect("Cannot decode token from tokenizer.");
|
||||||
|
|
||||||
|
let new_text = if new_text.len() > prefix_text.len() && !new_text.ends_with('<27>') {
|
||||||
|
self.prefix_offset = self.read_offset;
|
||||||
|
self.read_offset = self.token_ids.len();
|
||||||
|
&new_text[prefix_text.len()..]
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
|
||||||
|
if !new_text.is_empty() {
|
||||||
|
self.reversed_text = reverse(new_text) + &self.reversed_text;
|
||||||
|
|
||||||
|
if let Some(re) = &self.stop_re {
|
||||||
|
if re.find(&self.reversed_text).is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(new_text.to_owned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
pub mod decoding;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue