feat: support early stop [TAB-51] (#208)
* bump ctranslate2 to v3.15.0 * enable early stop * support early stopsupport-stop-sequences
parent
e44dfc1c04
commit
007a40c582
|
|
@ -1,4 +1,4 @@
|
|||
FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2 as source
|
||||
FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2 as source
|
||||
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
|
||||
|
||||
ENV CTRANSLATE2_ROOT=/opt/ctranslate2
|
||||
|
|
@ -31,7 +31,7 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
|||
cargo build --features link_shared --release && \
|
||||
cp target/release/tabby /opt/tabby/bin/
|
||||
|
||||
FROM ghcr.io/opennmt/ctranslate2:3.14.0-ubuntu20.04-cuda11.2
|
||||
FROM ghcr.io/opennmt/ctranslate2:3.15.0-ubuntu20.04-cuda11.2
|
||||
|
||||
COPY --from=builder /opt/tabby /opt/tabby
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 45af5ebcb643f205a6709e0bf6c09157d1ecba52
|
||||
Subproject commit d4b6f3849ae1bd67d1de0a037be3d7a7833fac6c
|
||||
|
|
@ -12,11 +12,10 @@ class TextInferenceEngine {
|
|||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
float sampling_temperature
|
||||
) const = 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -12,26 +12,24 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
struct Options {
|
||||
size_t max_decoding_length;
|
||||
float sampling_temperature;
|
||||
size_t beam_size;
|
||||
};
|
||||
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
float sampling_temperature
|
||||
) const {
|
||||
// FIXME(meng): implement the cancellation.
|
||||
if (is_context_cancelled(std::move(context))) {
|
||||
return rust::Vec<rust::String>();
|
||||
}
|
||||
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size});
|
||||
const auto output_tokens = process(
|
||||
std::move(context),
|
||||
std::move(is_context_cancelled),
|
||||
input_tokens,
|
||||
Options{max_decoding_length, sampling_temperature}
|
||||
);
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
|
|
@ -47,34 +45,48 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0;
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const = 0;
|
||||
std::unique_ptr<Model> model_;
|
||||
};
|
||||
|
||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::TranslationOptions x;
|
||||
x.max_decoding_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = options.beam_size;
|
||||
ctranslate2::TranslationResult result = model_->translate_batch(
|
||||
{ tokens },
|
||||
ctranslate2::TranslationOptions{
|
||||
}
|
||||
)[0];
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return std::move(result.output());
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::GenerationOptions x;
|
||||
x.include_prompt_in_result = false;
|
||||
x.max_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = options.beam_size;
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,11 +26,10 @@ mod ffi {
|
|||
fn inference(
|
||||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
is_context_cancelled: fn(Box<InferenceContext>) -> bool,
|
||||
is_context_cancelled: fn(&InferenceContext) -> bool,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
beam_size: usize,
|
||||
) -> Vec<String>;
|
||||
}
|
||||
}
|
||||
|
|
@ -60,9 +59,6 @@ pub struct TextInferenceOptions {
|
|||
|
||||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "2")]
|
||||
beam_size: usize,
|
||||
}
|
||||
|
||||
struct InferenceContext(CancellationToken);
|
||||
|
|
@ -104,7 +100,6 @@ impl TextInferenceEngine {
|
|||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
options.beam_size,
|
||||
)
|
||||
})
|
||||
.await
|
||||
|
|
|
|||
Loading…
Reference in New Issue