add stop regexp

support-stop-sequences
Meng Zhang 2023-06-06 14:12:04 -07:00
parent 040af1a374
commit fea645248e
3 changed files with 19 additions and 7 deletions

5
Cargo.lock generated
View File

@ -579,6 +579,7 @@ dependencies = [
"cxx", "cxx",
"cxx-build", "cxx-build",
"derive_builder", "derive_builder",
"regex",
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"tokenizers", "tokenizers",
"tokio", "tokio",
@ -2022,9 +2023,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.8.3" version = "1.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
dependencies = [ dependencies = [
"aho-corasick 1.0.1", "aho-corasick 1.0.1",
"memchr", "memchr",

View File

@ -6,6 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
cxx = "1.0" cxx = "1.0"
derive_builder = "0.12.0" derive_builder = "0.12.0"
regex = "1.8.4"
tokenizers = "0.13.3" tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] } tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true } tokio-util = { workspace = true }

View File

@ -1,3 +1,4 @@
use regex::Regex;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@ -67,16 +68,21 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
sampling_temperature: f32, sampling_temperature: f32,
#[builder(default = "vec!()")]
stop_words: Vec<String>
} }
pub struct InferenceContext { pub struct InferenceContext {
stop_regexp: Regex,
cancel: CancellationToken, cancel: CancellationToken,
output_text: String output_text: String
} }
impl InferenceContext { impl InferenceContext {
fn new(cancel: CancellationToken) -> Self { fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self {
InferenceContext { cancel, output_text: "".to_owned() } let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap();
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
} }
} }
@ -108,7 +114,7 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let context = InferenceContext::new(cancel_for_inference); let context = InferenceContext::new(options.stop_words, cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || { let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
@ -136,10 +142,14 @@ impl TextInferenceEngine {
} }
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool { fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
context.output_text.push_str(&token);
if context.cancel.is_cancelled() { if context.cancel.is_cancelled() {
true true
} else { } else {
false context.output_text.push_str(&token);
if let Some(_) = context.stop_regexp.find(&context.output_text) {
true
} else {
false
}
} }
} }