Revert "Revert "refactor: deprecate --compute-type (#505)""

This reverts commit aa6f39985c.
r0.3
Meng Zhang 2023-10-20 00:36:17 -07:00
parent e125ab82fc
commit a9f1829a52
6 changed files with 19 additions and 76 deletions

View File

@ -25,8 +25,6 @@ std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Str compute_type,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device
rust::Slice<const int32_t> device_indices
);
} // namespace

View File

@ -105,19 +105,24 @@ std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,
rust::Str compute_type,
rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device
rust::Slice<const int32_t> device_indices
) {
std::string model_type_str(model_type);
std::string model_path_str(model_path);
ctranslate2::models::ModelLoader loader(model_path_str);
loader.device = ctranslate2::str_to_device(std::string(device));
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
loader.num_replicas_per_device = num_replicas_per_device;
loader.compute_type = ctranslate2::ComputeType::AUTO;
std::string compute_type_str(compute_type);
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
const size_t num_cpus = std::thread::hardware_concurrency();
if (loader.device == ctranslate2::Device::CUDA) {
// When device is cuda, set parallelism to be number of thread.
loader.num_replicas_per_device = num_cpus;
} else if (loader.device == ctranslate2::Device::CPU){
// When device is cpu, adjust the number based on threads per replica.
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
loader.num_replicas_per_device = std::max<int32_t>(num_cpus / 4, 1);
}
if (model_type_str == "AutoModelForCausalLM") {
return DecoderImpl::create(loader);

View File

@ -27,9 +27,7 @@ mod ffi {
model_path: &str,
model_type: &str,
device: &str,
compute_type: &str,
device_indices: &[i32],
num_replicas_per_device: usize,
) -> SharedPtr<TextInferenceEngine>;
fn inference(
@ -65,10 +63,6 @@ pub struct CTranslate2EngineOptions {
device: String,
device_indices: Vec<i32>,
num_replicas_per_device: usize,
compute_type: String,
}
pub struct InferenceContext {
@ -103,9 +97,7 @@ impl CTranslate2Engine {
&options.model_path,
&options.model_type,
&options.device,
&options.compute_type,
&options.device_indices,
options.num_replicas_per_device,
);
return Self {

View File

@ -5,7 +5,6 @@ use serde::Deserialize;
use tabby_common::path::ModelDir;
use tabby_inference::TextGeneration;
use super::Device;
use crate::fatal;
pub fn create_engine(
@ -68,28 +67,12 @@ fn create_ctranslate2_engine(
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let num_replicas_per_device = {
let num_cpus = std::thread::available_parallelism()
.expect("Failed to read # of cpu")
.get();
if args.device == Device::Cuda {
// When device is cuda, set parallelism to be number of thread.
num_cpus
} else {
// Otherwise, adjust the number based on threads per replica.
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
std::cmp::max(num_cpus / 4, 1)
}
};
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.num_replicas_per_device(num_replicas_per_device)
.compute_type(compute_type)
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))

View File

@ -13,7 +13,6 @@ pub struct HealthState {
#[serde(skip_serializing_if = "Option::is_none")]
chat_model: Option<String>,
device: String,
compute_type: String,
arch: String,
cpu_info: String,
cpu_count: usize,
@ -34,7 +33,6 @@ impl HealthState {
model: args.model.clone(),
chat_model: args.chat_model.clone(),
device: args.device.to_string(),
compute_type: args.compute_type.to_string(),
arch: ARCH.to_string(),
cpu_info,
cpu_count,

View File

@ -88,36 +88,6 @@ pub enum Device {
ExperimentalHttp,
}
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
#[clap(rename_all = "snake_case")]
pub enum ComputeType {
/// Use the fastest computation type that is supported on this system and device
#[strum(serialize = "auto")]
Auto,
/// Quantize model weights to use int8 for inference.
///
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
/// float32.
#[strum(serialize = "int8")]
Int8,
/// Use float16 for inference, only supported on CUDA devices.
#[strum(serialize = "float16")]
Float16,
/// Use int8 / float16 mixed precision for inference, only supported on CUDA devices.
///
/// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16
/// instead of float32.
#[strum(serialize = "int8_float16")]
Int8Float16,
/// Use float32 for inference.
#[strum(serialize = "float32")]
Float32,
}
#[derive(Args)]
pub struct ServeArgs {
/// Model id for `/completions` API endpoint.
@ -140,12 +110,12 @@ pub struct ServeArgs {
device_indices: Vec<i32>,
/// DEPRECATED: Do not use.
#[clap(long)]
#[clap(long, hide(true))]
num_replicas_per_device: Option<usize>,
/// Compute type
#[clap(long, default_value_t=ComputeType::Auto)]
compute_type: ComputeType,
/// DEPRECATED: Do not use.
#[clap(long, hide(true))]
compute_type: Option<String>,
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
@ -282,7 +252,7 @@ fn fallback() -> routing::MethodRouter {
fn valid_args(args: &ServeArgs) {
if args.num_replicas_per_device.is_some() {
warn!("num_replicas_per_device is deprecated and will be removed in future release.");
warn!("--num-replicas-per-device is deprecated and will be removed in future release.");
}
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
@ -290,11 +260,8 @@ fn valid_args(args: &ServeArgs) {
fatal!("CPU device only supports device indices = [0]");
}
if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 {
match args.compute_type {
ComputeType::Auto | ComputeType::Int8 => {}
_ => fatal!("CPU device only supports int8 compute type"),
}
if args.compute_type.is_some() {
warn!("--compute-type is deprecated and will be removed in future release.");
}
}