add token to callback

support-stop-sequences
Meng Zhang 2023-06-06 13:58:04 -07:00
parent d571248c7f
commit 040af1a374
3 changed files with 20 additions and 8 deletions

View File

@ -7,7 +7,7 @@ namespace tabby {
struct InferenceContext; struct InferenceContext;
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback; typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
class TextInferenceEngine { class TextInferenceEngine {
public: public:

View File

@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, rust::String(result.token)); return callback(*context, result.step, result.token_id, result.token);
}; };
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output()); return std::move(result.output());
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, rust::String(result.token)); return callback(*context, result.step, result.token_id, result.token);
}; };
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]); return std::move(result.sequences[0]);

View File

@ -27,9 +27,11 @@ mod ffi {
&self, &self,
context: Box<InferenceContext>, context: Box<InferenceContext>,
callback: fn( callback: fn(
&InferenceContext, &mut InferenceContext,
// step // step
usize, usize,
// token_id
u32,
// token // token
String, String,
) -> bool, ) -> bool,
@ -67,7 +69,16 @@ pub struct TextInferenceOptions {
sampling_temperature: f32, sampling_temperature: f32,
} }
pub struct InferenceContext(CancellationToken); pub struct InferenceContext {
cancel: CancellationToken,
output_text: String
}
impl InferenceContext {
fn new(cancel: CancellationToken) -> Self {
InferenceContext { cancel, output_text: "".to_owned() }
}
}
pub struct TextInferenceEngine { pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
@ -97,7 +108,7 @@ impl TextInferenceEngine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let context = InferenceContext(cancel_for_inference); let context = InferenceContext::new(cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || { let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
@ -124,8 +135,9 @@ impl TextInferenceEngine {
} }
} }
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool { fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
if context.0.is_cancelled() { context.output_text.push_str(&token);
if context.cancel.is_cancelled() {
true true
} else { } else {
false false