pass token_ids from inference

support-stop-sequences
Meng Zhang 2023-06-06 15:19:20 -07:00
parent fa84e376f4
commit 301c86a985
3 changed files with 28 additions and 35 deletions

View File

@ -12,7 +12,7 @@ typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> Infere
class TextInferenceEngine { class TextInferenceEngine {
public: public:
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference( virtual rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
InferenceCallback callback, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,

View File

@ -15,7 +15,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}; };
public: public:
rust::Vec<rust::String> inference( rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
InferenceCallback callback, InferenceCallback callback,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
@ -24,18 +24,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
) const { ) const {
// Inference. // Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process( return process(
std::move(context), std::move(context),
std::move(callback), std::move(callback),
input_tokens, input_tokens,
Options{max_decoding_length, sampling_temperature} Options{max_decoding_length, sampling_temperature}
); );
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
} }
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) { static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
@ -45,7 +39,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
InferenceCallback callback, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
@ -55,7 +49,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> { class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
InferenceCallback callback, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
@ -64,17 +58,24 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.max_decoding_length = options.max_decoding_length; x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, result.token_id, result.token); bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
}; };
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output()); return output_ids;
} }
}; };
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> { class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected: protected:
virtual std::vector<std::string> process( virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context, rust::Box<InferenceContext> context,
InferenceCallback callback, InferenceCallback callback,
const std::vector<std::string>& tokens, const std::vector<std::string>& tokens,
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.max_length = options.max_decoding_length; x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature; x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1; x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) { x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, result.token_id, result.token); bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
}; };
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]); return output_ids;
} }
}; };

View File

@ -39,7 +39,7 @@ mod ffi {
tokens: &[String], tokens: &[String],
max_decoding_length: usize, max_decoding_length: usize,
sampling_temperature: f32, sampling_temperature: f32,
) -> Vec<String>; ) -> Vec<u32>;
} }
} }
@ -122,7 +122,7 @@ impl TextInferenceEngine {
}; };
let context = InferenceContext::new(stop_re, cancel_for_inference); let context = InferenceContext::new(stop_re, cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || { let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
context, context,
@ -134,23 +134,7 @@ impl TextInferenceEngine {
}) })
.await .await
.expect("Inference failed"); .expect("Inference failed");
let output_ids: Vec<u32> = output_tokens self.tokenizer.decode(output_ids, true).unwrap()
.iter()
.filter_map(|x| match self.tokenizer.token_to_id(x) {
Some(y) => Some(y),
None => {
println!("Warning: token ({}) missed in vocab", x);
None
}
})
.collect();
let output_text = self.tokenizer.decode(output_ids, true).unwrap();
for stop_word in options.stop_words {
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
return stripped_text.to_string();
}
}
output_text
} }
} }