Support causal lm (decoder only model) (#151)

* support

* support causal lm
add-tracing
Meng Zhang 2023-05-27 01:26:33 -07:00 committed by GitHub
parent 72ed30e9ff
commit 552711a560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 28 deletions

View File

@ -18,6 +18,7 @@ class TextInferenceEngine {
std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device

View File

@ -1,16 +1,13 @@
#include "ctranslate2-bindings/include/ctranslate2.h"
#include "ctranslate2/translator.h"
#include "ctranslate2/generator.h"
namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {}
class TextInferenceEngineImpl : public TextInferenceEngine {
class EncoderDecoderImpl: public TextInferenceEngine {
public:
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
~TextInferenceEngineImpl() {}
rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
@ -34,36 +31,72 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<EncoderDecoderImpl>();
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
return impl;
}
private:
std::unique_ptr<ctranslate2::Translator> translator_;
};
class DecoderImpl: public TextInferenceEngine {
public:
rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::GenerationOptions options;
options.include_prompt_in_result = false;
options.max_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
const auto& output_tokens = result.sequences[0];
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<DecoderImpl>();
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
return impl;
}
private:
std::unique_ptr<ctranslate2::Generator> generator_;
};
std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device
) {
// model_path.
std::string model_path_string(model_path);
ctranslate2::models::ModelLoader loader(model_path_string);
// device.
std::string device_string(device);
if (device_string == "cuda") {
loader.device = ctranslate2::Device::CUDA;
} else if (device_string == "cpu") {
loader.device = ctranslate2::Device::CPU;
}
// device_indices
loader.device_indices.clear();
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
// num_replicas_per_device
std::string model_type_str(model_type);
std::string model_path_str(model_path);
ctranslate2::models::ModelLoader loader(model_path_str);
loader.device = ctranslate2::str_to_device(std::string(device));
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
loader.num_replicas_per_device = num_replicas_per_device;
auto translator = std::make_unique<ctranslate2::Translator>(loader);
return std::make_unique<TextInferenceEngineImpl>(std::move(translator));
if (model_type_str == "decoder") {
return DecoderImpl::create(loader);
} else if (model_type_str == "encoder-decoder") {
return EncoderDecoderImpl::create(loader);
} else {
return nullptr;
}
}
} // namespace tabby

View File

@ -1,4 +1,4 @@
use tokenizers::tokenizer::{Model, Tokenizer};
use tokenizers::tokenizer::{Tokenizer};
#[macro_use]
extern crate derive_builder;
@ -12,6 +12,7 @@ mod ffi {
fn create_engine(
model_path: &str,
model_type: &str,
device: &str,
device_indices: &[i32],
num_replicas_per_device: usize,
@ -31,6 +32,8 @@ mod ffi {
pub struct TextInferenceEngineCreateOptions {
model_path: String,
model_type: String,
tokenizer_path: String,
device: String,
@ -64,6 +67,7 @@ impl TextInferenceEngine {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
let engine = ffi::create_engine(
&options.model_path,
&options.model_type,
&options.device,
&options.device_indices,
options.num_replicas_per_device,
@ -82,11 +86,14 @@ impl TextInferenceEngine {
options.sampling_temperature,
options.beam_size,
);
let model = self.tokenizer.get_model();
let output_ids: Vec<u32> = output_tokens
.iter()
.map(|x| model.token_to_id(x).unwrap())
.filter_map(|x| {
match self.tokenizer.token_to_id(x) {
Some(y) => Some(y),
None => { println!("Warning: token ({}) missed in vocab", x); None }
}
})
.collect();
self.tokenizer.decode(output_ids, true).unwrap()
}

View File

@ -43,12 +43,32 @@ impl std::fmt::Display for Device {
}
}
#[derive(clap::ValueEnum, Clone)]
pub enum ModelType {
EncoderDecoder,
Decoder,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let printable = match *self {
ModelType::EncoderDecoder => "encoder-decoder",
ModelType::Decoder => "decoder",
};
write!(f, "{}", printable)
}
}
#[derive(Args)]
pub struct ServeArgs {
/// path to model for serving
#[clap(long)]
model: String,
/// model type for serving
#[clap(long, default_value_t=ModelType::Decoder)]
model_type: ModelType,
#[clap(long, default_value_t = 8080)]
port: u16,
@ -79,6 +99,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
.to_string(),
)
.device(device)
.model_type(format!("{}", args.model_type))
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.build()