add token to callback

support-stop-sequences
Meng Zhang 2023-06-06 13:58:04 -07:00
parent d571248c7f
commit 040af1a374
3 changed files with 20 additions and 8 deletions

View File

@ -7,7 +7,7 @@ namespace tabby {
struct InferenceContext;
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
class TextInferenceEngine {
public:

View File

@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, rust::String(result.token));
return callback(*context, result.step, result.token_id, result.token);
};
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output());
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, rust::String(result.token));
return callback(*context, result.step, result.token_id, result.token);
};
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]);

View File

@ -27,9 +27,11 @@ mod ffi {
&self,
context: Box<InferenceContext>,
callback: fn(
&InferenceContext,
&mut InferenceContext,
// step
usize,
// token_id
u32,
// token
String,
) -> bool,
@ -67,7 +69,16 @@ pub struct TextInferenceOptions {
sampling_temperature: f32,
}
pub struct InferenceContext(CancellationToken);
pub struct InferenceContext {
cancel: CancellationToken,
output_text: String
}
impl InferenceContext {
fn new(cancel: CancellationToken) -> Self {
InferenceContext { cancel, output_text: "".to_owned() }
}
}
pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
@ -97,7 +108,7 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let context = InferenceContext(cancel_for_inference);
let context = InferenceContext::new(cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
@ -124,8 +135,9 @@ impl TextInferenceEngine {
}
}
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
if context.0.is_cancelled() {
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
context.output_text.push_str(&token);
if context.cancel.is_cancelled() {
true
} else {
false