pass token_ids from inference
parent
fa84e376f4
commit
301c86a985
|
|
@ -12,7 +12,7 @@ typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> Infere
|
|||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
virtual rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
};
|
||||
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Vec<uint32_t> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
|
|
@ -24,18 +24,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
) const {
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
const auto output_tokens = process(
|
||||
return process(
|
||||
std::move(context),
|
||||
std::move(callback),
|
||||
input_tokens,
|
||||
Options{max_decoding_length, sampling_temperature}
|
||||
);
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
output.reserve(output_tokens.size());
|
||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
|
|
@ -45,7 +39,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
|
|
@ -55,7 +49,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
|
||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
|
|
@ -64,17 +58,24 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
|||
x.max_decoding_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return callback(*context, result.step, result.token_id, result.token);
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||
return std::move(result.output());
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(
|
||||
virtual rust::Vec<uint32_t> process(
|
||||
rust::Box<InferenceContext> context,
|
||||
InferenceCallback callback,
|
||||
const std::vector<std::string>& tokens,
|
||||
|
|
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
|||
x.max_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = 1;
|
||||
|
||||
rust::Vec<uint32_t> output_ids;
|
||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||
return callback(*context, result.step, result.token_id, result.token);
|
||||
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||
if (!stop) {
|
||||
output_ids.push_back(result.token_id);
|
||||
} else if (result.is_last) {
|
||||
output_ids.push_back(result.token_id);
|
||||
}
|
||||
return stop;
|
||||
};
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
return output_ids;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ mod ffi {
|
|||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
) -> Vec<String>;
|
||||
) -> Vec<u32>;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ impl TextInferenceEngine {
|
|||
};
|
||||
|
||||
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
|
|
@ -134,23 +134,7 @@ impl TextInferenceEngine {
|
|||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
let output_ids: Vec<u32> = output_tokens
|
||||
.iter()
|
||||
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
||||
Some(y) => Some(y),
|
||||
None => {
|
||||
println!("Warning: token ({}) missed in vocab", x);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let output_text = self.tokenizer.decode(output_ids, true).unwrap();
|
||||
for stop_word in options.stop_words {
|
||||
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
|
||||
return stripped_text.to_string();
|
||||
}
|
||||
}
|
||||
output_text
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue