implement stop words logic

support-stop-sequences
Meng Zhang 2023-06-06 15:01:40 -07:00
parent fea645248e
commit fa84e376f4
5 changed files with 45 additions and 40 deletions

1
Cargo.lock generated
View File

@ -2504,7 +2504,6 @@ dependencies = [
"hyper", "hyper",
"lazy_static", "lazy_static",
"mime_guess", "mime_guess",
"regex",
"rust-embed", "rust-embed",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -69,20 +69,18 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
sampling_temperature: f32, sampling_temperature: f32,
#[builder(default = "vec!()")] stop_words: &'static Vec<&'static str>
stop_words: Vec<String>
} }
pub struct InferenceContext { pub struct InferenceContext {
stop_regexp: Regex, stop_re: Option<Regex>,
cancel: CancellationToken, cancel: CancellationToken,
output_text: String output_text: String
} }
impl InferenceContext { impl InferenceContext {
fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self { fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap(); InferenceContext { stop_re, cancel, output_text: "".to_owned() }
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
} }
} }
@ -114,7 +112,16 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let context = InferenceContext::new(options.stop_words, cancel_for_inference); let stop_re = if options.stop_words.is_empty() {
None
} else {
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
let stop_tokens : Vec<String> = encodings.iter().map(|x| x.get_tokens().join("")).collect();
let regex_string = r"(?m)".to_owned() + &stop_tokens.join("|");
Some(Regex::new(&regex_string).unwrap())
};
let context = InferenceContext::new(stop_re, cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || { let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
@ -137,7 +144,13 @@ impl TextInferenceEngine {
} }
}) })
.collect(); .collect();
self.tokenizer.decode(output_ids, true).unwrap() let output_text = self.tokenizer.decode(output_ids, true).unwrap();
for stop_word in options.stop_words {
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
return stripped_text.to_string();
}
}
output_text
} }
} }
@ -145,9 +158,9 @@ fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u
if context.cancel.is_cancelled() { if context.cancel.is_cancelled() {
true true
} else { } else {
context.output_text.push_str(&token); if let Some(re) = &context.stop_re {
if let Some(_) = context.stop_regexp.find(&context.output_text) { context.output_text.push_str(&token);
true re.find(&context.output_text).is_some()
} else { } else {
false false
} }

View File

@ -19,7 +19,6 @@ serdeconv = { workspace = true }
serde_json = "1.0" serde_json = "1.0"
tower-http = { version = "0.4.0", features = ["cors"] } tower-http = { version = "0.4.0", features = ["cors"] }
clap = { version = "4.3.0", features = ["derive"] } clap = { version = "4.3.0", features = ["derive"] }
regex = "1.8.3"
lazy_static = { workspace = true } lazy_static = { workspace = true }
rust-embed = "6.6.1" rust-embed = "6.6.1"
mime_guess = "2.0.4" mime_guess = "2.0.4"

View File

@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder};
use tabby_common::{events, path::ModelDir}; use tabby_common::{events, path::ModelDir};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_stop_words;
mod languages; mod languages;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -57,9 +59,11 @@ pub async fn completion(
State(state): State<Arc<CompletionState>>, State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Json<CompletionResponse> { ) -> Json<CompletionResponse> {
let language = request.language.unwrap_or("unknown".into());
let options = TextInferenceOptionsBuilder::default() let options = TextInferenceOptionsBuilder::default()
.max_decoding_length(64) .max_decoding_length(128)
.sampling_temperature(0.2) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language))
.build() .build()
.expect("Invalid TextInferenceOptions"); .expect("Invalid TextInferenceOptions");
@ -80,30 +84,27 @@ pub async fn completion(
request.prompt.expect("No prompt is set") request.prompt.expect("No prompt is set")
}; };
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.inference(&prompt, options).await; let text = state.engine.inference(&prompt, options).await;
let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text);
let response = CompletionResponse {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
choices: vec![Choice {
index: 0,
text: filtered_text.to_string(),
}],
};
events::Event::Completion { events::Event::Completion {
completion_id: &response.id, completion_id: &completion_id,
language: &language, language: &language,
prompt: &prompt, prompt: &prompt,
choices: vec![events::Choice { choices: vec![events::Choice {
index: 0, index: 0,
text: filtered_text, text: &text,
}], }],
} }
.log(); .log();
Json(response) Json(CompletionResponse {
id: completion_id,
choices: vec![Choice {
index: 0,
text,
}],
})
} }
pub struct CompletionState { pub struct CompletionState {

View File

@ -1,26 +1,19 @@
use std::collections::HashMap; use std::collections::HashMap;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex;
lazy_static! { lazy_static! {
static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap(); static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
static ref LANGUAGES: HashMap<&'static str, Regex> = { static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
let mut map = HashMap::new(); let mut map = HashMap::new();
map.insert( map.insert(
"python", "python",
Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(), vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass")
); );
map map
}; };
} }
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str { pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT); LANGUAGES.get(language).unwrap_or(&DEFAULT)
let position = re.find_iter(text).next();
if let Some(m) = position {
&text[..m.start()]
} else {
text
}
} }