refactor: pass step and string token to callback
parent
e8a1688fce
commit
d571248c7f
|
|
@ -7,12 +7,14 @@ namespace tabby {
|
||||||
|
|
||||||
struct InferenceContext;
|
struct InferenceContext;
|
||||||
|
|
||||||
|
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
|
||||||
|
|
||||||
class TextInferenceEngine {
|
class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
virtual rust::Vec<rust::String> inference(
|
virtual rust::Vec<rust::String> inference(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
InferenceCallback callback,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
size_t max_decoding_length,
|
size_t max_decoding_length,
|
||||||
float sampling_temperature
|
float sampling_temperature
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
rust::Vec<rust::String> inference(
|
rust::Vec<rust::String> inference(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
InferenceCallback callback,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
size_t max_decoding_length,
|
size_t max_decoding_length,
|
||||||
float sampling_temperature
|
float sampling_temperature
|
||||||
|
|
@ -26,7 +26,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||||
const auto output_tokens = process(
|
const auto output_tokens = process(
|
||||||
std::move(context),
|
std::move(context),
|
||||||
std::move(is_context_cancelled),
|
std::move(callback),
|
||||||
input_tokens,
|
input_tokens,
|
||||||
Options{max_decoding_length, sampling_temperature}
|
Options{max_decoding_length, sampling_temperature}
|
||||||
);
|
);
|
||||||
|
|
@ -47,7 +47,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual std::vector<std::string> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
const Options& options) const = 0;
|
const Options& options) const = 0;
|
||||||
std::unique_ptr<Model> model_;
|
std::unique_ptr<Model> model_;
|
||||||
|
|
@ -57,7 +57,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual std::vector<std::string> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
const Options& options) const override {
|
const Options& options) const override {
|
||||||
ctranslate2::TranslationOptions x;
|
ctranslate2::TranslationOptions x;
|
||||||
|
|
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return is_context_cancelled(*context);
|
return callback(*context, result.step, rust::String(result.token));
|
||||||
};
|
};
|
||||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||||
return std::move(result.output());
|
return std::move(result.output());
|
||||||
|
|
@ -76,7 +76,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual std::vector<std::string> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
const Options& options) const override {
|
const Options& options) const override {
|
||||||
ctranslate2::GenerationOptions x;
|
ctranslate2::GenerationOptions x;
|
||||||
|
|
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return is_context_cancelled(*context);
|
return callback(*context, result.step, rust::String(result.token));
|
||||||
};
|
};
|
||||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||||
return std::move(result.sequences[0]);
|
return std::move(result.sequences[0]);
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,13 @@ mod ffi {
|
||||||
fn inference(
|
fn inference(
|
||||||
&self,
|
&self,
|
||||||
context: Box<InferenceContext>,
|
context: Box<InferenceContext>,
|
||||||
is_context_cancelled: fn(&InferenceContext) -> bool,
|
callback: fn(
|
||||||
|
&InferenceContext,
|
||||||
|
// step
|
||||||
|
usize,
|
||||||
|
// token
|
||||||
|
String,
|
||||||
|
) -> bool,
|
||||||
tokens: &[String],
|
tokens: &[String],
|
||||||
max_decoding_length: usize,
|
max_decoding_length: usize,
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
|
|
@ -61,7 +67,7 @@ pub struct TextInferenceOptions {
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InferenceContext(CancellationToken);
|
pub struct InferenceContext(CancellationToken);
|
||||||
|
|
||||||
pub struct TextInferenceEngine {
|
pub struct TextInferenceEngine {
|
||||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
|
|
@ -96,7 +102,7 @@ impl TextInferenceEngine {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
context,
|
context,
|
||||||
|context| context.0.is_cancelled(),
|
inference_callback,
|
||||||
encoding.get_tokens(),
|
encoding.get_tokens(),
|
||||||
options.max_decoding_length,
|
options.max_decoding_length,
|
||||||
options.sampling_temperature,
|
options.sampling_temperature,
|
||||||
|
|
@ -117,3 +123,11 @@ impl TextInferenceEngine {
|
||||||
self.tokenizer.decode(output_ids, true).unwrap()
|
self.tokenizer.decode(output_ids, true).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
|
||||||
|
if context.0.is_cancelled() {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue