fix: fix int8 compute type, fix auto compute type selection (include float32 into consideration for cuda compute capability <= 6.0) (#291)

sweep/improve-logging-information
Meng Zhang 2023-07-12 11:09:38 +08:00 committed by GitHub
parent 2ad0b69786
commit be90047477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 14 deletions

View File

@ -117,15 +117,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
loader.num_replicas_per_device = num_replicas_per_device; loader.num_replicas_per_device = num_replicas_per_device;
std::string compute_type_str(compute_type); std::string compute_type_str(compute_type);
if (compute_type_str == "auto") { loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
if (loader.device == ctranslate2::Device::CPU) {
loader.compute_type = ctranslate2::ComputeType::INT8;
} else if (loader.device == ctranslate2::Device::CUDA) {
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
}
} else {
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
}
if (model_type_str == "AutoModelForCausalLM") { if (model_type_str == "AutoModelForCausalLM") {
return DecoderImpl::create(loader); return DecoderImpl::create(loader);

View File

@ -54,10 +54,7 @@ pub enum Device {
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
#[clap(rename_all = "snake_case")] #[clap(rename_all = "snake_case")]
pub enum ComputeType { pub enum ComputeType {
/// Set quantization automatically based on device: /// Use the fastest computation type that is supported on this system and device
///
/// CPU: Int8
/// CUDA: Int8Float32
#[strum(serialize = "auto")] #[strum(serialize = "auto")]
Auto, Auto,
@ -65,7 +62,7 @@ pub enum ComputeType {
/// ///
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
/// float32. /// float32.
#[strum(serialize = "cpu")] #[strum(serialize = "int8")]
Int8, Int8,
/// Use float16 for inference, only supported on CUDA devices. /// Use float16 for inference, only supported on CUDA devices.
@ -78,6 +75,10 @@ pub enum ComputeType {
/// instead of float32. /// instead of float32.
#[strum(serialize = "int8_float16")] #[strum(serialize = "int8_float16")]
Int8Float16, Int8Float16,
/// Use float32 for inference.
#[strum(serialize = "float32")]
Float32,
} }
#[derive(Args)] #[derive(Args)]