feat: support stop sequences [TAB-52] (#212)

* refactor: pass step and string token to callback

* add token to callback

* add stop regexp

* implement stop words logic

* pass token_ids from inference

* improve effiency of regexp match with reversed regex

* fmt

* add typescript and javascript stop words

* add cache for stop words regexp
improve-workflow
Meng Zhang 2023-06-06 16:28:58 -07:00 committed by GitHub
parent 787e195359
commit fd1baff8d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 169 additions and 70 deletions

20
Cargo.lock generated
View File

@ -578,7 +578,9 @@ dependencies = [
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"dashmap",
"derive_builder", "derive_builder",
"regex",
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"tokenizers", "tokenizers",
"tokio", "tokio",
@ -664,6 +666,19 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "derive_builder" name = "derive_builder"
version = "0.12.0" version = "0.12.0"
@ -2022,9 +2037,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.8.3" version = "1.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
dependencies = [ dependencies = [
"aho-corasick 1.0.1", "aho-corasick 1.0.1",
"memchr", "memchr",
@ -2503,7 +2518,6 @@ dependencies = [
"hyper", "hyper",
"lazy_static", "lazy_static",
"mime_guess", "mime_guess",
"regex",
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -5,7 +5,9 @@ edition = "2021"
[dependencies] [dependencies]
cxx = "1.0" cxx = "1.0"
dashmap = "5.4.0"
derive_builder = "0.12.0" derive_builder = "0.12.0"
regex = "1.8.4"
tokenizers = "0.13.3" tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] } tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true } tokio-util = { workspace = true }

View File

@ -7,12 +7,14 @@ namespace tabby {
struct InferenceContext; struct InferenceContext;
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
class TextInferenceEngine { class TextInferenceEngine {
public: public:
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference( virtual rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature float sampling_temperature

View File

@ -15,27 +15,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}; };
public: public:
rust::Vec<rust::String> inference( rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature float sampling_temperature
) const { ) const {
// Inference. // Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process( return process(
std::move(context), std::move(context),
std::move(is_context_cancelled), std::move(callback),
input_tokens, input_tokens,
Options{max_decoding_length, sampling_temperature} Options{max_decoding_length, sampling_temperature}
); );
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
} }
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) { static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
@ -45,9 +39,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const = 0; const Options& options) const = 0;
std::unique_ptr<Model> model_; std::unique_ptr<Model> model_;
@ -55,28 +49,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> { class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const override { const Options& options) const override {
ctranslate2::TranslationOptions x; ctranslate2::TranslationOptions x;
x.max_decoding_length = options.max_decoding_length; x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context); bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
}; };
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output()); return output_ids;
} }
}; };
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> { class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const override { const Options& options) const override {
ctranslate2::GenerationOptions x; ctranslate2::GenerationOptions x;
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.max_length = options.max_decoding_length; x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context); bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
}; };
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]); return output_ids;
} }
}; };

View File

@ -1,3 +1,5 @@
use dashmap::DashMap;
use regex::Regex;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@ -26,11 +28,19 @@ mod ffi {
fn inference( fn inference(
&self, &self,
context: Box<InferenceContext>, context: Box<InferenceContext>,
is_context_cancelled: fn(&InferenceContext) -> bool, callback: fn(
&mut InferenceContext,
// step
usize,
// token_id
u32,
// token
String,
) -> bool,
tokens: &[String], tokens: &[String],
max_decoding_length: usize, max_decoding_length: usize,
sampling_temperature: f32, sampling_temperature: f32,
) -> Vec<String>; ) -> Vec<u32>;
} }
} }
@ -59,13 +69,30 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
sampling_temperature: f32, sampling_temperature: f32,
stop_words: &'static Vec<&'static str>,
} }
struct InferenceContext(CancellationToken); pub struct InferenceContext {
stop_re: Option<Regex>,
cancel: CancellationToken,
reversed_output_text: String,
}
impl InferenceContext {
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
InferenceContext {
stop_re,
cancel,
reversed_output_text: "".to_owned(),
}
}
}
pub struct TextInferenceEngine { pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
} }
impl TextInferenceEngine { impl TextInferenceEngine {
@ -79,6 +106,7 @@ impl TextInferenceEngine {
); );
return TextInferenceEngine { return TextInferenceEngine {
engine, engine,
stop_regex_cache: DashMap::new(),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
}; };
} }
@ -91,12 +119,26 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let context = InferenceContext(cancel_for_inference); let stop_re: Option<Regex> = if options.stop_words.is_empty() {
let output_tokens = tokio::task::spawn_blocking(move || { None
} else {
let mut re = self.stop_regex_cache.get(options.stop_words);
if re.is_none() {
self.stop_regex_cache.insert(
options.stop_words,
create_stop_regex(&self.tokenizer, options.stop_words),
);
re = self.stop_regex_cache.get(options.stop_words);
}
re.map(|x| x.value().clone())
};
let context = InferenceContext::new(stop_re, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
context, context,
|context| context.0.is_cancelled(), inference_callback,
encoding.get_tokens(), encoding.get_tokens(),
options.max_decoding_length, options.max_decoding_length,
options.sampling_temperature, options.sampling_temperature,
@ -104,16 +146,43 @@ impl TextInferenceEngine {
}) })
.await .await
.expect("Inference failed"); .expect("Inference failed");
let output_ids: Vec<u32> = output_tokens
.iter()
.filter_map(|x| match self.tokenizer.token_to_id(x) {
Some(y) => Some(y),
None => {
println!("Warning: token ({}) missed in vocab", x);
None
}
})
.collect();
self.tokenizer.decode(output_ids, true).unwrap() self.tokenizer.decode(output_ids, true).unwrap()
} }
} }
fn inference_callback(
context: &mut InferenceContext,
_step: usize,
_token_id: u32,
token: String,
) -> bool {
if context.cancel.is_cancelled() {
true
} else if let Some(re) = &context.stop_re {
let mut new_token = reverse(token);
new_token.push_str(&context.reversed_output_text);
context.reversed_output_text = new_token;
re.find(&context.reversed_output_text).is_some()
} else {
false
}
}
fn reverse(s: String) -> String {
s.chars().rev().collect()
}
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex {
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
.map(reverse)
.collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
Regex::new(&regex_string).unwrap()
}

