refactor: extract TextInferenceEngineImpl to reduce duplications between EncoderDecoderImpl and DecoderImpl #189
parent
6de61f45bb
commit
2bf5bcd0cf
|
|
@ -491,6 +491,7 @@ dependencies = [
|
||||||
"rust-cxx-cmake-bridge",
|
"rust-cxx-cmake-bridge",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
|
|
@ -17,3 +17,4 @@ lazy_static = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serdeconv = "0.4.1"
|
serdeconv = "0.4.1"
|
||||||
tokio = "1.28"
|
tokio = "1.28"
|
||||||
|
tokio-util = "0.7"
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ cxx = "1.0"
|
||||||
derive_builder = "0.12.0"
|
derive_builder = "0.12.0"
|
||||||
tokenizers = "0.13.3"
|
tokenizers = "0.13.3"
|
||||||
tokio = { workspace = true, features = ["rt"] }
|
tokio = { workspace = true, features = ["rt"] }
|
||||||
|
tokio-util = { workspace = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cxx-build = "1.0"
|
cxx-build = "1.0"
|
||||||
|
|
@ -15,5 +16,5 @@ cmake = { version = "0.1", optional = true }
|
||||||
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
|
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "dep:cmake", "dep:rust-cxx-cmake-bridge" ]
|
default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
|
||||||
link_shared = []
|
link_shared = []
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,14 @@
|
||||||
|
|
||||||
namespace tabby {
|
namespace tabby {
|
||||||
|
|
||||||
|
struct InferenceContext;
|
||||||
|
|
||||||
class TextInferenceEngine {
|
class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
virtual rust::Vec<rust::String> inference(
|
virtual rust::Vec<rust::String> inference(
|
||||||
|
rust::Box<InferenceContext> context,
|
||||||
|
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
size_t max_decoding_length,
|
size_t max_decoding_length,
|
||||||
float sampling_temperature,
|
float sampling_temperature,
|
||||||
|
|
|
||||||
|
|
@ -6,24 +6,32 @@
|
||||||
namespace tabby {
|
namespace tabby {
|
||||||
TextInferenceEngine::~TextInferenceEngine() {}
|
TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
class EncoderDecoderImpl: public TextInferenceEngine {
|
template <class Model, class Child>
|
||||||
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
|
protected:
|
||||||
|
struct Options {
|
||||||
|
size_t max_decoding_length;
|
||||||
|
float sampling_temperature;
|
||||||
|
size_t beam_size;
|
||||||
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
rust::Vec<rust::String> inference(
|
rust::Vec<rust::String> inference(
|
||||||
|
rust::Box<InferenceContext> context,
|
||||||
|
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
size_t max_decoding_length,
|
size_t max_decoding_length,
|
||||||
float sampling_temperature,
|
float sampling_temperature,
|
||||||
size_t beam_size
|
size_t beam_size
|
||||||
) const {
|
) const {
|
||||||
// Create options.
|
// FIXME(meng): implement the cancellation.
|
||||||
ctranslate2::TranslationOptions options;
|
if (is_context_cancelled(std::move(context))) {
|
||||||
options.max_decoding_length = max_decoding_length;
|
return rust::Vec<rust::String>();
|
||||||
options.sampling_temperature = sampling_temperature;
|
}
|
||||||
options.beam_size = beam_size;
|
|
||||||
|
|
||||||
// Inference.
|
// Inference.
|
||||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||||
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
|
const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size});
|
||||||
const auto& output_tokens = result.output();
|
|
||||||
|
|
||||||
// Convert to rust vec.
|
// Convert to rust vec.
|
||||||
rust::Vec<rust::String> output;
|
rust::Vec<rust::String> output;
|
||||||
|
|
@ -33,48 +41,43 @@ class EncoderDecoderImpl: public TextInferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||||
auto impl = std::make_unique<EncoderDecoderImpl>();
|
auto impl = std::make_unique<Child>();
|
||||||
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
impl->model_ = std::make_unique<Model>(loader);
|
||||||
return impl;
|
return impl;
|
||||||
}
|
}
|
||||||
private:
|
|
||||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
protected:
|
||||||
|
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0;
|
||||||
|
std::unique_ptr<Model> model_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DecoderImpl: public TextInferenceEngine {
|
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||||
public:
|
protected:
|
||||||
rust::Vec<rust::String> inference(
|
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||||
rust::Slice<const rust::String> tokens,
|
ctranslate2::TranslationOptions x;
|
||||||
size_t max_decoding_length,
|
x.max_decoding_length = options.max_decoding_length;
|
||||||
float sampling_temperature,
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
size_t beam_size
|
x.beam_size = options.beam_size;
|
||||||
) const {
|
ctranslate2::TranslationResult result = model_->translate_batch(
|
||||||
// Create options.
|
{ tokens },
|
||||||
ctranslate2::GenerationOptions options;
|
ctranslate2::TranslationOptions{
|
||||||
options.include_prompt_in_result = false;
|
}
|
||||||
options.max_length = max_decoding_length;
|
)[0];
|
||||||
options.sampling_temperature = sampling_temperature;
|
return std::move(result.output());
|
||||||
options.beam_size = beam_size;
|
|
||||||
|
|
||||||
// Inference.
|
|
||||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
|
||||||
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
|
|
||||||
const auto& output_tokens = result.sequences[0];
|
|
||||||
|
|
||||||
// Convert to rust vec.
|
|
||||||
rust::Vec<rust::String> output;
|
|
||||||
output.reserve(output_tokens.size());
|
|
||||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
|
||||||
return output;
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||||
auto impl = std::make_unique<DecoderImpl>();
|
protected:
|
||||||
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
|
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||||
return impl;
|
ctranslate2::GenerationOptions x;
|
||||||
|
x.include_prompt_in_result = false;
|
||||||
|
x.max_length = options.max_decoding_length;
|
||||||
|
x.sampling_temperature = options.sampling_temperature;
|
||||||
|
x.beam_size = options.beam_size;
|
||||||
|
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||||
|
return std::move(result.sequences[0]);
|
||||||
}
|
}
|
||||||
private:
|
|
||||||
std::unique_ptr<ctranslate2::Generator> generator_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,15 @@
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate derive_builder;
|
extern crate derive_builder;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "tabby")]
|
#[cxx::bridge(namespace = "tabby")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
extern "Rust" {
|
||||||
|
type InferenceContext;
|
||||||
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("ctranslate2-bindings/include/ctranslate2.h");
|
include!("ctranslate2-bindings/include/ctranslate2.h");
|
||||||
|
|
||||||
|
|
@ -20,6 +25,8 @@ mod ffi {
|
||||||
|
|
||||||
fn inference(
|
fn inference(
|
||||||
&self,
|
&self,
|
||||||
|
context: Box<InferenceContext>,
|
||||||
|
is_context_cancelled: fn(Box<InferenceContext>) -> bool,
|
||||||
tokens: &[String],
|
tokens: &[String],
|
||||||
max_decoding_length: usize,
|
max_decoding_length: usize,
|
||||||
sampling_temperature: f32,
|
sampling_temperature: f32,
|
||||||
|
|
@ -58,6 +65,8 @@ pub struct TextInferenceOptions {
|
||||||
beam_size: usize,
|
beam_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct InferenceContext(CancellationToken);
|
||||||
|
|
||||||
pub struct TextInferenceEngine {
|
pub struct TextInferenceEngine {
|
||||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
|
@ -81,8 +90,17 @@ impl TextInferenceEngine {
|
||||||
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
let engine = self.engine.clone();
|
let engine = self.engine.clone();
|
||||||
|
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
let cancel_for_inference = cancel.clone();
|
||||||
|
let _guard = cancel.drop_guard();
|
||||||
|
|
||||||
|
let context = InferenceContext(cancel_for_inference);
|
||||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||||
|
let context = Box::new(context);
|
||||||
engine.inference(
|
engine.inference(
|
||||||
|
context,
|
||||||
|
|context| context.0.is_cancelled(),
|
||||||
encoding.get_tokens(),
|
encoding.get_tokens(),
|
||||||
options.max_decoding_length,
|
options.max_decoding_length,
|
||||||
options.sampling_temperature,
|
options.sampling_temperature,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue