add stop regexp

support-stop-sequences
Meng Zhang 2023-06-06 14:12:04 -07:00
parent 040af1a374
commit fea645248e
3 changed files with 19 additions and 7 deletions

5
Cargo.lock generated
View File

@ -579,6 +579,7 @@ dependencies = [
"cxx",
"cxx-build",
"derive_builder",
"regex",
"rust-cxx-cmake-bridge",
"tokenizers",
"tokio",
@ -2022,9 +2023,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.8.3"
version = "1.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390"
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
dependencies = [
"aho-corasick 1.0.1",
"memchr",

View File

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
cxx = "1.0"
derive_builder = "0.12.0"
regex = "1.8.4"
tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }

View File

@ -1,3 +1,4 @@
use regex::Regex;
use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken;
@ -67,16 +68,21 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")]
sampling_temperature: f32,
#[builder(default = "vec!()")]
stop_words: Vec<String>
}
pub struct InferenceContext {
stop_regexp: Regex,
cancel: CancellationToken,
output_text: String
}
impl InferenceContext {
fn new(cancel: CancellationToken) -> Self {
InferenceContext { cancel, output_text: "".to_owned() }
fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self {
let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap();
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
}
}
@ -108,7 +114,7 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let context = InferenceContext::new(cancel_for_inference);
let context = InferenceContext::new(options.stop_words, cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
@ -136,10 +142,14 @@ impl TextInferenceEngine {
}
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
context.output_text.push_str(&token);
if context.cancel.is_cancelled() {
true
} else {
false
context.output_text.push_str(&token);
if let Some(_) = context.stop_regexp.find(&context.output_text) {
true
} else {
false
}
}
}