diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 473118a..fde3551 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -25,6 +25,8 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, - rust::Slice device_indices + rust::Str compute_type, + rust::Slice device_indices, + size_t num_replicas_per_device ); } // namespace diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 68c7829..ca11612 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -105,24 +105,19 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, - rust::Slice device_indices + rust::Str compute_type, + rust::Slice device_indices, + size_t num_replicas_per_device ) { std::string model_type_str(model_type); std::string model_path_str(model_path); ctranslate2::models::ModelLoader loader(model_path_str); loader.device = ctranslate2::str_to_device(std::string(device)); loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); - loader.compute_type = ctranslate2::ComputeType::AUTO; + loader.num_replicas_per_device = num_replicas_per_device; - const size_t num_cpus = std::thread::hardware_concurrency(); - if (loader.device == ctranslate2::Device::CUDA) { - // When device is cuda, set parallelism to be number of thread. - loader.num_replicas_per_device = num_cpus; - } else if (loader.device == ctranslate2::Device::CPU){ - // When device is cpu, adjust the number based on threads per replica. - // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 - loader.num_replicas_per_device = std::max(num_cpus / 4, 1); - } + std::string compute_type_str(compute_type); + loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); if (model_type_str == "AutoModelForCausalLM") { return DecoderImpl::create(loader); diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 8cb1a62..f04f8ad 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -27,7 +27,9 @@ mod ffi { model_path: &str, model_type: &str, device: &str, + compute_type: &str, device_indices: &[i32], + num_replicas_per_device: usize, ) -> SharedPtr; fn inference( @@ -63,6 +65,10 @@ pub struct CTranslate2EngineOptions { device: String, device_indices: Vec, + + num_replicas_per_device: usize, + + compute_type: String, } pub struct InferenceContext { @@ -97,7 +103,9 @@ impl CTranslate2Engine { &options.model_path, &options.model_type, &options.device, + &options.compute_type, &options.device_indices, + options.num_replicas_per_device, ); return Self { diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 9eb86f9..99a3bd2 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -5,6 +5,7 @@ use serde::Deserialize; use tabby_common::path::ModelDir; use tabby_inference::TextGeneration; +use super::Device; use crate::fatal; pub fn create_engine( @@ -67,12 +68,28 @@ fn create_ctranslate2_engine( metadata: &Metadata, ) -> Box { let device = format!("{}", args.device); + let compute_type = format!("{}", args.compute_type); + let num_replicas_per_device = { + let num_cpus = std::thread::available_parallelism() + .expect("Failed to read # of cpu") + .get(); + if args.device == Device::Cuda { + // When device is cuda, set parallelism to be number of thread. + num_cpus + } else { + // Otherwise, adjust the number based on threads per replica. + // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 + std::cmp::max(num_cpus / 4, 1) + } + }; let options = CTranslate2EngineOptionsBuilder::default() .model_path(model_dir.ctranslate2_dir()) .tokenizer_path(model_dir.tokenizer_file()) .device(device) .model_type(metadata.auto_model.clone()) .device_indices(args.device_indices.clone()) + .num_replicas_per_device(num_replicas_per_device) + .compute_type(compute_type) .build() .unwrap(); Box::new(CTranslate2Engine::create(options)) diff --git a/crates/tabby/src/serve/health.rs b/crates/tabby/src/serve/health.rs index 0a5e0bf..49dcd51 100644 --- a/crates/tabby/src/serve/health.rs +++ b/crates/tabby/src/serve/health.rs @@ -13,6 +13,7 @@ pub struct HealthState { #[serde(skip_serializing_if = "Option::is_none")] chat_model: Option, device: String, + compute_type: String, arch: String, cpu_info: String, cpu_count: usize, @@ -33,6 +34,7 @@ impl HealthState { model: args.model.clone(), chat_model: args.chat_model.clone(), device: args.device.to_string(), + compute_type: args.compute_type.to_string(), arch: ARCH.to_string(), cpu_info, cpu_count, diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 9398c1f..bee4682 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -88,6 +88,36 @@ pub enum Device { ExperimentalHttp, } +#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] +#[clap(rename_all = "snake_case")] +pub enum ComputeType { + /// Use the fastest computation type that is supported on this system and device + #[strum(serialize = "auto")] + Auto, + + /// Quantize model weights to use int8 for inference. + /// + /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on + /// float32. + #[strum(serialize = "int8")] + Int8, + + /// Use float16 for inference, only supported on CUDA devices. + #[strum(serialize = "float16")] + Float16, + + /// Use int8 / float16 mixed precision for inference, only supported on CUDA devices. + /// + /// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16 + /// instead of float32. + #[strum(serialize = "int8_float16")] + Int8Float16, + + /// Use float32 for inference. + #[strum(serialize = "float32")] + Float32, +} + #[derive(Args)] pub struct ServeArgs { /// Model id for `/completions` API endpoint. @@ -110,12 +140,12 @@ pub struct ServeArgs { device_indices: Vec, /// DEPRECATED: Do not use. - #[clap(long, hide(true))] + #[clap(long)] num_replicas_per_device: Option, - /// DEPRECATED: Do not use. - #[clap(long, hide(true))] - compute_type: Option, + /// Compute type + #[clap(long, default_value_t=ComputeType::Auto)] + compute_type: ComputeType, } #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] @@ -252,7 +282,7 @@ fn fallback() -> routing::MethodRouter { fn valid_args(args: &ServeArgs) { if args.num_replicas_per_device.is_some() { - warn!("--num-replicas-per-device is deprecated and will be removed in future release."); + warn!("num_replicas_per_device is deprecated and will be removed in future release."); } if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) @@ -260,8 +290,11 @@ fn valid_args(args: &ServeArgs) { fatal!("CPU device only supports device indices = [0]"); } - if args.compute_type.is_some() { - warn!("--compute-type is deprecated and will be removed in future release."); + if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 { + match args.compute_type { + ComputeType::Auto | ComputeType::Int8 => {} + _ => fatal!("CPU device only supports int8 compute type"), + } } }