feat: support set compute_type through commandline arguments
parent
ba7e04d030
commit
9c9e46c6f4
|
|
@ -25,6 +25,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
|
rust::Str compute_type,
|
||||||
rust::Slice<const int32_t> device_indices,
|
rust::Slice<const int32_t> device_indices,
|
||||||
size_t num_replicas_per_device
|
size_t num_replicas_per_device
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
|
rust::Str compute_type,
|
||||||
rust::Slice<const int32_t> device_indices,
|
rust::Slice<const int32_t> device_indices,
|
||||||
size_t num_replicas_per_device
|
size_t num_replicas_per_device
|
||||||
) {
|
) {
|
||||||
|
|
@ -115,10 +116,15 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||||
loader.num_replicas_per_device = num_replicas_per_device;
|
loader.num_replicas_per_device = num_replicas_per_device;
|
||||||
|
|
||||||
if (loader.device == ctranslate2::Device::CPU) {
|
std::string compute_type_str(compute_type);
|
||||||
loader.compute_type = ctranslate2::ComputeType::INT8;
|
if (compute_type_str == "auto") {
|
||||||
} else if (loader.device == ctranslate2::Device::CUDA) {
|
if (loader.device == ctranslate2::Device::CPU) {
|
||||||
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
|
loader.compute_type = ctranslate2::ComputeType::INT8;
|
||||||
|
} else if (loader.device == ctranslate2::Device::CUDA) {
|
||||||
|
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_type_str == "AutoModelForCausalLM") {
|
if (model_type_str == "AutoModelForCausalLM") {
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ mod ffi {
|
||||||
model_path: &str,
|
model_path: &str,
|
||||||
model_type: &str,
|
model_type: &str,
|
||||||
device: &str,
|
device: &str,
|
||||||
|
compute_type: &str,
|
||||||
device_indices: &[i32],
|
device_indices: &[i32],
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
) -> SharedPtr<TextInferenceEngine>;
|
) -> SharedPtr<TextInferenceEngine>;
|
||||||
|
|
@ -60,6 +61,8 @@ pub struct TextInferenceEngineCreateOptions {
|
||||||
device_indices: Vec<i32>,
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
|
|
||||||
|
compute_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
|
|
@ -101,6 +104,7 @@ impl TextInferenceEngine {
|
||||||
&options.model_path,
|
&options.model_path,
|
||||||
&options.model_type,
|
&options.model_type,
|
||||||
&options.device,
|
&options.device,
|
||||||
|
&options.compute_type,
|
||||||
&options.device_indices,
|
&options.device_indices,
|
||||||
options.num_replicas_per_device,
|
options.num_replicas_per_device,
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,7 @@ impl CompletionState {
|
||||||
let metadata = read_metadata(&model_dir);
|
let metadata = read_metadata(&model_dir);
|
||||||
|
|
||||||
let device = format!("{}", args.device);
|
let device = format!("{}", args.device);
|
||||||
|
let compute_type = format!("{}", args.compute_type);
|
||||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||||
.model_path(model_dir.ctranslate2_dir())
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
|
|
@ -134,6 +135,7 @@ impl CompletionState {
|
||||||
.model_type(metadata.auto_model)
|
.model_type(metadata.auto_model)
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
|
.compute_type(compute_type)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let engine = TextInferenceEngine::create(options);
|
let engine = TextInferenceEngine::create(options);
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,35 @@ pub enum Device {
|
||||||
Cuda,
|
Cuda,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||||
|
#[clap(rename_all = "snake_case")]
|
||||||
|
pub enum ComputeType {
|
||||||
|
/// Set quantization automatically based on device:
|
||||||
|
///
|
||||||
|
/// CPU: Int8
|
||||||
|
/// CUDA: Int8Float32
|
||||||
|
#[strum(serialize = "auto")]
|
||||||
|
Auto,
|
||||||
|
|
||||||
|
/// Quantize model weights to use int8 for inference.
|
||||||
|
///
|
||||||
|
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
|
||||||
|
/// float32.
|
||||||
|
#[strum(serialize = "cpu")]
|
||||||
|
Int8,
|
||||||
|
|
||||||
|
/// Use float16 for inference, only supported on CUDA devices.
|
||||||
|
#[strum(serialize = "float16")]
|
||||||
|
Float16,
|
||||||
|
|
||||||
|
/// Use int8 / float16 mixed precision for inference, only supported on CUDA devices.
|
||||||
|
///
|
||||||
|
/// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16
|
||||||
|
/// instead of float32.
|
||||||
|
#[strum(serialize = "int8_float16")]
|
||||||
|
Int8Float16,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct ServeArgs {
|
pub struct ServeArgs {
|
||||||
/// Model id for serving.
|
/// Model id for serving.
|
||||||
|
|
@ -69,6 +98,10 @@ pub struct ServeArgs {
|
||||||
/// Number of replicas per device, only applicable for CPU.
|
/// Number of replicas per device, only applicable for CPU.
|
||||||
#[clap(long, default_value_t = 1)]
|
#[clap(long, default_value_t = 1)]
|
||||||
num_replicas_per_device: usize,
|
num_replicas_per_device: usize,
|
||||||
|
|
||||||
|
/// Compute type
|
||||||
|
#[clap(long, default_value_t=ComputeType::Auto)]
|
||||||
|
compute_type: ComputeType,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &ServeArgs) {
|
pub async fn main(args: &ServeArgs) {
|
||||||
|
|
@ -124,6 +157,10 @@ fn valid_args(args: &ServeArgs) {
|
||||||
{
|
{
|
||||||
fatal!("CPU device only supports device indices = [0]");
|
fatal!("CPU device only supports device indices = [0]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 {
|
||||||
|
fatal!("CPU device only supports int8 compute type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue