feat: support set compute_type through commandline arguments

improve-workflow
Meng Zhang 2023-06-13 12:04:07 -07:00
parent ba7e04d030
commit 9c9e46c6f4
5 changed files with 54 additions and 4 deletions

View File

@ -25,6 +25,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type, rust::Str model_type,
rust::Str device, rust::Str device,
rust::Str compute_type,
rust::Slice<const int32_t> device_indices, rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device size_t num_replicas_per_device
); );

View File

@ -105,6 +105,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path, rust::Str model_path,
rust::Str model_type, rust::Str model_type,
rust::Str device, rust::Str device,
rust::Str compute_type,
rust::Slice<const int32_t> device_indices, rust::Slice<const int32_t> device_indices,
size_t num_replicas_per_device size_t num_replicas_per_device
) { ) {
@ -115,10 +116,15 @@ std::shared_ptr<TextInferenceEngine> create_engine(
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end()); loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
loader.num_replicas_per_device = num_replicas_per_device; loader.num_replicas_per_device = num_replicas_per_device;
if (loader.device == ctranslate2::Device::CPU) { std::string compute_type_str(compute_type);
loader.compute_type = ctranslate2::ComputeType::INT8; if (compute_type_str == "auto") {
} else if (loader.device == ctranslate2::Device::CUDA) { if (loader.device == ctranslate2::Device::CPU) {
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; loader.compute_type = ctranslate2::ComputeType::INT8;
} else if (loader.device == ctranslate2::Device::CUDA) {
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
}
} else {
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
} }
if (model_type_str == "AutoModelForCausalLM") { if (model_type_str == "AutoModelForCausalLM") {

View File

@ -21,6 +21,7 @@ mod ffi {
model_path: &str, model_path: &str,
model_type: &str, model_type: &str,
device: &str, device: &str,
compute_type: &str,
device_indices: &[i32], device_indices: &[i32],
num_replicas_per_device: usize, num_replicas_per_device: usize,
) -> SharedPtr<TextInferenceEngine>; ) -> SharedPtr<TextInferenceEngine>;
@ -60,6 +61,8 @@ pub struct TextInferenceEngineCreateOptions {
device_indices: Vec<i32>, device_indices: Vec<i32>,
num_replicas_per_device: usize, num_replicas_per_device: usize,
compute_type: String,
} }
#[derive(Builder, Debug)] #[derive(Builder, Debug)]
@ -101,6 +104,7 @@ impl TextInferenceEngine {
&options.model_path, &options.model_path,
&options.model_type, &options.model_type,
&options.device, &options.device,
&options.compute_type,
&options.device_indices, &options.device_indices,
options.num_replicas_per_device, options.num_replicas_per_device,
); );

View File

@ -127,6 +127,7 @@ impl CompletionState {
let metadata = read_metadata(&model_dir); let metadata = read_metadata(&model_dir);
let device = format!("{}", args.device); let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let options = TextInferenceEngineCreateOptionsBuilder::default() let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir()) .model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file()) .tokenizer_path(model_dir.tokenizer_file())
@ -134,6 +135,7 @@ impl CompletionState {
.model_type(metadata.auto_model) .model_type(metadata.auto_model)
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.build() .build()
.unwrap(); .unwrap();
let engine = TextInferenceEngine::create(options); let engine = TextInferenceEngine::create(options);

View File

@ -49,6 +49,35 @@ pub enum Device {
Cuda, Cuda,
} }
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
#[clap(rename_all = "snake_case")]
pub enum ComputeType {
/// Set quantization automatically based on device:
///
/// CPU: Int8
/// CUDA: Int8Float32
#[strum(serialize = "auto")]
Auto,
/// Quantize model weights to use int8 for inference.
///
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
/// float32.
#[strum(serialize = "cpu")]
Int8,
/// Use float16 for inference, only supported on CUDA devices.
#[strum(serialize = "float16")]
Float16,
/// Use int8 / float16 mixed precision for inference, only supported on CUDA devices.
///
/// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16
/// instead of float32.
#[strum(serialize = "int8_float16")]
Int8Float16,
}
#[derive(Args)] #[derive(Args)]
pub struct ServeArgs { pub struct ServeArgs {
/// Model id for serving. /// Model id for serving.
@ -69,6 +98,10 @@ pub struct ServeArgs {
/// Number of replicas per device, only applicable for CPU. /// Number of replicas per device, only applicable for CPU.
#[clap(long, default_value_t = 1)] #[clap(long, default_value_t = 1)]
num_replicas_per_device: usize, num_replicas_per_device: usize,
/// Compute type
#[clap(long, default_value_t=ComputeType::Auto)]
compute_type: ComputeType,
} }
pub async fn main(args: &ServeArgs) { pub async fn main(args: &ServeArgs) {
@ -124,6 +157,10 @@ fn valid_args(args: &ServeArgs) {
{ {
fatal!("CPU device only supports device indices = [0]"); fatal!("CPU device only supports device indices = [0]");
} }
if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 {
fatal!("CPU device only supports int8 compute type");
}
} }
#[utoipa::path( #[utoipa::path(