implement stop words logic
parent
fea645248e
commit
fa84e376f4
|
|
@ -2504,7 +2504,6 @@ dependencies = [
|
|||
"hyper",
|
||||
"lazy_static",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
|||
|
|
@ -69,20 +69,18 @@ pub struct TextInferenceOptions {
|
|||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "vec!()")]
|
||||
stop_words: Vec<String>
|
||||
stop_words: &'static Vec<&'static str>
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
stop_regexp: Regex,
|
||||
stop_re: Option<Regex>,
|
||||
cancel: CancellationToken,
|
||||
output_text: String
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self {
|
||||
let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap();
|
||||
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
|
||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||
InferenceContext { stop_re, cancel, output_text: "".to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -114,7 +112,16 @@ impl TextInferenceEngine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let context = InferenceContext::new(options.stop_words, cancel_for_inference);
|
||||
let stop_re = if options.stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
|
||||
let stop_tokens : Vec<String> = encodings.iter().map(|x| x.get_tokens().join("")).collect();
|
||||
let regex_string = r"(?m)".to_owned() + &stop_tokens.join("|");
|
||||
Some(Regex::new(®ex_string).unwrap())
|
||||
};
|
||||
|
||||
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
|
|
@ -137,7 +144,13 @@ impl TextInferenceEngine {
|
|||
}
|
||||
})
|
||||
.collect();
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
let output_text = self.tokenizer.decode(output_ids, true).unwrap();
|
||||
for stop_word in options.stop_words {
|
||||
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
|
||||
return stripped_text.to_string();
|
||||
}
|
||||
}
|
||||
output_text
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -145,9 +158,9 @@ fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u
|
|||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else {
|
||||
context.output_text.push_str(&token);
|
||||
if let Some(_) = context.stop_regexp.find(&context.output_text) {
|
||||
true
|
||||
if let Some(re) = &context.stop_re {
|
||||
context.output_text.push_str(&token);
|
||||
re.find(&context.output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ serdeconv = { workspace = true }
|
|||
serde_json = "1.0"
|
||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||
clap = { version = "4.3.0", features = ["derive"] }
|
||||
regex = "1.8.3"
|
||||
lazy_static = { workspace = true }
|
||||
rust-embed = "6.6.1"
|
||||
mime_guess = "2.0.4"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ use strfmt::{strfmt, strfmt_builder};
|
|||
use tabby_common::{events, path::ModelDir};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_stop_words;
|
||||
|
||||
mod languages;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
|
|
@ -57,9 +59,11 @@ pub async fn completion(
|
|||
State(state): State<Arc<CompletionState>>,
|
||||
Json(request): Json<CompletionRequest>,
|
||||
) -> Json<CompletionResponse> {
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let options = TextInferenceOptionsBuilder::default()
|
||||
.max_decoding_length(64)
|
||||
.sampling_temperature(0.2)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(get_stop_words(&language))
|
||||
.build()
|
||||
.expect("Invalid TextInferenceOptions");
|
||||
|
||||
|
|
@ -80,30 +84,27 @@ pub async fn completion(
|
|||
request.prompt.expect("No prompt is set")
|
||||
};
|
||||
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let text = state.engine.inference(&prompt, options).await;
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||
|
||||
let response = CompletionResponse {
|
||||
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
text: filtered_text.to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
events::Event::Completion {
|
||||
completion_id: &response.id,
|
||||
completion_id: &completion_id,
|
||||
language: &language,
|
||||
prompt: &prompt,
|
||||
choices: vec![events::Choice {
|
||||
index: 0,
|
||||
text: filtered_text,
|
||||
text: &text,
|
||||
}],
|
||||
}
|
||||
.log();
|
||||
|
||||
Json(response)
|
||||
Json(CompletionResponse {
|
||||
id: completion_id,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
text,
|
||||
}],
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
|
|
|
|||
|
|
@ -1,26 +1,19 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Regex = Regex::new(r"(?m)\n\n").unwrap();
|
||||
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
|
||||
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"python",
|
||||
Regex::new(r"(?m)(\n\n|^def|^#|^from|^class)").unwrap(),
|
||||
vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass")
|
||||
);
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str {
|
||||
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT);
|
||||
let position = re.find_iter(text).next();
|
||||
if let Some(m) = position {
|
||||
&text[..m.start()]
|
||||
} else {
|
||||
text
|
||||
}
|
||||
pub fn get_stop_words(language: &str) -> &'static Vec<&'static str> {
|
||||
LANGUAGES.get(language).unwrap_or(&DEFAULT)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue