From 552711a560d8b5ce016fab93d88d22515fbc4621 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 27 May 2023 01:26:33 -0700 Subject: [PATCH] Support causal lm (decoder only model) (#151) * support * support causal lm --- .../include/ctranslate2.h | 1 + .../ctranslate2-bindings/src/ctranslate2.cc | 81 +++++++++++++------ crates/ctranslate2-bindings/src/lib.rs | 15 +++- crates/tabby/src/serve/mod.rs | 21 +++++ 4 files changed, 90 insertions(+), 28 deletions(-) diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 3ed8f1b..514ba57 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -18,6 +18,7 @@ class TextInferenceEngine { std::unique_ptr create_engine( rust::Str model_path, + rust::Str model_type, rust::Str device, rust::Slice device_indices, size_t num_replicas_per_device diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 61dffb6..00c2b6f 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -1,16 +1,13 @@ #include "ctranslate2-bindings/include/ctranslate2.h" #include "ctranslate2/translator.h" +#include "ctranslate2/generator.h" namespace tabby { TextInferenceEngine::~TextInferenceEngine() {} -class TextInferenceEngineImpl : public TextInferenceEngine { +class EncoderDecoderImpl: public TextInferenceEngine { public: - TextInferenceEngineImpl(std::unique_ptr translator) : translator_(std::move(translator)) {} - - ~TextInferenceEngineImpl() {} - rust::Vec inference( rust::Slice tokens, size_t max_decoding_length, @@ -34,36 +31,72 @@ class TextInferenceEngineImpl : public TextInferenceEngine { std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); return output; } + + static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { + auto impl = std::make_unique(); + impl->translator_ = std::make_unique(loader); + return impl; + } private: std::unique_ptr translator_; }; +class DecoderImpl: public TextInferenceEngine { + public: + rust::Vec inference( + rust::Slice tokens, + size_t max_decoding_length, + float sampling_temperature, + size_t beam_size + ) const { + // Create options. + ctranslate2::GenerationOptions options; + options.include_prompt_in_result = false; + options.max_length = max_decoding_length; + options.sampling_temperature = sampling_temperature; + options.beam_size = beam_size; + + // Inference. + std::vector input_tokens(tokens.begin(), tokens.end()); + ctranslate2::GenerationResult result = generator_->generate_batch_async({ input_tokens }, options)[0].get(); + const auto& output_tokens = result.sequences[0]; + + // Convert to rust vec. + rust::Vec output; + output.reserve(output_tokens.size()); + std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); + return output; + } + + static std::unique_ptr create(const ctranslate2::models::ModelLoader& loader) { + auto impl = std::make_unique(); + impl->generator_ = std::make_unique(loader); + return impl; + } + private: + std::unique_ptr generator_; +}; + std::unique_ptr create_engine( rust::Str model_path, + rust::Str model_type, rust::Str device, rust::Slice device_indices, size_t num_replicas_per_device ) { - // model_path. - std::string model_path_string(model_path); - ctranslate2::models::ModelLoader loader(model_path_string); - - // device. - std::string device_string(device); - if (device_string == "cuda") { - loader.device = ctranslate2::Device::CUDA; - } else if (device_string == "cpu") { - loader.device = ctranslate2::Device::CPU; - } - - // device_indices - loader.device_indices.clear(); - std::copy(device_indices.begin(), device_indices.end(), std::back_inserter(loader.device_indices)); - - // num_replicas_per_device + std::string model_type_str(model_type); + std::string model_path_str(model_path); + ctranslate2::models::ModelLoader loader(model_path_str); + loader.device = ctranslate2::str_to_device(std::string(device)); + loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); loader.num_replicas_per_device = num_replicas_per_device; - auto translator = std::make_unique(loader); - return std::make_unique(std::move(translator)); + if (model_type_str == "decoder") { + return DecoderImpl::create(loader); + } else if (model_type_str == "encoder-decoder") { + return EncoderDecoderImpl::create(loader); + } else { + return nullptr; + } } } // namespace tabby diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 2ce0587..5701c6c 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,4 +1,4 @@ -use tokenizers::tokenizer::{Model, Tokenizer}; +use tokenizers::tokenizer::{Tokenizer}; #[macro_use] extern crate derive_builder; @@ -12,6 +12,7 @@ mod ffi { fn create_engine( model_path: &str, + model_type: &str, device: &str, device_indices: &[i32], num_replicas_per_device: usize, @@ -31,6 +32,8 @@ mod ffi { pub struct TextInferenceEngineCreateOptions { model_path: String, + model_type: String, + tokenizer_path: String, device: String, @@ -64,6 +67,7 @@ impl TextInferenceEngine { pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { let engine = ffi::create_engine( &options.model_path, + &options.model_type, &options.device, &options.device_indices, options.num_replicas_per_device, @@ -82,11 +86,14 @@ impl TextInferenceEngine { options.sampling_temperature, options.beam_size, ); - - let model = self.tokenizer.get_model(); let output_ids: Vec = output_tokens .iter() - .map(|x| model.token_to_id(x).unwrap()) + .filter_map(|x| { + match self.tokenizer.token_to_id(x) { + Some(y) => Some(y), + None => { println!("Warning: token ({}) missed in vocab", x); None } + } + }) .collect(); self.tokenizer.decode(output_ids, true).unwrap() } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index ee27e38..a8630cb 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -43,12 +43,32 @@ impl std::fmt::Display for Device { } } +#[derive(clap::ValueEnum, Clone)] +pub enum ModelType { + EncoderDecoder, + Decoder, +} + +impl std::fmt::Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let printable = match *self { + ModelType::EncoderDecoder => "encoder-decoder", + ModelType::Decoder => "decoder", + }; + write!(f, "{}", printable) + } +} + #[derive(Args)] pub struct ServeArgs { /// path to model for serving #[clap(long)] model: String, + /// model type for serving + #[clap(long, default_value_t=ModelType::Decoder)] + model_type: ModelType, + #[clap(long, default_value_t = 8080)] port: u16, @@ -79,6 +99,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> { .to_string(), ) .device(device) + .model_type(format!("{}", args.model_type)) .device_indices(args.device_indices.clone()) .num_replicas_per_device(args.num_replicas_per_device) .build()