pass token_ids from inference
parent
fa84e376f4
commit
301c86a985
|
|
@ -12,7 +12,7 @@ typedef rust::Fn<bool(InferenceContext&, size_t, uint32_t, rust::String)> Infere
|
||||||
class TextInferenceEngine {
|
class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
virtual rust::Vec<rust::String> inference(
|
virtual rust::Vec<uint32_t> inference(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
InferenceCallback callback,
|
InferenceCallback callback,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
rust::Vec<rust::String> inference(
|
rust::Vec<uint32_t> inference(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
InferenceCallback callback,
|
InferenceCallback callback,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
|
|
@ -24,18 +24,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
) const {
|
) const {
|
||||||
// Inference.
|
// Inference.
|
||||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||||
const auto output_tokens = process(
|
return process(
|
||||||
std::move(context),
|
std::move(context),
|
||||||
std::move(callback),
|
std::move(callback),
|
||||||
input_tokens,
|
input_tokens,
|
||||||
Options{max_decoding_length, sampling_temperature}
|
Options{max_decoding_length, sampling_temperature}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Convert to rust vec.
|
|
||||||
rust::Vec<rust::String> output;
|
|
||||||
output.reserve(output_tokens.size());
|
|
||||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
|
||||||
return output;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||||
|
|
@ -45,7 +39,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual rust::Vec<uint32_t> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
InferenceCallback callback,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
|
|
@ -55,7 +49,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
|
|
||||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual rust::Vec<uint32_t> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
InferenceCallback callback,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
|
|
@ -64,17 +58,24 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translato
|
||||||
x.max_decoding_length = options.max_decoding_length;
|
x.max_decoding_length = options.max_decoding_length;
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
|
rust::Vec<uint32_t> output_ids;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return callback(*context, result.step, result.token_id, result.token);
|
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||||
|
if (!stop) {
|
||||||
|
output_ids.push_back(result.token_id);
|
||||||
|
} else if (result.is_last) {
|
||||||
|
output_ids.push_back(result.token_id);
|
||||||
|
}
|
||||||
|
return stop;
|
||||||
};
|
};
|
||||||
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0];
|
||||||
return std::move(result.output());
|
return output_ids;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||||
protected:
|
protected:
|
||||||
virtual std::vector<std::string> process(
|
virtual rust::Vec<uint32_t> process(
|
||||||
rust::Box<InferenceContext> context,
|
rust::Box<InferenceContext> context,
|
||||||
InferenceCallback callback,
|
InferenceCallback callback,
|
||||||
const std::vector<std::string>& tokens,
|
const std::vector<std::string>& tokens,
|
||||||
|
|
@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, Decod
|
||||||
x.max_length = options.max_decoding_length;
|
x.max_length = options.max_decoding_length;
|
||||||
x.sampling_temperature = options.sampling_temperature;
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
x.beam_size = 1;
|
x.beam_size = 1;
|
||||||
|
|
||||||
|
rust::Vec<uint32_t> output_ids;
|
||||||
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
x.callback = [&](ctranslate2::GenerationStepResult result) {
|
||||||
return callback(*context, result.step, result.token_id, result.token);
|
bool stop = callback(*context, result.step, result.token_id, result.token);
|
||||||
|
if (!stop) {
|
||||||
|
output_ids.push_back(result.token_id);
|
||||||
|
} else if (result.is_last) {
|
||||||
|
output_ids.push_back(result.token_id);
|
||||||
|
}
|
||||||
|
return stop;
|
||||||
};
|
};
|
||||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||||
return std::move(result.sequences[0]);
|
return output_ids;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ mod ffi {
|
||||||
tokens: &[String],
|
tokens: &[String],
|
||||||
max_decoding_length: usize,
|
max_decoding_length: usize,
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
) -> Vec<String>;
|
) -> Vec<u32>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,7 +122,7 @@ impl TextInferenceEngine {
|
||||||
};
|
};
|
||||||
|
|
||||||
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
||||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
let output_ids = tokio::task::spawn_blocking(move || {
|
||||||
let context = Box::new(context);
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
context,
|
context,
|
||||||
|
|
@ -134,23 +134,7 @@ impl TextInferenceEngine {
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.expect("Inference failed");
|
.expect("Inference failed");
|
||||||
let output_ids: Vec<u32> = output_tokens
|
self.tokenizer.decode(output_ids, true).unwrap()
|
||||||
.iter()
|
|
||||||
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
|
||||||
Some(y) => Some(y),
|
|
||||||
None => {
|
|
||||||
println!("Warning: token ({}) missed in vocab", x);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let output_text = self.tokenizer.decode(output_ids, true).unwrap();
|
|
||||||
for stop_word in options.stop_words {
|
|
||||||
if let Some(stripped_text) = output_text.strip_suffix(stop_word) {
|
|
||||||
return stripped_text.to_string();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output_text
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue