From 9c9e46c6f44058f44dec527bab271e9d3531b442 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 13 Jun 2023 12:04:07 -0700 Subject: [PATCH] feat: support set compute_type through commandline arguments --- .../include/ctranslate2.h | 1 + .../ctranslate2-bindings/src/ctranslate2.cc | 14 +++++-- crates/ctranslate2-bindings/src/lib.rs | 4 ++ crates/tabby/src/serve/completions.rs | 2 + crates/tabby/src/serve/mod.rs | 37 +++++++++++++++++++ 5 files changed, 54 insertions(+), 4 deletions(-) diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index aa67db2..fde3551 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -25,6 +25,7 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, + rust::Str compute_type, rust::Slice device_indices, size_t num_replicas_per_device ); diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 62dacf3..ab633ce 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -105,6 +105,7 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, + rust::Str compute_type, rust::Slice device_indices, size_t num_replicas_per_device ) { @@ -115,10 +116,15 @@ std::shared_ptr create_engine( loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); loader.num_replicas_per_device = num_replicas_per_device; - if (loader.device == ctranslate2::Device::CPU) { - loader.compute_type = ctranslate2::ComputeType::INT8; - } else if (loader.device == ctranslate2::Device::CUDA) { - loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; + std::string compute_type_str(compute_type); + if (compute_type_str == "auto") { + if (loader.device == ctranslate2::Device::CPU) { + loader.compute_type = ctranslate2::ComputeType::INT8; + } else if (loader.device == ctranslate2::Device::CUDA) { + loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; + } + } else { + loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); } if (model_type_str == "AutoModelForCausalLM") { diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 152c6d9..664048c 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -21,6 +21,7 @@ mod ffi { model_path: &str, model_type: &str, device: &str, + compute_type: &str, device_indices: &[i32], num_replicas_per_device: usize, ) -> SharedPtr; @@ -60,6 +61,8 @@ pub struct TextInferenceEngineCreateOptions { device_indices: Vec, num_replicas_per_device: usize, + + compute_type: String, } #[derive(Builder, Debug)] @@ -101,6 +104,7 @@ impl TextInferenceEngine { &options.model_path, &options.model_type, &options.device, + &options.compute_type, &options.device_indices, options.num_replicas_per_device, ); diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 21d7674..fff57d3 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -127,6 +127,7 @@ impl CompletionState { let metadata = read_metadata(&model_dir); let device = format!("{}", args.device); + let compute_type = format!("{}", args.compute_type); let options = TextInferenceEngineCreateOptionsBuilder::default() .model_path(model_dir.ctranslate2_dir()) .tokenizer_path(model_dir.tokenizer_file()) @@ -134,6 +135,7 @@ impl CompletionState { .model_type(metadata.auto_model) .device_indices(args.device_indices.clone()) .num_replicas_per_device(args.num_replicas_per_device) + .compute_type(compute_type) .build() .unwrap(); let engine = TextInferenceEngine::create(options); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 952663a..85841f1 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -49,6 +49,35 @@ pub enum Device { Cuda, } +#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] +#[clap(rename_all = "snake_case")] +pub enum ComputeType { + /// Set quantization automatically based on device: + /// + /// CPU: Int8 + /// CUDA: Int8Float32 + #[strum(serialize = "auto")] + Auto, + + /// Quantize model weights to use int8 for inference. + /// + /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on + /// float32. + #[strum(serialize = "cpu")] + Int8, + + /// Use float16 for inference, only supported on CUDA devices. + #[strum(serialize = "float16")] + Float16, + + /// Use int8 / float16 mixed precision for inference, only supported on CUDA devices. + /// + /// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16 + /// instead of float32. + #[strum(serialize = "int8_float16")] + Int8Float16, +} + #[derive(Args)] pub struct ServeArgs { /// Model id for serving. @@ -69,6 +98,10 @@ pub struct ServeArgs { /// Number of replicas per device, only applicable for CPU. #[clap(long, default_value_t = 1)] num_replicas_per_device: usize, + + /// Compute type + #[clap(long, default_value_t=ComputeType::Auto)] + compute_type: ComputeType, } pub async fn main(args: &ServeArgs) { @@ -124,6 +157,10 @@ fn valid_args(args: &ServeArgs) { { fatal!("CPU device only supports device indices = [0]"); } + + if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 { + fatal!("CPU device only supports int8 compute type"); + } } #[utoipa::path(