refactor: pass step and string token to callback
parent
e8a1688fce
commit
d571248c7f
|
|
@ -7,12 +7,14 @@ namespace tabby {
|
|||
|
||||
struct InferenceContext;
|
||||
|
||||
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature
|
||||
|
|
@ -26,7 +26,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
const auto output_tokens = process(
|
||||
std::move(context),
|
||||
std::move(is_context_cancelled),
|
||||
std::move(callback),
|
||||
input_tokens,
|
||||
Options{max_decoding_length, sampling_temperature}
|
||||
);
|
||||
|
|
@ -47,7 +47,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const = 0;
|
||||
std::unique_ptr<Model> model_;
|
||||
|
|
@ -57,7 +57,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
|||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::TranslationOptions x;
|
||||
|
|
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
|||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
return callback(*context, result.step, rust::String(result.token));
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return std::move(result.output());
|
||||
|
|
@ -76,7 +76,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
|||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(const InferenceContext&)> is_context_cancelled,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
const Options& options) const override {
|
||||
ctranslate2::GenerationOptions x;
|
||||
|
|
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
|||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return is_context_cancelled(*context);
|
||||
return callback(*context, result.step, rust::String(result.token));
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,13 @@ mod ffi {
|
|||
fn inference(
|
||||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
is_context_cancelled: fn(&InferenceContext) -> bool,
|
||||
callback: fn(
|
||||
&InferenceContext,
|
||||
// step
|
||||
usize,
|
||||
// token
|
||||
String,
|
||||
) -> bool,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
|
|
@ -61,7 +67,7 @@ pub struct TextInferenceOptions {
|
|||
sampling_temperature: f32,
|
||||
}
|
||||
|
||||
struct InferenceContext(CancellationToken);
|
||||
pub struct InferenceContext(CancellationToken);
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
|
|
@ -96,7 +102,7 @@ impl TextInferenceEngine {
|
|||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
|context| context.0.is_cancelled(),
|
||||
inference_callback,
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
|
|
@ -117,3 +123,11 @@ impl TextInferenceEngine {
|
|||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
|
||||
if context.0.is_cancelled() {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue