From 301c86a9853bf67be7f6f01f99647b741db6bfd3 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 6 Jun 2023 15:19:20 -0700 Subject: [PATCH] pass token_ids from inference --- .../include/ctranslate2.h | 2 +- .../ctranslate2-bindings/src/ctranslate2.cc | 39 ++++++++++++------- crates/ctranslate2-bindings/src/lib.rs | 22 ++--------- 3 files changed, 28 insertions(+), 35 deletions(-) diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index b923c6a..aa67db2 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -12,7 +12,7 @@ typedef rust::Fn Infere class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual rust::Vec inference( + virtual rust::Vec inference( rust::Box context, InferenceCallback callback, rust::Slice tokens, diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index d45c2d8..3b71b19 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -15,7 +15,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { }; public: - rust::Vec inference( + rust::Vec inference( rust::Box context, InferenceCallback callback, rust::Slice tokens, @@ -24,18 +24,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine { ) const { // Inference. std::vector input_tokens(tokens.begin(), tokens.end()); - const auto output_tokens = process( + return process( std::move(context), std::move(callback), input_tokens, Options{max_decoding_length, sampling_temperature} ); - - // Convert to rust vec. - rust::Vec output; - output.reserve(output_tokens.size()); - std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); - return output; } static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { @@ -45,7 +39,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, @@ -55,7 +49,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { class EncoderDecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, @@ -64,17 +58,24 @@ class EncoderDecoderImpl : public TextInferenceEngineImpl output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { - return callback(*context, result.step, result.token_id, result.token); + bool stop = callback(*context, result.step, result.token_id, result.token); + if (!stop) { + output_ids.push_back(result.token_id); + } else if (result.is_last) { + output_ids.push_back(result.token_id); + } + return stop; }; ctranslate2::TranslationResult result = model_->translate_batch({ tokens }, x)[0]; - return std::move(result.output()); + return output_ids; } }; class DecoderImpl : public TextInferenceEngineImpl { protected: - virtual std::vector process( + virtual rust::Vec process( rust::Box context, InferenceCallback callback, const std::vector& tokens, @@ -84,11 +85,19 @@ class DecoderImpl : public TextInferenceEngineImpl output_ids; x.callback = [&](ctranslate2::GenerationStepResult result) { - return callback(*context, result.step, result.token_id, result.token); + bool stop = callback(*context, result.step, result.token_id, result.token); + if (!stop) { + output_ids.push_back(result.token_id); + } else if (result.is_last) { + output_ids.push_back(result.token_id); + } + return stop; }; ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get(); - return std::move(result.sequences[0]); + return output_ids; } }; diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 3dd9613..ffeece7 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -39,7 +39,7 @@ mod ffi { tokens: &[String], max_decoding_length: usize, sampling_temperature: f32, - ) -> Vec; + ) -> Vec; } } @@ -122,7 +122,7 @@ impl TextInferenceEngine { }; let context = InferenceContext::new(stop_re, cancel_for_inference); - let output_tokens = tokio::task::spawn_blocking(move || { + let output_ids = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( context, @@ -134,23 +134,7 @@ impl TextInferenceEngine { }) .await .expect("Inference failed"); - let output_ids: Vec = output_tokens - .iter() - .filter_map(|x| match self.tokenizer.token_to_id(x) { - Some(y) => Some(y), - None => { - println!("Warning: token ({}) missed in vocab", x); - None - } - }) - .collect(); - let output_text = self.tokenizer.decode(output_ids, true).unwrap(); - for stop_word in options.stop_words { - if let Some(stripped_text) = output_text.strip_suffix(stop_word) { - return stripped_text.to_string(); - } - } - output_text + self.tokenizer.decode(output_ids, true).unwrap() } }