add token to callback
parent
d571248c7f
commit
040af1a374
|
|
@ -7,7 +7,7 @@ namespace tabby {
|
||||||
|
|
||||||
struct InferenceContext;
|
struct InferenceContext;
|
||||||
|
|
||||||
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
|
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
|
||||||
|
|
||||||
class TextInferenceEngine {
|
class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return callback(*context, result.step, rust::String(result.token));
|
return callback(*context, result.step, result.token_id, result.token);
|
||||||
};
|
};
|
||||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||||
return std::move(result.output());
|
return std::move(result.output());
|
||||||
|
|
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return callback(*context, result.step, rust::String(result.token));
|
return callback(*context, result.step, result.token_id, result.token);
|
||||||
};
|
};
|
||||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||||
return std::move(result.sequences[0]);
|
return std::move(result.sequences[0]);
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,11 @@ mod ffi {
|
||||||
&self,
|
&self,
|
||||||
context: Box<InferenceContext>,
|
context: Box<InferenceContext>,
|
||||||
callback: fn(
|
callback: fn(
|
||||||
&InferenceContext,
|
&mut InferenceContext,
|
||||||
// step
|
// step
|
||||||
usize,
|
usize,
|
||||||
|
// token_id
|
||||||
|
u32,
|
||||||
// token
|
// token
|
||||||
String,
|
String,
|
||||||
) -> bool,
|
) -> bool,
|
||||||
|
|
@ -67,7 +69,16 @@ pub struct TextInferenceOptions {
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext(CancellationToken);
|
pub struct InferenceContext {
|
||||||
|
cancel: CancellationToken,
|
||||||
|
output_text: String
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InferenceContext {
|
||||||
|
fn new(cancel: CancellationToken) -> Self {
|
||||||
|
InferenceContext { cancel, output_text: "".to_owned() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TextInferenceEngine {
|
pub struct TextInferenceEngine {
|
||||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
|
|
@ -97,7 +108,7 @@ impl TextInferenceEngine {
|
||||||
let cancel_for_inference = cancel.clone();
|
let cancel_for_inference = cancel.clone();
|
||||||
let _guard = cancel.drop_guard();
|
let _guard = cancel.drop_guard();
|
||||||
|
|
||||||
let context = InferenceContext(cancel_for_inference);
|
let context = InferenceContext::new(cancel_for_inference);
|
||||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
|
|
@ -124,8 +135,9 @@ impl TextInferenceEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
|
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
||||||
if context.0.is_cancelled() {
|
context.output_text.push_str(&token);
|
||||||
|
if context.cancel.is_cancelled() {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue