From 2bf5bcd0cf0af4d2df5d44cf0b49db3ff0dcb2af Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 4 Jun 2023 15:28:39 -0700 Subject: [PATCH] refactor: extract TextInferenceEngineImpl to reduce duplications between EncoderDecoderImpl and DecoderImpl #189 --- Cargo.lock | 1 + Cargo.toml | 1 + crates/ctranslate2-bindings/Cargo.toml | 3 +- .../include/ctranslate2.h | 4 + .../ctranslate2-bindings/src/ctranslate2.cc | 89 ++++++++++--------- crates/ctranslate2-bindings/src/lib.rs | 18 ++++ 6 files changed, 72 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5ace271..c4d7632 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -491,6 +491,7 @@ dependencies = [ "rust-cxx-cmake-bridge", "tokenizers", "tokio", + "tokio-util", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 6488911..b8b77ff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,4 @@ lazy_static = "1.4.0" serde = { version = "1.0", features = ["derive"] } serdeconv = "0.4.1" tokio = "1.28" +tokio-util = "0.7" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 511c9f5..046e303 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -8,6 +8,7 @@ cxx = "1.0" derive_builder = "0.12.0" tokenizers = "0.13.3" tokio = { workspace = true, features = ["rt"] } +tokio-util = { workspace = true } [build-dependencies] cxx-build = "1.0" @@ -15,5 +16,5 @@ cmake = { version = "0.1", optional = true } rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true } [features] -default = [ "dep:cmake", "dep:rust-cxx-cmake-bridge" ] +default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"] link_shared = [] diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 9e127d8..92f20a4 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -5,10 +5,14 @@ namespace tabby { +struct InferenceContext; + class TextInferenceEngine { public: virtual ~TextInferenceEngine(); virtual rust::Vec inference( + rust::Box context, + rust::Fn)> is_context_cancelled, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature, diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 1191b81..68e2cea 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -6,24 +6,32 @@ namespace tabby { TextInferenceEngine::~TextInferenceEngine() {} -class EncoderDecoderImpl: public TextInferenceEngine { +template +class TextInferenceEngineImpl : public TextInferenceEngine { + protected: + struct Options { + size_t max_decoding_length; + float sampling_temperature; + size_t beam_size; + }; + public: rust::Vec inference( + rust::Box context, + rust::Fn)> is_context_cancelled, rust::Slice tokens, size_t max_decoding_length, float sampling_temperature, size_t beam_size ) const { - // Create options. - ctranslate2::TranslationOptions options; - options.max_decoding_length = max_decoding_length; - options.sampling_temperature = sampling_temperature; - options.beam_size = beam_size; + // FIXME(meng): implement the cancellation. + if (is_context_cancelled(std::move(context))) { + return rust::Vec(); + } // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); - ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0]; - const auto& output_tokens = result.output(); + const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size}); // Convert to rust vec. rust::Vec output; @@ -33,48 +41,43 @@ class EncoderDecoderImpl: public TextInferenceEngine { } static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { - auto impl = std::make_unique(); - impl->translator_ = std::make_unique(loader); + auto impl = std::make_unique(); + impl->model_ = std::make_unique(loader); return impl; } - private: - std::unique_ptr translator_; + + protected: + virtual std::vector process(const std::vector& tokens, const Options& options) const = 0; + std::unique_ptr model_; }; -class DecoderImpl: public TextInferenceEngine { - public: - rust::Vec inference( - rust::Slice tokens, - size_t max_decoding_length, - float sampling_temperature, - size_t beam_size - ) const { - // Create options. - ctranslate2::GenerationOptions options; - options.include_prompt_in_result = false; - options.max_length = max_decoding_length; - options.sampling_temperature = sampling_temperature; - options.beam_size = beam_size; - - // Inference. - std::vector input_tokens(tokens.begin(), tokens.end()); - ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get(); - const auto& output_tokens = result.sequences[0]; - - // Convert to rust vec. - rust::Vec output; - output.reserve(output_tokens.size()); - std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); - return output; +class EncoderDecoderImpl : public TextInferenceEngineImpl { + protected: + virtual std::vector process(const std::vector& tokens, const Options& options) const override { + ctranslate2::TranslationOptions x; + x.max_decoding_length = options.max_decoding_length; + x.sampling_temperature = options.sampling_temperature; + x.beam_size = options.beam_size; + ctranslate2::TranslationResult result = model_->translate_batch( + { tokens }, + ctranslate2::TranslationOptions{ + } + )[0]; + return std::move(result.output()); } +}; - static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { - auto impl = std::make_unique(); - impl->generator_ = std::make_unique(loader); - return impl; +class DecoderImpl : public TextInferenceEngineImpl { + protected: + virtual std::vector process(const std::vector& tokens, const Options& options) const override { + ctranslate2::GenerationOptions x; + x.include_prompt_in_result = false; + x.max_length = options.max_decoding_length; + x.sampling_temperature = options.sampling_temperature; + x.beam_size = options.beam_size; + ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); + return std::move(result.sequences[0]); } - private: - std::unique_ptr generator_; }; std::shared_ptr create_engine( diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 55e10e7..f40f8af 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,10 +1,15 @@ use tokenizers::tokenizer::Tokenizer; +use tokio_util::sync::CancellationToken; #[macro_use] extern crate derive_builder; #[cxx::bridge(namespace = "tabby")] mod ffi { + extern "Rust" { + type InferenceContext; + } + unsafe extern "C++" { include!("ctranslate2-bindings/include/ctranslate2.h"); @@ -20,6 +25,8 @@ mod ffi { fn inference( &self, + context: Box, + is_context_cancelled: fn(Box) -> bool, tokens: &[String], max_decoding_length: usize, sampling_temperature: f32, @@ -58,6 +65,8 @@ pub struct TextInferenceOptions { beam_size: usize, } +struct InferenceContext(CancellationToken); + pub struct TextInferenceEngine { engine: cxx::SharedPtr, tokenizer: Tokenizer, @@ -81,8 +90,17 @@ impl TextInferenceEngine { pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { let encoding = self.tokenizer.encode(prompt, true).unwrap(); let engine = self.engine.clone(); + + let cancel = CancellationToken::new(); + let cancel_for_inference = cancel.clone(); + let _guard = cancel.drop_guard(); + + let context = InferenceContext(cancel_for_inference); let output_tokens = tokio::task::spawn_blocking(move || { + let context = Box::new(context); engine.inference( + context, + |context| context.0.is_cancelled(), encoding.get_tokens(), options.max_decoding_length, options.sampling_temperature,