diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index fde3551..473118a 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -25,8 +25,6 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, - rust::Str compute_type, - rust::Slice device_indices, - size_t num_replicas_per_device + rust::Slice device_indices ); } // namespace diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index ca11612..68c7829 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -105,19 +105,24 @@ std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, - rust::Str compute_type, - rust::Slice device_indices, - size_t num_replicas_per_device + rust::Slice device_indices ) { std::string model_type_str(model_type); std::string model_path_str(model_path); ctranslate2::models::ModelLoader loader(model_path_str); loader.device = ctranslate2::str_to_device(std::string(device)); loader.device_indices = std::vector(device_indices.begin(), device_indices.end()); - loader.num_replicas_per_device = num_replicas_per_device; + loader.compute_type = ctranslate2::ComputeType::AUTO; - std::string compute_type_str(compute_type); - loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); + const size_t num_cpus = std::thread::hardware_concurrency(); + if (loader.device == ctranslate2::Device::CUDA) { + // When device is cuda, set parallelism to be number of thread. + loader.num_replicas_per_device = num_cpus; + } else if (loader.device == ctranslate2::Device::CPU){ + // When device is cpu, adjust the number based on threads per replica. + // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 + loader.num_replicas_per_device = std::max(num_cpus / 4, 1); + } if (model_type_str == "AutoModelForCausalLM") { return DecoderImpl::create(loader); diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index f04f8ad..8cb1a62 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -27,9 +27,7 @@ mod ffi { model_path: &str, model_type: &str, device: &str, - compute_type: &str, device_indices: &[i32], - num_replicas_per_device: usize, ) -> SharedPtr; fn inference( @@ -65,10 +63,6 @@ pub struct CTranslate2EngineOptions { device: String, device_indices: Vec, - - num_replicas_per_device: usize, - - compute_type: String, } pub struct InferenceContext { @@ -103,9 +97,7 @@ impl CTranslate2Engine { &options.model_path, &options.model_type, &options.device, - &options.compute_type, &options.device_indices, - options.num_replicas_per_device, ); return Self { diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 99a3bd2..9eb86f9 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -5,7 +5,6 @@ use serde::Deserialize; use tabby_common::path::ModelDir; use tabby_inference::TextGeneration; -use super::Device; use crate::fatal; pub fn create_engine( @@ -68,28 +67,12 @@ fn create_ctranslate2_engine( metadata: &Metadata, ) -> Box { let device = format!("{}", args.device); - let compute_type = format!("{}", args.compute_type); - let num_replicas_per_device = { - let num_cpus = std::thread::available_parallelism() - .expect("Failed to read # of cpu") - .get(); - if args.device == Device::Cuda { - // When device is cuda, set parallelism to be number of thread. - num_cpus - } else { - // Otherwise, adjust the number based on threads per replica. - // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 - std::cmp::max(num_cpus / 4, 1) - } - }; let options = CTranslate2EngineOptionsBuilder::default() .model_path(model_dir.ctranslate2_dir()) .tokenizer_path(model_dir.tokenizer_file()) .device(device) .model_type(metadata.auto_model.clone()) .device_indices(args.device_indices.clone()) - .num_replicas_per_device(num_replicas_per_device) - .compute_type(compute_type) .build() .unwrap(); Box::new(CTranslate2Engine::create(options)) diff --git a/crates/tabby/src/serve/health.rs b/crates/tabby/src/serve/health.rs index 49dcd51..0a5e0bf 100644 --- a/crates/tabby/src/serve/health.rs +++ b/crates/tabby/src/serve/health.rs @@ -13,7 +13,6 @@ pub struct HealthState { #[serde(skip_serializing_if = "Option::is_none")] chat_model: Option, device: String, - compute_type: String, arch: String, cpu_info: String, cpu_count: usize, @@ -34,7 +33,6 @@ impl HealthState { model: args.model.clone(), chat_model: args.chat_model.clone(), device: args.device.to_string(), - compute_type: args.compute_type.to_string(), arch: ARCH.to_string(), cpu_info, cpu_count, diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index bee4682..9398c1f 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -88,36 +88,6 @@ pub enum Device { ExperimentalHttp, } -#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] -#[clap(rename_all = "snake_case")] -pub enum ComputeType { - /// Use the fastest computation type that is supported on this system and device - #[strum(serialize = "auto")] - Auto, - - /// Quantize model weights to use int8 for inference. - /// - /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on - /// float32. - #[strum(serialize = "int8")] - Int8, - - /// Use float16 for inference, only supported on CUDA devices. - #[strum(serialize = "float16")] - Float16, - - /// Use int8 / float16 mixed precision for inference, only supported on CUDA devices. - /// - /// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16 - /// instead of float32. - #[strum(serialize = "int8_float16")] - Int8Float16, - - /// Use float32 for inference. - #[strum(serialize = "float32")] - Float32, -} - #[derive(Args)] pub struct ServeArgs { /// Model id for `/completions` API endpoint. @@ -140,12 +110,12 @@ pub struct ServeArgs { device_indices: Vec, /// DEPRECATED: Do not use. - #[clap(long)] + #[clap(long, hide(true))] num_replicas_per_device: Option, - /// Compute type - #[clap(long, default_value_t=ComputeType::Auto)] - compute_type: ComputeType, + /// DEPRECATED: Do not use. + #[clap(long, hide(true))] + compute_type: Option, } #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] @@ -282,7 +252,7 @@ fn fallback() -> routing::MethodRouter { fn valid_args(args: &ServeArgs) { if args.num_replicas_per_device.is_some() { - warn!("num_replicas_per_device is deprecated and will be removed in future release."); + warn!("--num-replicas-per-device is deprecated and will be removed in future release."); } if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) @@ -290,11 +260,8 @@ fn valid_args(args: &ServeArgs) { fatal!("CPU device only supports device indices = [0]"); } - if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 { - match args.compute_type { - ComputeType::Auto | ComputeType::Int8 => {} - _ => fatal!("CPU device only supports int8 compute type"), - } + if args.compute_type.is_some() { + warn!("--compute-type is deprecated and will be removed in future release."); } }