diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 665e2f6..bc46ef8 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -7,6 +7,7 @@ use serde_json::Value; use tabby_common::path::ModelDir; use tabby_inference::TextGeneration; +use super::Device; use crate::fatal; fn get_param(params: &Value, key: &str) -> String { @@ -108,13 +109,26 @@ fn create_ctranslate2_engine( ) -> Box { let device = format!("{}", args.device); let compute_type = format!("{}", args.compute_type); + let num_replicas_per_device = { + let num_cpus = std::thread::available_parallelism() + .expect("Failed to read # of cpu") + .get(); + if args.device == Device::Cuda { + // When device is cuda, set parallelism to be number of thread. + num_cpus + } else { + // Otherwise, adjust the number based on threads per replica. + // https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77 + std::cmp::max(num_cpus / 4, 1) + } + }; let options = CTranslate2EngineOptionsBuilder::default() .model_path(model_dir.ctranslate2_dir()) .tokenizer_path(model_dir.tokenizer_file()) .device(device) .model_type(metadata.auto_model.clone()) .device_indices(args.device_indices.clone()) - .num_replicas_per_device(args.num_replicas_per_device) + .num_replicas_per_device(num_replicas_per_device) .compute_type(compute_type) .build() .unwrap(); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index b1f209b..ae9c49b 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -131,9 +131,9 @@ pub struct ServeArgs { #[clap(long, default_values_t=[0])] device_indices: Vec, - /// Number of replicas per device, only applicable for CPU. - #[clap(long, default_value_t = 1)] - num_replicas_per_device: usize, + /// DEPRECATED: Do not use. + #[clap(long)] + num_replicas_per_device: Option, /// Compute type #[clap(long, default_value_t=ComputeType::Auto)] @@ -244,6 +244,10 @@ fn fallback() -> routing::MethodRouter { } fn valid_args(args: &ServeArgs) { + if args.num_replicas_per_device.is_some() { + warn!("num_replicas_per_device is deprecated and will be removed in future release."); + } + if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) { fatal!("CPU device only supports device indices = [0]");