add token to callback
parent
d571248c7f
commit
040af1a374
|
|
@ -7,7 +7,7 @@ namespace tabby {
|
|||
|
||||
struct InferenceContext;
|
||||
|
||||
typedef rust::Fn<bool(const InferenceContext&, size_t, rust::String)> InferenceCallback;
|
||||
typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> InferenceCallback;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
|||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return callback(*context, result.step, rust::String(result.token));
|
||||
return callback(*context, result.step, result.token_id, result.token);
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return std::move(result.output());
|
||||
|
|
@ -85,7 +85,7 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
|||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return callback(*context, result.step, rust::String(result.token));
|
||||
return callback(*context, result.step, result.token_id, result.token);
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
|
|
|
|||
|
|
@ -27,9 +27,11 @@ mod ffi {
|
|||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
callback: fn(
|
||||
&InferenceContext,
|
||||
&mut InferenceContext,
|
||||
// step
|
||||
usize,
|
||||
// token_id
|
||||
u32,
|
||||
// token
|
||||
String,
|
||||
) -> bool,
|
||||
|
|
@ -67,7 +69,16 @@ pub struct TextInferenceOptions {
|
|||
sampling_temperature: f32,
|
||||
}
|
||||
|
||||
pub struct InferenceContext(CancellationToken);
|
||||
pub struct InferenceContext {
|
||||
cancel: CancellationToken,
|
||||
output_text: String
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(cancel: CancellationToken) -> Self {
|
||||
InferenceContext { cancel, output_text: "".to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
|
|
@ -97,7 +108,7 @@ impl TextInferenceEngine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let context = InferenceContext(cancel_for_inference);
|
||||
let context = InferenceContext::new(cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
|
|
@ -124,8 +135,9 @@ impl TextInferenceEngine {
|
|||
}
|
||||
}
|
||||
|
||||
fn inference_callback(context: &InferenceContext, step: usize, token: String) -> bool {
|
||||
if context.0.is_cancelled() {
|
||||
fn inference_callback(context: &mut InferenceContext, _step: usize, _token_id: u32, token: String) -> bool {
|
||||
context.output_text.push_str(&token);
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
|
|
|
|||
Loading…
Reference in New Issue