refactor: extract TextInferenceEngineImpl to reduce duplications between EncoderDecoderImpl and DecoderImpl #189

docs-add-demo
Meng Zhang 2023-06-04 15:28:39 -07:00 committed by GitHub
parent 6de61f45bb
commit 2bf5bcd0cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 44 deletions

1
Cargo.lock generated
View File

@ -491,6 +491,7 @@ dependencies = [
"rust-cxx-cmake-bridge",
"tokenizers",
"tokio",
"tokio-util",
]
[[package]]

View File

@ -17,3 +17,4 @@ lazy_static = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
serdeconv = "0.4.1"
tokio = "1.28"
tokio-util = "0.7"

View File

@ -8,6 +8,7 @@ cxx = "1.0"
derive_builder = "0.12.0"
tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true }
[build-dependencies]
cxx-build = "1.0"
@ -15,5 +16,5 @@ cmake = { version = "0.1", optional = true }
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
[features]
default = [ "dep:cmake", "dep:rust-cxx-cmake-bridge" ]
default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
link_shared = []

View File

@ -5,10 +5,14 @@
namespace tabby {
struct InferenceContext;
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,

View File

@ -6,24 +6,32 @@
namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {}
class EncoderDecoderImpl: public TextInferenceEngine {
template <class Model, class Child>
class TextInferenceEngineImpl : public TextInferenceEngine {
protected:
struct Options {
size_t max_decoding_length;
float sampling_temperature;
size_t beam_size;
};
public:
rust::Vec<rust::String> inference(
rust::Box<InferenceContext> context,
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::TranslationOptions options;
options.max_decoding_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;
// FIXME(meng): implement the cancellation.
if (is_context_cancelled(std::move(context))) {
return rust::Vec<rust::String>();
}
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
const auto& output_tokens = result.output();
const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size});
// Convert to rust vec.
rust::Vec<rust::String> output;
@ -33,48 +41,43 @@ class EncoderDecoderImpl: public TextInferenceEngine {
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<EncoderDecoderImpl>();
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
auto impl = std::make_unique<Child>();
impl->model_ = std::make_unique<Model>(loader);
return impl;
}
private:
std::unique_ptr<ctranslate2::Translator> translator_;
protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0;
std::unique_ptr<Model> model_;
};
class DecoderImpl: public TextInferenceEngine {
public:
rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::GenerationOptions options;
options.include_prompt_in_result = false;
options.max_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
const auto& output_tokens = result.sequences[0];
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
ctranslate2::TranslationOptions x;
x.max_decoding_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = options.beam_size;
ctranslate2::TranslationResult result = model_->translate_batch(
{ tokens },
ctranslate2::TranslationOptions{
}
)[0];
return std::move(result.output());
}
};
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<DecoderImpl>();
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
return impl;
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
protected:
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
ctranslate2::GenerationOptions x;
x.include_prompt_in_result = false;
x.max_length = options.max_decoding_length;
x.sampling_temperature = options.sampling_temperature;
x.beam_size = options.beam_size;
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
return std::move(result.sequences[0]);
}
private:
std::unique_ptr<ctranslate2::Generator> generator_;
};
std::shared_ptr<TextInferenceEngine> create_engine(

View File

@ -1,10 +1,15 @@
use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken;
#[macro_use]
extern crate derive_builder;
#[cxx::bridge(namespace = "tabby")]
mod ffi {
extern "Rust" {
type InferenceContext;
}
unsafe extern "C++" {
include!("ctranslate2-bindings/include/ctranslate2.h");
@ -20,6 +25,8 @@ mod ffi {
fn inference(
&self,
context: Box<InferenceContext>,
is_context_cancelled: fn(Box<InferenceContext>) -> bool,
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
@ -58,6 +65,8 @@ pub struct TextInferenceOptions {
beam_size: usize,
}
struct InferenceContext(CancellationToken);
pub struct TextInferenceEngine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer,
@ -81,8 +90,17 @@ impl TextInferenceEngine {
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let engine = self.engine.clone();
let cancel = CancellationToken::new();
let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard();
let context = InferenceContext(cancel_for_inference);
let output_tokens = tokio::task::spawn_blocking(move || {
let context = Box::new(context);
engine.inference(
context,
|context| context.0.is_cancelled(),
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,