add stop regexp
parent
040af1a374
commit
fea645248e
|
|
@ -579,6 +579,7 @@ dependencies = [
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
"regex",
|
||||||
"rust-cxx-cmake-bridge",
|
"rust-cxx-cmake-bridge",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
@ -2022,9 +2023,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.8.3"
|
version = "1.8.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390"
|
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick 1.0.1",
|
"aho-corasick 1.0.1",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
|
regex = "1.8.4"
|
||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
tokio = { workspace = true, features = ["rt"] }
|
tokio = { workspace = true, features = ["rt"] }
|
||||||
tokio-util = { workspace = true }
|
tokio-util = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use regex::Regex;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
|
@ -67,16 +68,21 @@ pub struct TextInferenceOptions {
|
||||||
|
|
||||||
#[builder(default = "1.0")]
|
#[builder(default = "1.0")]
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
|
|
||||||
|
#[builder(default = "vec!()")]
|
||||||
|
stop_words: Vec<String>
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
|
stop_regexp: Regex,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
output_text: String
|
output_text: String
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceContext {
|
impl InferenceContext {
|
||||||
fn new(cancel: CancellationToken) -> Self {
|
fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self {
|
||||||
InferenceContext { cancel, output_text: "".to_owned() }
|
let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap();
|
||||||
|
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -108,7 +114,7 @@ impl TextInferenceEngine {
|
||||||
let cancel_for_inference = cancel.clone();
|
let cancel_for_inference = cancel.clone();
|
||||||
let _guard = cancel.drop_guard();
|
let _guard = cancel.drop_guard();
|
||||||
|
|
||||||
let context = InferenceContext::new(cancel_for_inference);
|
let context = InferenceContext::new(options.stop_words, cancel_for_inference);
|
||||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
|
|
@ -136,10 +142,14 @@ impl TextInferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
||||||
context.output_text.push_str(&token);
|
|
||||||
if context.cancel.is_cancelled() {
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
false
|
context.output_text.push_str(&token);
|
||||||
|
if let Some(_) = context.stop_regexp.find(&context.output_text) {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue