refactor: extract TextInferenceEngineImpl to reduce duplications between EncoderDecoderImpl and DecoderImpl #189
parent
6de61f45bb
commit
2bf5bcd0cf
|
|
@ -491,6 +491,7 @@ dependencies = [
|
|||
"rust-cxx-cmake-bridge",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -17,3 +17,4 @@ lazy_static = "1.4.0"
|
|||
serde = { version = "1.0", features = ["derive"] }
|
||||
serdeconv = "0.4.1"
|
||||
tokio = "1.28"
|
||||
tokio-util = "0.7"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ cxx = "1.0"
|
|||
derive_builder = "0.12.0"
|
||||
tokenizers = "0.13.3"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
@ -15,5 +16,5 @@ cmake = { version = "0.1", optional = true }
|
|||
rust-cxx-cmake-bridge = { path = "../rust-cxx-cmake-bridge", optional = true }
|
||||
|
||||
[features]
|
||||
default = [ "dep:cmake", "dep:rust-cxx-cmake-bridge" ]
|
||||
default = ["dep:cmake", "dep:rust-cxx-cmake-bridge"]
|
||||
link_shared = []
|
||||
|
|
|
|||
|
|
@ -5,10 +5,14 @@
|
|||
|
||||
namespace tabby {
|
||||
|
||||
struct InferenceContext;
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
|
|
|
|||
|
|
@ -6,24 +6,32 @@
|
|||
namespace tabby {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
class EncoderDecoderImpl: public TextInferenceEngine {
|
||||
template <class Model, class Child>
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
protected:
|
||||
struct Options {
|
||||
size_t max_decoding_length;
|
||||
float sampling_temperature;
|
||||
size_t beam_size;
|
||||
};
|
||||
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Box<InferenceContext> context,
|
||||
rust::Fn<bool(rust::Box<InferenceContext>)> is_context_cancelled,
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
) const {
|
||||
// Create options.
|
||||
ctranslate2::TranslationOptions options;
|
||||
options.max_decoding_length = max_decoding_length;
|
||||
options.sampling_temperature = sampling_temperature;
|
||||
options.beam_size = beam_size;
|
||||
// FIXME(meng): implement the cancellation.
|
||||
if (is_context_cancelled(std::move(context))) {
|
||||
return rust::Vec<rust::String>();
|
||||
}
|
||||
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
|
||||
const auto& output_tokens = result.output();
|
||||
const auto output_tokens = process(input_tokens, Options{max_decoding_length, sampling_temperature, beam_size});
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
|
|
@ -33,48 +41,43 @@ class EncoderDecoderImpl: public TextInferenceEngine {
|
|||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
auto impl = std::make_unique<EncoderDecoderImpl>();
|
||||
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||
auto impl = std::make_unique<Child>();
|
||||
impl->model_ = std::make_unique<Model>(loader);
|
||||
return impl;
|
||||
}
|
||||
private:
|
||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const = 0;
|
||||
std::unique_ptr<Model> model_;
|
||||
};
|
||||
|
||||
class DecoderImpl: public TextInferenceEngine {
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
) const {
|
||||
// Create options.
|
||||
ctranslate2::GenerationOptions options;
|
||||
options.include_prompt_in_result = false;
|
||||
options.max_length = max_decoding_length;
|
||||
options.sampling_temperature = sampling_temperature;
|
||||
options.beam_size = beam_size;
|
||||
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
|
||||
const auto& output_tokens = result.sequences[0];
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
output.reserve(output_tokens.size());
|
||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
class EncoderDecoderImpl : public TextInferenceEngineImpl<ctranslate2::Translator, EncoderDecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||
ctranslate2::TranslationOptions x;
|
||||
x.max_decoding_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = options.beam_size;
|
||||
ctranslate2::TranslationResult result = model_->translate_batch(
|
||||
{ tokens },
|
||||
ctranslate2::TranslationOptions{
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
auto impl = std::make_unique<DecoderImpl>();
|
||||
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
|
||||
return impl;
|
||||
)[0];
|
||||
return std::move(result.output());
|
||||
}
|
||||
};
|
||||
|
||||
class DecoderImpl : public TextInferenceEngineImpl<ctranslate2::Generator, DecoderImpl> {
|
||||
protected:
|
||||
virtual std::vector<std::string> process(const std::vector<std::string>& tokens, const Options& options) const override {
|
||||
ctranslate2::GenerationOptions x;
|
||||
x.include_prompt_in_result = false;
|
||||
x.max_length = options.max_decoding_length;
|
||||
x.sampling_temperature = options.sampling_temperature;
|
||||
x.beam_size = options.beam_size;
|
||||
ctranslate2::GenerationResult result = model_->generate_batch_async({ tokens }, x)[0].get();
|
||||
return std::move(result.sequences[0]);
|
||||
}
|
||||
private:
|
||||
std::unique_ptr<ctranslate2::Generator> generator_;
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
#[cxx::bridge(namespace = "tabby")]
|
||||
mod ffi {
|
||||
extern "Rust" {
|
||||
type InferenceContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("ctranslate2-bindings/include/ctranslate2.h");
|
||||
|
||||
|
|
@ -20,6 +25,8 @@ mod ffi {
|
|||
|
||||
fn inference(
|
||||
&self,
|
||||
context: Box<InferenceContext>,
|
||||
is_context_cancelled: fn(Box<InferenceContext>) -> bool,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
|
|
@ -58,6 +65,8 @@ pub struct TextInferenceOptions {
|
|||
beam_size: usize,
|
||||
}
|
||||
|
||||
struct InferenceContext(CancellationToken);
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
|
|
@ -81,8 +90,17 @@ impl TextInferenceEngine {
|
|||
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let engine = self.engine.clone();
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let context = InferenceContext(cancel_for_inference);
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
context,
|
||||
|context| context.0.is_cancelled(),
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
|
|
|
|||
Loading…
Reference in New Issue