refactor: pass step and string token to callback

support-stop-sequences
Meng Zhang 2023-06-06 13:24:24 -07:00
parent e8a1688fce
commit d571248c7f
3 changed files with 27 additions and 11 deletions

View File

@ -7,12 +7,14 @@ namespace tabby {
struct InferenceContext; struct InferenceContext;
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
class TextInferenceEngine { class TextInferenceEngine {
public: public:
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference( virtual rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature float sampling_temperature

View File

@ -17,7 +17,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
public: public:
rust::Vec<rust::String> inference( rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature float sampling_temperature
@ -26,7 +26,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process( const auto output_tokens = process(
std::move(context), std::move(context),
std::move(is_context_cancelled), std::move(callback),
input_tokens, input_tokens,
Options{max_decoding_length, sampling_temperature} Options{max_decoding_length, sampling_temperature}
); );
@ -47,7 +47,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
protected: protected:
virtual std::vector<std::string> process( virtual std::vector<std::string> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const = 0; const Options& options) const = 0;
std::unique_ptr<Model> model_; std::unique_ptr<Model> model_;
@ -57,7 +57,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
protected: protected:
virtual std::vector<std::string> process( virtual std::vector<std::string> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const override { const Options& options) const override {
ctranslate2::TranslationOptions x; ctranslate2::TranslationOptions x;
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context); return callback(*context, result.step, rust::String(result.token));
}; };
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output()); return std::move(result.output());
@ -76,7 +76,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
protected: protected:
virtual std::vector<std::string> process( virtual std::vector<std::string> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
rust::Fn<bool(const InferenceContext&)> is_context_cancelled, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
const Options& options) const override { const Options& options) const override {
ctranslate2::GenerationOptions x; ctranslate2::GenerationOptions x;
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return is_context_cancelled(*context); return callback(*context, result.step, rust::String(result.token));
}; };
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]); return std::move(result.sequences[0]);

View File

@ -26,7 +26,13 @@ mod ffi {
fn inference( fn inference(
&self, &self,
context: Box<InferenceContext>, context: Box<InferenceContext>,
is_context_cancelled: fn(&InferenceContext) -> bool, callback: fn(
&InferenceContext,
// step
usize,
// token
String,
) -> bool,
tokens: &[String], tokens: &[String],
max_decoding_length: usize, max_decoding_length: usize,
sampling_temperature: f32, sampling_temperature: f32,
@ -61,7 +67,7 @@ pub struct TextInferenceOptions {
sampling_temperature: f32, sampling_temperature: f32,
} }
struct InferenceContext(CancellationToken); pub struct InferenceContext(CancellationToken);
pub struct TextInferenceEngine { pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
@ -96,7 +102,7 @@ impl TextInferenceEngine {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
context, context,
|context| context.0.is_cancelled(), inference_callback,
encoding.get_tokens(), encoding.get_tokens(),
options.max_decoding_length, options.max_decoding_length,
options.sampling_temperature, options.sampling_temperature,
@ -117,3 +123,11 @@ impl TextInferenceEngine {
self.tokenizer.decode(output_ids, true).unwrap() self.tokenizer.decode(output_ids, true).unwrap()
} }
} }
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
if context.0.is_cancelled() {
true
} else {
false
}
}