parent
72ed30e9ff
commit
552711a560
|
|
@ -18,6 +18,7 @@ class TextInferenceEngine {
|
|||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices,
|
||||
size_t num_replicas_per_device
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
#include "ctranslate2-bindings/include/ctranslate2.h"
|
||||
|
||||
#include "ctranslate2/translator.h"
|
||||
#include "ctranslate2/generator.h"
|
||||
|
||||
namespace tabby {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
class EncoderDecoderImpl: public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
|
||||
|
||||
~TextInferenceEngineImpl() {}
|
||||
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
|
|
@ -34,36 +31,72 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
auto impl = std::make_unique<EncoderDecoderImpl>();
|
||||
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||
return impl;
|
||||
}
|
||||
private:
|
||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||
};
|
||||
|
||||
class DecoderImpl: public TextInferenceEngine {
|
||||
public:
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
) const {
|
||||
// Create options.
|
||||
ctranslate2::GenerationOptions options;
|
||||
options.include_prompt_in_result = false;
|
||||
options.max_length = max_decoding_length;
|
||||
options.sampling_temperature = sampling_temperature;
|
||||
options.beam_size = beam_size;
|
||||
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
|
||||
const auto& output_tokens = result.sequences[0];
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
output.reserve(output_tokens.size());
|
||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
}
|
||||
|
||||
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
|
||||
auto impl = std::make_unique<DecoderImpl>();
|
||||
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
|
||||
return impl;
|
||||
}
|
||||
private:
|
||||
std::unique_ptr<ctranslate2::Generator> generator_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
rust::Slice<const int32_t> device_indices,
|
||||
size_t num_replicas_per_device
|
||||
) {
|
||||
// model_path.
|
||||
std::string model_path_string(model_path);
|
||||
ctranslate2::models::ModelLoader loader(model_path_string);
|
||||
|
||||
// device.
|
||||
std::string device_string(device);
|
||||
if (device_string == "cuda") {
|
||||
loader.device = ctranslate2::Device::CUDA;
|
||||
} else if (device_string == "cpu") {
|
||||
loader.device = ctranslate2::Device::CPU;
|
||||
}
|
||||
|
||||
// device_indices
|
||||
loader.device_indices.clear();
|
||||
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
|
||||
|
||||
// num_replicas_per_device
|
||||
std::string model_type_str(model_type);
|
||||
std::string model_path_str(model_path);
|
||||
ctranslate2::models::ModelLoader loader(model_path_str);
|
||||
loader.device = ctranslate2::str_to_device(std::string(device));
|
||||
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||
loader.num_replicas_per_device = num_replicas_per_device;
|
||||
|
||||
auto translator = std::make_unique<ctranslate2::Translator>(loader);
|
||||
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
|
||||
if (model_type_str == "decoder") {
|
||||
return DecoderImpl::create(loader);
|
||||
} else if (model_type_str == "encoder-decoder") {
|
||||
return EncoderDecoderImpl::create(loader);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace tabby
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use tokenizers::tokenizer::{Model, Tokenizer};
|
||||
use tokenizers::tokenizer::{Tokenizer};
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
|
@ -12,6 +12,7 @@ mod ffi {
|
|||
|
||||
fn create_engine(
|
||||
model_path: &str,
|
||||
model_type: &str,
|
||||
device: &str,
|
||||
device_indices: &[i32],
|
||||
num_replicas_per_device: usize,
|
||||
|
|
@ -31,6 +32,8 @@ mod ffi {
|
|||
pub struct TextInferenceEngineCreateOptions {
|
||||
model_path: String,
|
||||
|
||||
model_type: String,
|
||||
|
||||
tokenizer_path: String,
|
||||
|
||||
device: String,
|
||||
|
|
@ -64,6 +67,7 @@ impl TextInferenceEngine {
|
|||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||
let engine = ffi::create_engine(
|
||||
&options.model_path,
|
||||
&options.model_type,
|
||||
&options.device,
|
||||
&options.device_indices,
|
||||
options.num_replicas_per_device,
|
||||
|
|
@ -82,11 +86,14 @@ impl TextInferenceEngine {
|
|||
options.sampling_temperature,
|
||||
options.beam_size,
|
||||
);
|
||||
|
||||
let model = self.tokenizer.get_model();
|
||||
let output_ids: Vec<u32> = output_tokens
|
||||
.iter()
|
||||
.map(|x| model.token_to_id(x).unwrap())
|
||||
.filter_map(|x| {
|
||||
match self.tokenizer.token_to_id(x) {
|
||||
Some(y) => Some(y),
|
||||
None => { println!("Warning: token ({}) missed in vocab", x); None }
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,12 +43,32 @@ impl std::fmt::Display for Device {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Clone)]
|
||||
pub enum ModelType {
|
||||
EncoderDecoder,
|
||||
Decoder,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let printable = match *self {
|
||||
ModelType::EncoderDecoder => "encoder-decoder",
|
||||
ModelType::Decoder => "decoder",
|
||||
};
|
||||
write!(f, "{}", printable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// path to model for serving
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
/// model type for serving
|
||||
#[clap(long, default_value_t=ModelType::Decoder)]
|
||||
model_type: ModelType,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
|
|
@ -79,6 +99,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
|||
.to_string(),
|
||||
)
|
||||
.device(device)
|
||||
.model_type(format!("{}", args.model_type))
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.build()
|
||||
|
|
|
|||
Loading…
Reference in New Issue