feat: support stop sequences [TAB-52] (#212)
* refactor: pass step and string token to callback * add token to callback * add stop regexp * implement stop words logic * pass token_ids from inference * improve effiency of regexp match with reversed regex * fmt * add typescript and javascript stop words * add cache for stop words regexpimprove-workflow
parent
787e195359
commit
fd1baff8d5
|
|
@ -578,7 +578,9 @@ dependencies = [
|
|||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"dashmap",
|
||||
"derive_builder",
|
||||
"regex",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
@ -664,6 +666,19 @@ dependencies = [
|
|||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.12.0"
|
||||
|
|
@ -2022,9 +2037,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.8.3"
|
||||
version = "1.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390"
|
||||
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.1",
|
||||
"memchr",
|
||||
|
|
@ -2503,7 +2518,6 @@ dependencies = [
|
|||
"hyper",
|
||||
"lazy_static",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
dashmap = "5.4.0"
|
||||
derive_builder = "0.12.0"
|
||||
regex = "1.8.4"
|
||||
tokenizers = "0.13.3"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@ namespace tabby {
|
|||
|
||||
struct InferenceContext;
|
||||
|
||||
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
virtual rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
|
|
|
|||
|
|
@ -15,27 +15,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
};
|
||||
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
) const {
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
const auto output_tokens = process(
|
||||
return process(
|
||||
std::move(context),
|
||||
std::move(is_context_cancelled),
|
||||
std::move(callback),
|
||||
input_tokens,
|
||||
Options{max_decoding_length, sampling_temperature}
|
||||
);
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
output.reserve(output_tokens.size());
|
||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
|
|
@ -45,9 +39,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const = 0;
|
||||
std::unique_ptr<Model> model_;
|
||||
|
|
@ -55,28 +49,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
|
||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::TranslationOptions x;
|
||||
x.max_decoding_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return std::move(result.output());
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::GenerationOptions x;
|
||||
|
|
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
|||
x.max_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
|
|
@ -26,11 +28,19 @@ mod ffi {
|
|||
fn inference(
|
||||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
is_context_cancelled: fn(&InferenceContext) -> bool,
|
||||
callback: fn(
|
||||
&mut InferenceContext,
|
||||
// step
|
||||
usize,
|
||||
// token_id
|
||||
u32,
|
||||
// token
|
||||
String,
|
||||
) -> bool,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
) -> Vec<String>;
|
||||
) -> Vec<u32>;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -59,13 +69,30 @@ pub struct TextInferenceOptions {
|
|||
|
||||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
}
|
||||
|
||||
struct InferenceContext(CancellationToken);
|
||||
pub struct InferenceContext {
|
||||
stop_re: Option<Regex>,
|
||||
cancel: CancellationToken,
|
||||
reversed_output_text: String,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||
InferenceContext {
|
||||
stop_re,
|
||||
cancel,
|
||||
reversed_output_text: "".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
}
|
||||
|
||||
impl TextInferenceEngine {
|
||||
|
|
@ -79,6 +106,7 @@ impl TextInferenceEngine {
|
|||
);
|
||||
return TextInferenceEngine {
|
||||
engine,
|
||||
stop_regex_cache: DashMap::new(),
|
||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||
};
|
||||
}
|
||||
|
|
@ -91,12 +119,26 @@ impl TextInferenceEngine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let context = InferenceContext(cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let stop_re: Option<Regex> = if options.stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(options.stop_words);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache.insert(
|
||||
options.stop_words,
|
||||
create_stop_regex(&self.tokenizer, options.stop_words),
|
||||
);
|
||||
re = self.stop_regex_cache.get(options.stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
};
|
||||
|
||||
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
|context| context.0.is_cancelled(),
|
||||
inference_callback,
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
|
|
@ -104,16 +146,43 @@ impl TextInferenceEngine {
|
|||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
let output_ids: Vec<u32> = output_tokens
|
||||
.iter()
|
||||
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
||||
Some(y) => Some(y),
|
||||
None => {
|
||||
println!("Warning: token ({}) missed in vocab", x);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_callback(
|
||||
context: &mut InferenceContext,
|
||||
_step: usize,
|
||||
_token_id: u32,
|
||||
token: String,
|
||||
) -> bool {
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(re) = &context.stop_re {
|
||||
let mut new_token = reverse(token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse(s: String) -> String {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
|
||||
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex {
|
||||
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| x.get_tokens().join(""))
|
||||
// Reverse for efficient suffix matching.
|
||||
.map(reverse)
|
||||
.collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ serdeconv = { workspace = true }
|
|||
serde_json = "1.0"
|
||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||
clap = { version = "4.3.0", features = ["derive"] }
|
||||
regex = "1.8.3"
|
||||
lazy_static = { workspace = true }
|
||||
rust-embed = "6.6.1"
|
||||
mime_guess = "2.0.4"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder};
|
|||
use tabby_common::{events, path::ModelDir};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_stop_words;
|
||||
|
||||
mod languages;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -57,9 +59,11 @@ pub async fn completion(
|
|||
State(state): State<Arc<CompletionState>>,
|
||||
Json(request): Json<CompletionRequest>,
|
||||
) -> Json<CompletionResponse> {
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let options = TextInferenceOptionsBuilder::default()
|
||||
.max_decoding_length(64)
|
||||
.sampling_temperature(0.2)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
.build()
|
||||
.expect("Invalid TextInferenceOptions");
|
||||
|
||||
|
|
@ -80,30 +84,24 @@ pub async fn completion(
|
|||
request.prompt.expect("No prompt is set")
|
||||
};
|
||||
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let text = state.engine.inference(&prompt, options).await;
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||
|
||||
let response = CompletionResponse {
|
||||
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
text: filtered_text.to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
events::Event::Completion {
|
||||
completion_id: &response.id,
|
||||
completion_id: &completion_id,
|
||||
language: &language,
|
||||
prompt: &prompt,
|
||||
choices: vec![events::Choice {
|
||||
index: 0,
|
||||
text: filtered_text,
|
||||
text: &text,
|
||||
}],
|
||||
}
|
||||
.log();
|
||||
|
||||
Json(response)
|
||||
Json(CompletionResponse {
|
||||
id: completion_id,
|
||||
choices: vec![Choice { index: 0, text }],
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
|
|
|
|||
|
|
@ -1,26 +1,32 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap();
|
||||
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
|
||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]);
|
||||
map.insert(
|
||||
"python",
|
||||
Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(),
|
||||
"javascript",
|
||||
vec!["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"],
|
||||
);
|
||||
map.insert(
|
||||
"typescript",
|
||||
vec![
|
||||
"\n\n",
|
||||
"\nfunction",
|
||||
"\n//",
|
||||
"\nimport",
|
||||
"\nclass",
|
||||
"\ninterface",
|
||||
"\ntype",
|
||||
],
|
||||
);
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str {
|
||||
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT);
|
||||
let position = re.find_iter(text).next();
|
||||
if let Some(m) = position {
|
||||
&text[..m.start()]
|
||||
} else {
|
||||
text
|
||||
}
|
||||
pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
|
||||
LANGUAGES.get(language).unwrap_or(&DEFAULT)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue