diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index b1f67db..d81a742 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -7,12 +7,14 @@ namespace tabby { struct InferenceContext; +typedef rust::Fn InferenceCallback; + class TextInferenceEngine { public: virtual ~TextInferenceEngine(); virtual rust::Vec inference( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index beb289c..a345423 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -17,7 +17,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { public: rust::Vec inference( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature @@ -26,7 +26,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { std::vector input_tokens(tokens.begin(), tokens.end()); const auto output_tokens = process( std::move(context), - std::move(is_context_cancelled), + std::move(callback), input_tokens, Options{max_decoding_length, sampling_temperature} ); @@ -47,7 +47,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { protected: virtual std::vector process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const = 0; std::unique_ptr model_; @@ -57,7 +57,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::TranslationOptions x; @@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpltranslate_batch({ tokens }, x)[0]; return std::move(result.output()); @@ -76,7 +76,7 @@ class DecoderImpl : public TextInferenceEngineImpl process( rust::Box context, - rust::Fn is_context_cancelled, + InferenceCallback callback, const std::vector& tokens, const Options& options) const override { ctranslate2::GenerationOptions x; @@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImplgenerate_batch_async({ tokens }, x)[0].get(); return std::move(result.sequences[0]); diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 2861972..d8c8094 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -26,7 +26,13 @@ mod ffi { fn inference( &self, context: Box, - is_context_cancelled: fn(&InferenceContext) -> bool, + callback: fn( + &InferenceContext, + // step + usize, + // token + String, + ) -> bool, tokens: &[String], max_decoding_length: usize, sampling_temperature: f32, @@ -61,7 +67,7 @@ pub struct TextInferenceOptions { sampling_temperature: f32, } -struct InferenceContext(CancellationToken); +pub struct InferenceContext(CancellationToken); pub struct TextInferenceEngine { engine: cxx::SharedPtr, @@ -96,7 +102,7 @@ impl TextInferenceEngine { let context = Box::new(context); engine.inference( context, - |context| context.0.is_cancelled(), + inference_callback, encoding.get_tokens(), options.max_decoding_length, options.sampling_temperature, @@ -117,3 +123,11 @@ impl TextInferenceEngine { self.tokenizer.decode(output_ids, true).unwrap() } } + +fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool { + if context.0.is_cancelled() { + true + } else { + false + } +}