support-stop-sequences
Meng Zhang 2023-06-06 15:47:54 -07:00
parent c3e57147cf
commit 54901515ef
3 changed files with 26 additions and 22 deletions

View File

@ -69,18 +69,22 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")]
sampling_temperature: f32,
stop_words: &'static Vec<&'static str>
stop_words: &'static Vec<&'static str>,
}
pub struct InferenceContext {
stop_re: Option<Regex>,
cancel: CancellationToken,
reversed_output_text: String
reversed_output_text: String,
}
impl InferenceContext {
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
InferenceContext { stop_re, cancel, reversed_output_text: "".to_owned() }
InferenceContext {
stop_re,
cancel,
reversed_output_text: "".to_owned(),
}
}
}
@ -116,8 +120,11 @@ impl TextInferenceEngine {
None
} else {
// FIXME(meng): consider cache the regexp.
let encodings = self.tokenizer.encode_batch(options.stop_words.clone(), false).unwrap();
let stop_tokens : Vec<String> = encodings
let encodings = self
.tokenizer
.encode_batch(options.stop_words.clone(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| x.get_tokens().join(""))
// Reverse for efficient suffix matching.
@ -146,18 +153,21 @@ impl TextInferenceEngine {
}
}
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
fn inference_callback(
context: &mut InferenceContext,
_step: usize,
_token_id: u32,
token: String,
) -> bool {
if context.cancel.is_cancelled() {
true
} else if let Some(re) = &context.stop_re {
let mut new_token = reverse(token);
new_token.push_str(&context.reversed_output_text);
context.reversed_output_text = new_token;
re.find(&context.reversed_output_text).is_some()
} else {
if let Some(re) = &context.stop_re {
let mut new_token = reverse(token);
new_token.push_str(&context.reversed_output_text);
context.reversed_output_text = new_token;
re.find(&context.reversed_output_text).is_some()
} else {
false
}
false
}
}

View File

@ -100,10 +100,7 @@ pub async fn completion(
Json(CompletionResponse {
id: completion_id,
choices: vec![Choice {
index: 0,
text,
}],
choices: vec![Choice { index: 0, text }],
})
}

View File

@ -6,10 +6,7 @@ lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec!("\n\n");
static ref LANGUAGES: HashMap<&'static str, Vec<&'static str>> = {
let mut map = HashMap::new();
map.insert(
"python",
vec!("\n\n", "\ndef", "\n#", "\nfrom", "\nclass")
);
map.insert("python", vec!["\n\n", "\ndef", "\n#", "\nfrom", "\nclass"]);
map
};
}