feat: deprecate num_replicas_per_thread, generate default value for it

r0.2
Meng Zhang 2023-10-03 17:02:37 -07:00
parent 1afba47059
commit b3b498624c
2 changed files with 22 additions and 4 deletions

View File

@ -7,6 +7,7 @@ use serde_json::Value;
use tabby_common::path::ModelDir; use tabby_common::path::ModelDir;
use tabby_inference::TextGeneration; use tabby_inference::TextGeneration;
use super::Device;
use crate::fatal; use crate::fatal;
fn get_param(params: &Value, key: &str) -> String { fn get_param(params: &Value, key: &str) -> String {
@ -108,13 +109,26 @@ fn create_ctranslate2_engine(
) -> Box<dyn TextGeneration> { ) -> Box<dyn TextGeneration> {
let device = format!("{}", args.device); let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type); let compute_type = format!("{}", args.compute_type);
let num_replicas_per_device = {
let num_cpus = std::thread::available_parallelism()
.expect("Failed to read # of cpu")
.get();
if args.device == Device::Cuda {
// When device is cuda, set parallelism to be number of thread.
num_cpus
} else {
// Otherwise, adjust the number based on threads per replica.
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
std::cmp::max(num_cpus / 4, 1)
}
};
let options = CTranslate2EngineOptionsBuilder::default() let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir()) .model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file()) .tokenizer_path(model_dir.tokenizer_file())
.device(device) .device(device)
.model_type(metadata.auto_model.clone()) .model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(num_replicas_per_device)
.compute_type(compute_type) .compute_type(compute_type)
.build() .build()
.unwrap(); .unwrap();

View File

@ -131,9 +131,9 @@ pub struct ServeArgs {
#[clap(long, default_values_t=[0])] #[clap(long, default_values_t=[0])]
device_indices: Vec<i32>, device_indices: Vec<i32>,
/// Number of replicas per device, only applicable for CPU. /// DEPRECATED: Do not use.
#[clap(long, default_value_t = 1)] #[clap(long)]
num_replicas_per_device: usize, num_replicas_per_device: Option<usize>,
/// Compute type /// Compute type
#[clap(long, default_value_t=ComputeType::Auto)] #[clap(long, default_value_t=ComputeType::Auto)]
@ -244,6 +244,10 @@ fn fallback() -> routing::MethodRouter {
} }
fn valid_args(args: &ServeArgs) { fn valid_args(args: &ServeArgs) {
if args.num_replicas_per_device.is_some() {
warn!("num_replicas_per_device is deprecated and will be removed in future release.");
}
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{ {
fatal!("CPU device only supports device indices = [0]"); fatal!("CPU device only supports device indices = [0]");