refactor: extract TextInferenceEngineImpl to reduce duplications between EncoderDecoderImpl and DecoderImpl #189

docs-add-demo
Meng Zhang 2023-06-04 15:28:39 -07:00 committed by GitHub
parent 6de61f45bb
commit 2bf5bcd0cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 44 deletions

1
Cargo.lock generated
View File

@ -491,6 +491,7 @@ dependencies = [
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-util",
] ]
[[package]] [[package]]

View File

@ -17,3 +17,4 @@ lazy_static = "1.4.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serdeconv = "0.4.1" serdeconv = "0.4.1"
tokio = "1.28" tokio = "1.28"
tokio-util = "0.7"

View File

@ -8,6 +8,7 @@ cxx = "1.0"
derive_builder = "0.12.0" derive_builder = "0.12.0"
tokenizers = "0.13.3" tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] } tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"
@ -15,5 +16,5 @@ cmake = { version = "0.1", optional = true }
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true } rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
[features] [features]
default = [ "dep:cmake", "dep:rust-cxx-cmake-bridge" ] default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
link_shared = [] link_shared = []

View File

@ -5,10 +5,14 @@
namespace tabby { namespace tabby {
struct InferenceContext;
class TextInferenceEngine { class TextInferenceEngine {
public: public:
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference( virtual rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature, float sampling_temperature,

View File

@ -6,24 +6,32 @@
namespace tabby { namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {} TextInferenceEngine::~TextInferenceEngine() {}
class EncoderDecoderImpl: public TextInferenceEngine { template <class Model, class Child>
class TextInferenceEngineImpl : public TextInferenceEngine {
protected:
struct Options {
size_t max_decoding_length;
float sampling_temperature;
size_t beam_size;
};
public: public:
rust::Vec<rust::String> inference( rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
float sampling_temperature, float sampling_temperature,
size_t beam_size size_t beam_size
) const { ) const {
// Create options. // FIXME(meng): implement the cancellation.
ctranslate2::TranslationOptions options; if (is_context_cancelled(std::move(context))) {
options.max_decoding_length = max_decoding_length; return rust::Vec<rust::String>();
options.sampling_temperature = sampling_temperature; }
options.beam_size = beam_size;
// Inference. // Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0]; const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size});
const auto& output_tokens = result.output();
// Convert to rust vec. // Convert to rust vec.
rust::Vec<rust::String> output; rust::Vec<rust::String> output;
@ -33,48 +41,43 @@ class EncoderDecoderImpl: public TextInferenceEngine {
} }
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) { static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<EncoderDecoderImpl>(); auto impl = std::make_unique<Child>();
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader); impl->model_ = std::make_unique<Model>(loader);
return impl; return impl;
} }
private:
std::unique_ptr<ctranslate2::Translator> translator_; protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0;
std::unique_ptr<Model> model_;
}; };
class DecoderImpl: public TextInferenceEngine { class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
public: protected:
rust::Vec<rust::String> inference( virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
rust::Slice<const rust::String> tokens, ctranslate2::TranslationOptions x;
size_t max_decoding_length, x.max_decoding_length = options.max_decoding_length;
float sampling_temperature, x.sampling_temperature = options.sampling_temperature;
size_t beam_size x.beam_size = options.beam_size;
) const { ctranslate2::TranslationResult result = model_->translate_batch(
// Create options. { tokens },
ctranslate2::GenerationOptions options; ctranslate2::TranslationOptions{
options.include_prompt_in_result = false; }
options.max_length = max_decoding_length; )[0];
options.sampling_temperature = sampling_temperature; return std::move(result.output());
options.beam_size = beam_size;
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
const auto& output_tokens = result.sequences[0];
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
} }
};
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) { class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
auto impl = std::make_unique<DecoderImpl>(); protected:
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader); virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
return impl; ctranslate2::GenerationOptions x;
x.include_prompt_in_result = false;
x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = options.beam_size;
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]);
} }
private:
std::unique_ptr<ctranslate2::Generator> generator_;
}; };
std::shared_ptr<TextInferenceEngine> create_engine( std::shared_ptr<TextInferenceEngine> create_engine(

View File

@ -1,10 +1,15 @@
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken;
#[macro_use] #[macro_use]
extern crate derive_builder; extern crate derive_builder;
#[cxx::bridge(namespace = "tabby")] #[cxx::bridge(namespace = "tabby")]
mod ffi { mod ffi {
extern "Rust" {
type InferenceContext;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("ctranslate2-bindings/include/ctranslate2.h"); include!("ctranslate2-bindings/include/ctranslate2.h");
@ -20,6 +25,8 @@ mod ffi {
fn inference( fn inference(
&self, &self,
context: Box<InferenceContext>,
is_context_cancelled: fn(Box<InferenceContext>) -> bool,
tokens: &[String], tokens: &[String],
max_decoding_length: usize, max_decoding_length: usize,
sampling_temperature: f32, sampling_temperature: f32,
@ -58,6 +65,8 @@ pub struct TextInferenceOptions {
beam_size: usize, beam_size: usize,
} }
struct InferenceContext(CancellationToken);
pub struct TextInferenceEngine { pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
@ -81,8 +90,17 @@ impl TextInferenceEngine {
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap(); let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone(); let engine = self.engine.clone();
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let context = InferenceContext(cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || { let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference( engine.inference(
context,
|context| context.0.is_cancelled(),
encoding.get_tokens(), encoding.get_tokens(),
options.max_decoding_length, options.max_decoding_length,
options.sampling_temperature, options.sampling_temperature,