feat: deprecate num_replicas_per_thread, generate default value for it
parent
1afba47059
commit
b3b498624c
|
|
@ -7,6 +7,7 @@ use serde_json::Value;
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
use tabby_inference::TextGeneration;
|
use tabby_inference::TextGeneration;
|
||||||
|
|
||||||
|
use super::Device;
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
fn get_param(params: &Value, key: &str) -> String {
|
fn get_param(params: &Value, key: &str) -> String {
|
||||||
|
|
@ -108,13 +109,26 @@ fn create_ctranslate2_engine(
|
||||||
) -> Box<dyn TextGeneration> {
|
) -> Box<dyn TextGeneration> {
|
||||||
let device = format!("{}", args.device);
|
let device = format!("{}", args.device);
|
||||||
let compute_type = format!("{}", args.compute_type);
|
let compute_type = format!("{}", args.compute_type);
|
||||||
|
let num_replicas_per_device = {
|
||||||
|
let num_cpus = std::thread::available_parallelism()
|
||||||
|
.expect("Failed to read # of cpu")
|
||||||
|
.get();
|
||||||
|
if args.device == Device::Cuda {
|
||||||
|
// When device is cuda, set parallelism to be number of thread.
|
||||||
|
num_cpus
|
||||||
|
} else {
|
||||||
|
// Otherwise, adjust the number based on threads per replica.
|
||||||
|
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
||||||
|
std::cmp::max(num_cpus / 4, 1)
|
||||||
|
}
|
||||||
|
};
|
||||||
let options = CTranslate2EngineOptionsBuilder::default()
|
let options = CTranslate2EngineOptionsBuilder::default()
|
||||||
.model_path(model_dir.ctranslate2_dir())
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
.device(device)
|
.device(device)
|
||||||
.model_type(metadata.auto_model.clone())
|
.model_type(metadata.auto_model.clone())
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(num_replicas_per_device)
|
||||||
.compute_type(compute_type)
|
.compute_type(compute_type)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -131,9 +131,9 @@ pub struct ServeArgs {
|
||||||
#[clap(long, default_values_t=[0])]
|
#[clap(long, default_values_t=[0])]
|
||||||
device_indices: Vec<i32>,
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
/// Number of replicas per device, only applicable for CPU.
|
/// DEPRECATED: Do not use.
|
||||||
#[clap(long, default_value_t = 1)]
|
#[clap(long)]
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: Option<usize>,
|
||||||
|
|
||||||
/// Compute type
|
/// Compute type
|
||||||
#[clap(long, default_value_t=ComputeType::Auto)]
|
#[clap(long, default_value_t=ComputeType::Auto)]
|
||||||
|
|
@ -244,6 +244,10 @@ fn fallback() -> routing::MethodRouter {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn valid_args(args: &ServeArgs) {
|
fn valid_args(args: &ServeArgs) {
|
||||||
|
if args.num_replicas_per_device.is_some() {
|
||||||
|
warn!("num_replicas_per_device is deprecated and will be removed in future release.");
|
||||||
|
}
|
||||||
|
|
||||||
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
||||||
{
|
{
|
||||||
fatal!("CPU device only supports device indices = [0]");
|
fatal!("CPU device only supports device indices = [0]");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue