From 007a40c58227e4e4ba853ca61b90462dff4e4f08 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 6 Jun 2023 05:46:17 -0700 Subject: [PATCH] feat: support early stop [TAB-51] (#208) * bump ctranslate2 to v3.15.0 * enable early stop * support early stop --- Dockerfile | 4 +- crates/ctranslate2-bindings/CTranslate2 | 2 +- .../include/ctranslate2.h | 5 +- .../ctranslate2-bindings/src/ctranslate2.cc | 52 ++++++++++++------- crates/ctranslate2-bindings/src/lib.rs | 7 +-- 5 files changed, 38 insertions(+), 32 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0e922e8..c09d74a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2 as source +FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2 as source FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder ENV CTRANSLATE2_ROOT=/opt/ctranslate2 @@ -31,7 +31,7 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \ cargo build --features link_shared --release && \ cp target/release/tabby /opt/tabby/bin/ -FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2 +FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2 COPY --from=builder /opt/tabby /opt/tabby diff --git a/crates/ctranslate2-bindings/CTranslate2 b/crates/ctranslate2-bindings/CTranslate2 index 45af5eb..d4b6f38 160000 --- a/crates/ctranslate2-bindings/CTranslate2 +++ b/crates/ctranslate2-bindings/CTranslate2 @@ -1 +1 @@ -Subproject commit 45af5ebcb643f205a6709e0bf6c09157d1ecba52 +Subproject commit d4b6f3849ae1bd67d1de0a037be3d7a7833fac6c diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 92f20a4..b1f67db 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -12,11 +12,10 @@ class TextInferenceEngine { virtual ~TextInferenceEngine(); virtual rust::Vec inference( rust::Box context, - rust::Fn)> is_context_cancelled, + rust::Fn is_context_cancelled, rust::Slice tokens, size_t max_decoding_length, - float sampling_temperature, - size_t beam_size + float sampling_temperature ) const = 0; }; diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 68e2cea..beb289c 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -12,26 +12,24 @@ class TextInferenceEngineImpl : public TextInferenceEngine { struct Options { size_t max_decoding_length; float sampling_temperature; - size_t beam_size; }; public: rust::Vec inference( rust::Box context, - rust::Fn)> is_context_cancelled, + rust::Fn is_context_cancelled, rust::Slice tokens, size_t max_decoding_length, - float sampling_temperature, - size_t beam_size + float sampling_temperature ) const { - // FIXME(meng): implement the cancellation. - if (is_context_cancelled(std::move(context))) { - return rust::Vec(); - } - // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); - const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size}); + const auto output_tokens = process( + std::move(context), + std::move(is_context_cancelled), + input_tokens, + Options{max_decoding_length, sampling_temperature} + ); // Convert to rust vec. rust::Vec output; @@ -47,34 +45,48 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } protected: - virtual std::vector process(const std::vector& tokens, const Options& options) const = 0; + virtual std::vector process( + rust::Box context, + rust::Fn is_context_cancelled, + const std::vector& tokens, + const Options& options) const = 0; std::unique_ptr model_; }; class EncoderDecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process(const std::vector& tokens, const Options& options) const override { + virtual std::vector process( + rust::Box context, + rust::Fn is_context_cancelled, + const std::vector& tokens, + const Options& options) const override { ctranslate2::TranslationOptions x; x.max_decoding_length = options.max_decoding_length; x.sampling_temperature = options.sampling_temperature; - x.beam_size = options.beam_size; - ctranslate2::TranslationResult result = model_->translate_batch( - { tokens }, - ctranslate2::TranslationOptions{ - } - )[0]; + x.beam_size = 1; + x.callback = [&](ctranslate2::GenerationStepResult result) { + return is_context_cancelled(*context); + }; + ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; return std::move(result.output()); } }; class DecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process(const std::vector& tokens, const Options& options) const override { + virtual std::vector process( + rust::Box context, + rust::Fn is_context_cancelled, + const std::vector& tokens, + const Options& options) const override { ctranslate2::GenerationOptions x; x.include_prompt_in_result = false; x.max_length = options.max_decoding_length; x.sampling_temperature = options.sampling_temperature; - x.beam_size = options.beam_size; + x.beam_size = 1; + x.callback = [&](ctranslate2::GenerationStepResult result) { + return is_context_cancelled(*context); + }; ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); return std::move(result.sequences[0]); } diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index f40f8af..2861972 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -26,11 +26,10 @@ mod ffi { fn inference( &self, context: Box, - is_context_cancelled: fn(Box) -> bool, + is_context_cancelled: fn(&InferenceContext) -> bool, tokens: &[String], max_decoding_length: usize, sampling_temperature: f32, - beam_size: usize, ) -> Vec; } } @@ -60,9 +59,6 @@ pub struct TextInferenceOptions { #[builder(default = "1.0")] sampling_temperature: f32, - - #[builder(default = "2")] - beam_size: usize, } struct InferenceContext(CancellationToken); @@ -104,7 +100,6 @@ impl TextInferenceEngine { encoding.get_tokens(), options.max_decoding_length, options.sampling_temperature, - options.beam_size, ) }) .await