pass token_ids from inference

support-stop-sequences
Meng Zhang 2023-06-06 15:19:20 -07:00
parent fa84e376f4
commit 301c86a985
3 changed files with 28 additions and 35 deletions

View File

@ -12,7 +12,7 @@ typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> Infere
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference(
virtual rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,

View File

@ -15,7 +15,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
};
public:
rust::Vec<rust::String> inference(
rust::Vec<uint32_t> inference(
rust::Box<InferenceContext> context,
InferenceCallback callback,
rust::Slice<const rust::String> tokens,
@ -24,18 +24,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
) const {
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
const auto output_tokens = process(
return process(
std::move(context),
std::move(callback),
input_tokens,
Options{max_decoding_length, sampling_temperature}
);
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
@ -45,7 +39,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
protected:
virtual std::vector<std::string> process(
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
@ -55,7 +49,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected:
virtual std::vector<std::string> process(
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
@ -64,17 +58,24 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, result.token_id, result.token);
bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
};
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
return std::move(result.output());
return output_ids;
}
};
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected:
virtual std::vector<std::string> process(
virtual rust::Vec<uint32_t> process(
rust::Box<InferenceContext> context,
InferenceCallback callback,
const std::vector<std::string>& tokens,
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = 1;
rust::Vec<uint32_t> output_ids;
x.callback = [&](ctranslate2::GenerationStepResult result) {
return callback(*context, result.step, result.token_id, result.token);
bool stop = callback(*context, result.step, result.token_id, result.token);
if (!stop) {
output_ids.push_back(result.token_id);
} else if (result.is_last) {
output_ids.push_back(result.token_id);
}
return stop;
};
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]);
return output_ids;
}
};

View File

@ -39,7 +39,7 @@ mod ffi {
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
) -> Vec<String>;
) -> Vec<u32>;
}
}
@ -122,7 +122,7 @@ impl TextInferenceEngine {
};
let context = InferenceContext::new(stop_re, cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || {
let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
@ -134,23 +134,7 @@ impl TextInferenceEngine {
})
.await
.expect("Inference failed");
let output_ids: Vec<u32> = output_tokens
.iter()
.filter_map(|x| match self.tokenizer.token_to_id(x) {
Some(y) => Some(y),
None => {
println!("Warning: token ({}) missed in vocab", x);
None
}
})
.collect();
let output_text = self.tokenizer.decode(output_ids, true).unwrap();
for stop_word in options.stop_words {
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
return stripped_text.to_string();
}
}
output_text
self.tokenizer.decode(output_ids, true).unwrap()
}
}