refactor: pass step and string token to callback

support-stop-sequences
Meng Zhang 2023-06-06 13:24:24 -07:00
parent e8a1688fce
commit d571248c7f
3 changed files with 27 additions and 11 deletions

View File

@ -7,12 +7,14 @@ namespace tabby {
struct InferenceContext;
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature

View File

@ -17,7 +17,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
public:
rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature
@ -26,7 +26,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process(
std::move(context),
std::move(is_context_cancelled),
std::move(callback),
input_tokens,
Options{max_decoding_length, sampling_temperature}
);
@ -47,7 +47,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
protected:
virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const = 0;
std::unique_ptr<Model> model_;
@ -57,7 +57,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
protected:
virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::TranslationOptions x;
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context);
return callback(*context, result.step, rust::String(result.token));
};
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output());
@ -76,7 +76,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
protected:
virtual std::vector<std::string> process(
rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
InferenceCallback callback,
const std::vector<std::string>& tokens,
const Options& options) const override {
ctranslate2::GenerationOptions x;
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context);
return callback(*context, result.step, rust::String(result.token));
};
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]);

View File

@ -26,7 +26,13 @@ mod ffi {
fn inference(
&self,
context: Box<InferenceContext>,
is_context_cancelled: fn(&InferenceContext) -> bool,
callback: fn(
&InferenceContext,
// step
usize,
// token
String,
) -> bool,
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
@ -61,7 +67,7 @@ pub struct TextInferenceOptions {
sampling_temperature: f32,
}
struct InferenceContext(CancellationToken);
pub struct InferenceContext(CancellationToken);
pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
@ -96,7 +102,7 @@ impl TextInferenceEngine {
let context = Box::new(context);
engine.inference(
context,
|context| context.0.is_cancelled(),
inference_callback,
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,
@ -117,3 +123,11 @@ impl TextInferenceEngine {
self.tokenizer.decode(output_ids, true).unwrap()
}
}
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
if context.0.is_cancelled() {
true
} else {
false
}
}