parent
72ed30e9ff
commit
552711a560
|
|
@ -18,6 +18,7 @@ class TextInferenceEngine {
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
rust::Slice<const int32_t> device_indices,
|
rust::Slice<const int32_t> device_indices,
|
||||||
size_t num_replicas_per_device
|
size_t num_replicas_per_device
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,13 @@
|
||||||
#include "ctranslate2-bindings/include/ctranslate2.h"
|
#include "ctranslate2-bindings/include/ctranslate2.h"
|
||||||
|
|
||||||
#include "ctranslate2/translator.h"
|
#include "ctranslate2/translator.h"
|
||||||
|
#include "ctranslate2/generator.h"
|
||||||
|
|
||||||
namespace tabby {
|
namespace tabby {
|
||||||
TextInferenceEngine::~TextInferenceEngine() {}
|
TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
class EncoderDecoderImpl: public TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
|
|
||||||
|
|
||||||
~TextInferenceEngineImpl() {}
|
|
||||||
|
|
||||||
rust::Vec<rust::String> inference(
|
rust::Vec<rust::String> inference(
|
||||||
rust::Slice<const rust::String> tokens,
|
rust::Slice<const rust::String> tokens,
|
||||||
size_t max_decoding_length,
|
size_t max_decoding_length,
|
||||||
|
|
@ -34,36 +31,72 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||||
|
auto impl = std::make_unique<EncoderDecoderImpl>();
|
||||||
|
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||||
|
return impl;
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DecoderImpl: public TextInferenceEngine {
|
||||||
|
public:
|
||||||
|
rust::Vec<rust::String> inference(
|
||||||
|
rust::Slice<const rust::String> tokens,
|
||||||
|
size_t max_decoding_length,
|
||||||
|
float sampling_temperature,
|
||||||
|
size_t beam_size
|
||||||
|
) const {
|
||||||
|
// Create options.
|
||||||
|
ctranslate2::GenerationOptions options;
|
||||||
|
options.include_prompt_in_result = false;
|
||||||
|
options.max_length = max_decoding_length;
|
||||||
|
options.sampling_temperature = sampling_temperature;
|
||||||
|
options.beam_size = beam_size;
|
||||||
|
|
||||||
|
// Inference.
|
||||||
|
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||||
|
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
|
||||||
|
const auto& output_tokens = result.sequences[0];
|
||||||
|
|
||||||
|
// Convert to rust vec.
|
||||||
|
rust::Vec<rust::String> output;
|
||||||
|
output.reserve(output_tokens.size());
|
||||||
|
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||||
|
auto impl = std::make_unique<DecoderImpl>();
|
||||||
|
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
|
||||||
|
return impl;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ctranslate2::Generator> generator_;
|
||||||
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
rust::Slice<const int32_t> device_indices,
|
rust::Slice<const int32_t> device_indices,
|
||||||
size_t num_replicas_per_device
|
size_t num_replicas_per_device
|
||||||
) {
|
) {
|
||||||
// model_path.
|
std::string model_type_str(model_type);
|
||||||
std::string model_path_string(model_path);
|
std::string model_path_str(model_path);
|
||||||
ctranslate2::models::ModelLoader loader(model_path_string);
|
ctranslate2::models::ModelLoader loader(model_path_str);
|
||||||
|
loader.device = ctranslate2::str_to_device(std::string(device));
|
||||||
// device.
|
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||||
std::string device_string(device);
|
|
||||||
if (device_string == "cuda") {
|
|
||||||
loader.device = ctranslate2::Device::CUDA;
|
|
||||||
} else if (device_string == "cpu") {
|
|
||||||
loader.device = ctranslate2::Device::CPU;
|
|
||||||
}
|
|
||||||
|
|
||||||
// device_indices
|
|
||||||
loader.device_indices.clear();
|
|
||||||
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
|
|
||||||
|
|
||||||
// num_replicas_per_device
|
|
||||||
loader.num_replicas_per_device = num_replicas_per_device;
|
loader.num_replicas_per_device = num_replicas_per_device;
|
||||||
|
|
||||||
auto translator = std::make_unique<ctranslate2::Translator>(loader);
|
if (model_type_str == "decoder") {
|
||||||
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
|
return DecoderImpl::create(loader);
|
||||||
|
} else if (model_type_str == "encoder-decoder") {
|
||||||
|
return EncoderDecoderImpl::create(loader);
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace tabby
|
} // namespace tabby
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use tokenizers::tokenizer::{Model, Tokenizer};
|
use tokenizers::tokenizer::{Tokenizer};
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate derive_builder;
|
extern crate derive_builder;
|
||||||
|
|
@ -12,6 +12,7 @@ mod ffi {
|
||||||
|
|
||||||
fn create_engine(
|
fn create_engine(
|
||||||
model_path: &str,
|
model_path: &str,
|
||||||
|
model_type: &str,
|
||||||
device: &str,
|
device: &str,
|
||||||
device_indices: &[i32],
|
device_indices: &[i32],
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
|
|
@ -31,6 +32,8 @@ mod ffi {
|
||||||
pub struct TextInferenceEngineCreateOptions {
|
pub struct TextInferenceEngineCreateOptions {
|
||||||
model_path: String,
|
model_path: String,
|
||||||
|
|
||||||
|
model_type: String,
|
||||||
|
|
||||||
tokenizer_path: String,
|
tokenizer_path: String,
|
||||||
|
|
||||||
device: String,
|
device: String,
|
||||||
|
|
@ -64,6 +67,7 @@ impl TextInferenceEngine {
|
||||||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||||
let engine = ffi::create_engine(
|
let engine = ffi::create_engine(
|
||||||
&options.model_path,
|
&options.model_path,
|
||||||
|
&options.model_type,
|
||||||
&options.device,
|
&options.device,
|
||||||
&options.device_indices,
|
&options.device_indices,
|
||||||
options.num_replicas_per_device,
|
options.num_replicas_per_device,
|
||||||
|
|
@ -82,11 +86,14 @@ impl TextInferenceEngine {
|
||||||
options.sampling_temperature,
|
options.sampling_temperature,
|
||||||
options.beam_size,
|
options.beam_size,
|
||||||
);
|
);
|
||||||
|
|
||||||
let model = self.tokenizer.get_model();
|
|
||||||
let output_ids: Vec<u32> = output_tokens
|
let output_ids: Vec<u32> = output_tokens
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| model.token_to_id(x).unwrap())
|
.filter_map(|x| {
|
||||||
|
match self.tokenizer.token_to_id(x) {
|
||||||
|
Some(y) => Some(y),
|
||||||
|
None => { println!("Warning: token ({}) missed in vocab", x); None }
|
||||||
|
}
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
self.tokenizer.decode(output_ids, true).unwrap()
|
self.tokenizer.decode(output_ids, true).unwrap()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,32 @@ impl std::fmt::Display for Device {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(clap::ValueEnum, Clone)]
|
||||||
|
pub enum ModelType {
|
||||||
|
EncoderDecoder,
|
||||||
|
Decoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ModelType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
let printable = match *self {
|
||||||
|
ModelType::EncoderDecoder => "encoder-decoder",
|
||||||
|
ModelType::Decoder => "decoder",
|
||||||
|
};
|
||||||
|
write!(f, "{}", printable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct ServeArgs {
|
pub struct ServeArgs {
|
||||||
/// path to model for serving
|
/// path to model for serving
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
model: String,
|
model: String,
|
||||||
|
|
||||||
|
/// model type for serving
|
||||||
|
#[clap(long, default_value_t=ModelType::Decoder)]
|
||||||
|
model_type: ModelType,
|
||||||
|
|
||||||
#[clap(long, default_value_t = 8080)]
|
#[clap(long, default_value_t = 8080)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
|
|
@ -79,6 +99,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||||
.to_string(),
|
.to_string(),
|
||||||
)
|
)
|
||||||
.device(device)
|
.device(device)
|
||||||
|
.model_type(format!("{}", args.model_type))
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
.build()
|
.build()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue