feat: support early stop [TAB-51] (#208)

* bump ctranslate2 to v3.15.0

* enable early stop

* support early stop
support-stop-sequences
Meng Zhang 2023-06-06 05:46:17 -07:00 committed by GitHub
parent e44dfc1c04
commit 007a40c582
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 32 deletions

View File

@ -1,4 +1,4 @@
FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2 as source FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2 as source
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
ENV CTRANSLATE2_ROOT=/opt/ctranslate2 ENV CTRANSLATE2_ROOT=/opt/ctranslate2
@ -31,7 +31,7 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
cargo build --features link_shared --release && \ cargo build --features link_shared --release && \
cp target/release/tabby /opt/tabby/bin/ cp target/release/tabby /opt/tabby/bin/
FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2 FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2
COPY --from=builder /opt/tabby /opt/tabby COPY --from=builder /opt/tabby /opt/tabby

@ -1 +1 @@
Subproject commit 45af5ebcb643f205a6709e0bf6c09157d1ecba52 Subproject commit d4b6f3849ae1bd67d1de0a037be3d7a7833fac6c

View File

@ -12,11 +12,10 @@ class TextInferenceEngine {
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference( virtual rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled, rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature, float sampling_temperature
size_t beam_size
) const = 0; ) const = 0;
}; };

View File

@ -12,26 +12,24 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
struct Options { struct Options {
size_t max_decoding_length; size_t max_decoding_length;
float sampling_temperature; float sampling_temperature;
size_t beam_size;
}; };
public: public:
rust::Vec<rust::String> inference( rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled, rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature, float sampling_temperature
size_t beam_size
) const { ) const {
// FIXME(meng): implement the cancellation.
if (is_context_cancelled(std::move(context))) {
return rust::Vec<rust::String>();
}
// Inference. // Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size}); const auto output_tokens = process(
std::move(context),
std::move(is_context_cancelled),
input_tokens,
Options{max_decoding_length, sampling_temperature}
);
// Convert to rust vec. // Convert to rust vec.
rust::Vec<rust::String> output; rust::Vec<rust::String> output;
@ -47,34 +45,48 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
protected: protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0; virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
const std::vector<std::string>& tokens,
const Options& options) const = 0;
std::unique_ptr<Model> model_; std::unique_ptr<Model> model_;
}; };
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> { class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected: protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override { virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::TranslationOptions x; ctranslate2::TranslationOptions x;
x.max_decoding_length = options.max_decoding_length; x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = options.beam_size; x.beam_size = 1;
ctranslate2::TranslationResult result = model_->translate_batch( x.callback = [&](ctranslate2::GenerationStepResult result) {
{ tokens }, return is_context_cancelled(*context);
ctranslate2::TranslationOptions{ };
} ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
)[0];
return std::move(result.output()); return std::move(result.output());
} }
}; };
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> { class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected: protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override { virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::GenerationOptions x; ctranslate2::GenerationOptions x;
x.include_prompt_in_result = false; x.include_prompt_in_result = false;
x.max_length = options.max_decoding_length; x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = options.beam_size; x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context);
};
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]); return std::move(result.sequences[0]);
} }

View File

@ -26,11 +26,10 @@ mod ffi {
fn inference( fn inference(
&self, &self,
context: Box<InferenceContext>, context: Box<InferenceContext>,
is_context_cancelled: fn(Box<InferenceContext>) -> bool, is_context_cancelled: fn(&InferenceContext) -> bool,
tokens: &[String], tokens: &[String],
max_decoding_length: usize, max_decoding_length: usize,
sampling_temperature: f32, sampling_temperature: f32,
beam_size: usize,
) -> Vec<String>; ) -> Vec<String>;
} }
} }
@ -60,9 +59,6 @@ pub struct TextInferenceOptions {
#[builder(default = "1.0")] #[builder(default = "1.0")]
sampling_temperature: f32, sampling_temperature: f32,
#[builder(default = "2")]
beam_size: usize,
} }
struct InferenceContext(CancellationToken); struct InferenceContext(CancellationToken);
@ -104,7 +100,6 @@ impl TextInferenceEngine {
encoding.get_tokens(), encoding.get_tokens(),
options.max_decoding_length, options.max_decoding_length,
options.sampling_temperature, options.sampling_temperature,
options.beam_size,
) )
}) })
.await .await