add stop regexp
parent
040af1a374
commit
fea645248e
|
|
@ -579,6 +579,7 @@ dependencies = [
|
|||
"cxx",
|
||||
"cxx-build",
|
||||
"derive_builder",
|
||||
"regex",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
@ -2022,9 +2023,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.8.3"
|
||||
version = "1.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390"
|
||||
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.1",
|
||||
"memchr",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
cxx = "1.0"
|
||||
derive_builder = "0.12.0"
|
||||
regex = "1.8.4"
|
||||
tokenizers = "0.13.3"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use regex::Regex;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
|
|
@ -67,16 +68,21 @@ pub struct TextInferenceOptions {
|
|||
|
||||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "vec!()")]
|
||||
stop_words: Vec<String>
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
stop_regexp: Regex,
|
||||
cancel: CancellationToken,
|
||||
output_text: String
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(cancel: CancellationToken) -> Self {
|
||||
InferenceContext { cancel, output_text: "".to_owned() }
|
||||
fn new(stop_words: Vec<String>, cancel: CancellationToken) -> Self {
|
||||
let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap();
|
||||
InferenceContext { stop_regexp, cancel, output_text: "".to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -108,7 +114,7 @@ impl TextInferenceEngine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let context = InferenceContext::new(cancel_for_inference);
|
||||
let context = InferenceContext::new(options.stop_words, cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
|
|
@ -136,10 +142,14 @@ impl TextInferenceEngine {
|
|||
}
|
||||
|
||||
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
||||
context.output_text.push_str(&token);
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
context.output_text.push_str(&token);
|
||||
if let Some(_) = context.stop_regexp.find(&context.output_text) {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue