diff --git a/Cargo.lock b/Cargo.lock index ebd2492..051d658 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -579,6 +579,7 @@ dependencies = [ "cxx", "cxx-build", "derive_builder", + "regex", "rust-cxx-cmake-bridge", "tokenizers", "tokio", @@ -2022,9 +2023,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.3" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81ca098a9821bd52d6b24fd8b10bd081f47d39c22778cafaa75a2857a62c6390" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" dependencies = [ "aho-corasick 1.0.1", "memchr", diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 046e303..329ab12 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] cxx = "1.0" derive_builder = "0.12.0" +regex = "1.8.4" tokenizers = "0.13.3" tokio = { workspace = true, features = ["rt"] } tokio-util = { workspace = true } diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index f21ccf4..5202a65 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,3 +1,4 @@ +use regex::Regex; use tokenizers::tokenizer::Tokenizer; use tokio_util::sync::CancellationToken; @@ -67,16 +68,21 @@ pub struct TextInferenceOptions { #[builder(default = "1.0")] sampling_temperature: f32, + + #[builder(default = "vec!()")] + stop_words: Vec } pub struct InferenceContext { + stop_regexp: Regex, cancel: CancellationToken, output_text: String } impl InferenceContext { - fn new(cancel: CancellationToken) -> Self { - InferenceContext { cancel, output_text: "".to_owned() } + fn new(stop_words: Vec, cancel: CancellationToken) -> Self { + let stop_regexp = Regex::new(stop_words.join("|").as_ref()).unwrap(); + InferenceContext { stop_regexp, cancel, output_text: "".to_owned() } } } @@ -108,7 +114,7 @@ impl TextInferenceEngine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let context = InferenceContext::new(cancel_for_inference); + let context = InferenceContext::new(options.stop_words, cancel_for_inference); let output_tokens = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( @@ -136,10 +142,14 @@ impl TextInferenceEngine { } fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool { - context.output_text.push_str(&token); if context.cancel.is_cancelled() { true } else { - false + context.output_text.push_str(&token); + if let Some(_) = context.stop_regexp.find(&context.output_text) { + true + } else { + false + } } }