View File

@ -19,7 +19,6 @@ serdeconv = { workspace = true }
serde_json = "1.0" serde_json = "1.0"
tower-http = { version = "0.4.0", features = ["cors"] } tower-http = { version = "0.4.0", features = ["cors"] }
clap = { version = "4.3.0", features = ["derive"] } clap = { version = "4.3.0", features = ["derive"] }
regex = "1.8.3"
lazy_static = { workspace = true } lazy_static = { workspace = true }
rust-embed = "6.6.1" rust-embed = "6.6.1"
mime_guess = "2.0.4" mime_guess = "2.0.4"

View File

@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder};
use tabby_common::{events, path::ModelDir}; use tabby_common::{events, path::ModelDir};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_stop_words;
mod languages; mod languages;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -57,9 +59,11 @@ pub async fn completion(
State(state): State<Arc<CompletionState>>, State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Json<CompletionResponse> { ) -> Json<CompletionResponse> {
let language = request.language.unwrap_or("unknown".into());
let options = TextInferenceOptionsBuilder::default() let options = TextInferenceOptionsBuilder::default()
.max_decoding_length(64) .max_decoding_length(128)
.sampling_temperature(0.2) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language))
.build() .build()
.expect("Invalid TextInferenceOptions"); .expect("Invalid TextInferenceOptions");
@ -80,30 +84,24 @@ pub async fn completion(
request.prompt.expect("No prompt is set") request.prompt.expect("No prompt is set")
}; };
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.inference(&prompt, options).await; let text = state.engine.inference(&prompt, options).await;
let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text);
let response = CompletionResponse {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
choices: vec![Choice {
index: 0,
text: filtered_text.to_string(),
}],
};
events::Event::Completion { events::Event::Completion {
completion_id: &response.id, completion_id: &completion_id,
language: &language, language: &language,
prompt: &prompt, prompt: &prompt,
choices: vec![events::Choice { choices: vec![events::Choice {
index: 0, index: 0,
text: filtered_text, text: &text,
}], }],
} }
.log(); .log();
Json(response) Json(CompletionResponse {
id: completion_id,
choices: vec![Choice { index: 0, text }],
})
} }
pub struct CompletionState { pub struct CompletionState {

View File

@ -1,26 +1,32 @@
use std::collections::HashMap; use std::collections::HashMap;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex;
lazy_static! { lazy_static! {
static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap(); static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
static ref LANGUAGES: HashMap<&'static str, Regex> = { static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
let mut map = HashMap::new(); let mut map = HashMap::new();
map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]);
map.insert( map.insert(
"python", "javascript",
Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(), vec!["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"],
);
map.insert(
"typescript",
vec![
"\n\n",
"\nfunction",
"\n//",
"\nimport",
"\nclass",
"\ninterface",
"\ntype",
],
); );
map map
}; };
} }
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str { pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT); LANGUAGES.get(language).unwrap_or(&DEFAULT)
let position = re.find_iter(text).next();
if let Some(m) = position {
&text[..m.start()]
} else {
text
}
} }