Support causal lm (decoder only model) (#151)

* support

* support causal lm
add-tracing
Meng Zhang 2023-05-27 01:26:33 -07:00 committed by GitHub
parent 72ed30e9ff
commit 552711a560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 28 deletions

View File

@ -18,6 +18,7 @@ class TextInferenceEngine {
std::unique_ptr<TextInferenceEngine> create_engine( std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type,
rust::Str device, rust::Str device,
rust::Slice<const int32_t> device_indices, rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device size_t num_replicas_per_device

View File

@ -1,16 +1,13 @@
#include "ctranslate2-bindings/include/ctranslate2.h" #include "ctranslate2-bindings/include/ctranslate2.h"
#include "ctranslate2/translator.h" #include "ctranslate2/translator.h"
#include "ctranslate2/generator.h"
namespace tabby { namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {} TextInferenceEngine::~TextInferenceEngine() {}
class TextInferenceEngineImpl : public TextInferenceEngine { class EncoderDecoderImpl: public TextInferenceEngine {
public: public:
TextInferenceEngineImpl(std::unique_ptr<ctranslate2::Translator> translator) : translator_(std::move(translator)) {}
~TextInferenceEngineImpl() {}
rust::Vec<rust::String> inference( rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens, rust::Slice<const rust::String> tokens,
size_t max_decoding_length, size_t max_decoding_length,
@ -34,36 +31,72 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output; return output;
} }
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<EncoderDecoderImpl>();
impl->translator_ = std::make_unique<ctranslate2::Translator>(loader);
return impl;
}
private: private:
std::unique_ptr<ctranslate2::Translator> translator_; std::unique_ptr<ctranslate2::Translator> translator_;
}; };
class DecoderImpl: public TextInferenceEngine {
public:
rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::GenerationOptions options;
options.include_prompt_in_result = false;
options.max_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get();
const auto& output_tokens = result.sequences[0];
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
static std::unique_ptr<TextInferenceEngine> create(const ctranslate2::models::ModelLoader& loader) {
auto impl = std::make_unique<DecoderImpl>();
impl->generator_ = std::make_unique<ctranslate2::Generator>(loader);
return impl;
}
private:
std::unique_ptr<ctranslate2::Generator> generator_;
};
std::unique_ptr<TextInferenceEngine> create_engine( std::unique_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type,
rust::Str device, rust::Str device,
rust::Slice<const int32_t> device_indices, rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device size_t num_replicas_per_device
) { ) {
// model_path. std::string model_type_str(model_type);
std::string model_path_string(model_path); std::string model_path_str(model_path);
ctranslate2::models::ModelLoader loader(model_path_string); ctranslate2::models::ModelLoader loader(model_path_str);
loader.device = ctranslate2::str_to_device(std::string(device));
// device. loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
std::string device_string(device);
if (device_string == "cuda") {
loader.device = ctranslate2::Device::CUDA;
} else if (device_string == "cpu") {
loader.device = ctranslate2::Device::CPU;
}
// device_indices
loader.device_indices.clear();
std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices));
// num_replicas_per_device
loader.num_replicas_per_device = num_replicas_per_device; loader.num_replicas_per_device = num_replicas_per_device;
auto translator = std::make_unique<ctranslate2::Translator>(loader); if (model_type_str == "decoder") {
return std::make_unique<TextInferenceEngineImpl>(std::move(translator)); return DecoderImpl::create(loader);
} else if (model_type_str == "encoder-decoder") {
return EncoderDecoderImpl::create(loader);
} else {
return nullptr;
}
} }
} // namespace tabby } // namespace tabby

View File

@ -1,4 +1,4 @@
use tokenizers::tokenizer::{Model, Tokenizer}; use tokenizers::tokenizer::{Tokenizer};
#[macro_use] #[macro_use]
extern crate derive_builder; extern crate derive_builder;
@ -12,6 +12,7 @@ mod ffi {
fn create_engine( fn create_engine(
model_path: &str, model_path: &str,
model_type: &str,
device: &str, device: &str,
device_indices: &[i32], device_indices: &[i32],
num_replicas_per_device: usize, num_replicas_per_device: usize,
@ -31,6 +32,8 @@ mod ffi {
pub struct TextInferenceEngineCreateOptions { pub struct TextInferenceEngineCreateOptions {
model_path: String, model_path: String,
model_type: String,
tokenizer_path: String, tokenizer_path: String,
device: String, device: String,
@ -64,6 +67,7 @@ impl TextInferenceEngine {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
let engine = ffi::create_engine( let engine = ffi::create_engine(
&options.model_path, &options.model_path,
&options.model_type,
&options.device, &options.device,
&options.device_indices, &options.device_indices,
options.num_replicas_per_device, options.num_replicas_per_device,
@ -82,11 +86,14 @@ impl TextInferenceEngine {
options.sampling_temperature, options.sampling_temperature,
options.beam_size, options.beam_size,
); );
let model = self.tokenizer.get_model();
let output_ids: Vec<u32> = output_tokens let output_ids: Vec<u32> = output_tokens
.iter() .iter()
.map(|x| model.token_to_id(x).unwrap()) .filter_map(|x| {
match self.tokenizer.token_to_id(x) {
Some(y) => Some(y),
None => { println!("Warning: token ({}) missed in vocab", x); None }
}
})
.collect(); .collect();
self.tokenizer.decode(output_ids, true).unwrap() self.tokenizer.decode(output_ids, true).unwrap()
} }

View File

@ -43,12 +43,32 @@ impl std::fmt::Display for Device {
} }
} }
#[derive(clap::ValueEnum, Clone)]
pub enum ModelType {
EncoderDecoder,
Decoder,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let printable = match *self {
ModelType::EncoderDecoder => "encoder-decoder",
ModelType::Decoder => "decoder",
};
write!(f, "{}", printable)
}
}
#[derive(Args)] #[derive(Args)]
pub struct ServeArgs { pub struct ServeArgs {
/// path to model for serving /// path to model for serving
#[clap(long)] #[clap(long)]
model: String, model: String,
/// model type for serving
#[clap(long, default_value_t=ModelType::Decoder)]
model_type: ModelType,
#[clap(long, default_value_t = 8080)] #[clap(long, default_value_t = 8080)]
port: u16, port: u16,
@ -79,6 +99,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
.to_string(), .to_string(),
) )
.device(device) .device(device)
.model_type(format!("{}", args.model_type))
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(args.num_replicas_per_device)
.build() .build()