diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index d81a742..b923c6a 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -7,7 +7,7 @@ namespace tabby { struct InferenceContext; -typedef rust::Fn InferenceCallback; +typedef rust::Fn InferenceCallback; class TextInferenceEngine { public: diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index a345423..d45c2d8 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpltranslate_batch({ tokens }, x)[0]; return std::move(result.output()); @@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImplgenerate_batch_async({ tokens }, x)[0].get(); return std::move(result.sequences[0]); diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index d8c8094..f21ccf4 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -27,9 +27,11 @@ mod ffi { &self, context: Box, callback: fn( - &InferenceContext, + &mut InferenceContext, // step usize, + // token_id + u32, // token String, ) -> bool, @@ -67,7 +69,16 @@ pub struct TextInferenceOptions { sampling_temperature: f32, } -pub struct InferenceContext(CancellationToken); +pub struct InferenceContext { + cancel: CancellationToken, + output_text: String +} + +impl InferenceContext { + fn new(cancel: CancellationToken) -> Self { + InferenceContext { cancel, output_text: "".to_owned() } + } +} pub struct TextInferenceEngine { engine: cxx::SharedPtr, @@ -97,7 +108,7 @@ impl TextInferenceEngine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let context = InferenceContext(cancel_for_inference); + let context = InferenceContext::new(cancel_for_inference); let output_tokens = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( @@ -124,8 +135,9 @@ impl TextInferenceEngine { } } -fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool { - if context.0.is_cancelled() { +fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool { + context.output_text.push_str(&token); + if context.cancel.is_cancelled() { true } else { false