fix: fix int8 compute type, fix auto compute type selection (include float32 into consideration for cuda compute capability <= 6.0) (#291)

sweep/improve-logging-information
Meng Zhang 2023-07-12 11:09:38 +08:00 committed by GitHub
parent 2ad0b69786
commit be90047477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 14 deletions

View File

@ -117,15 +117,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(
loader.num_replicas_per_device = num_replicas_per_device;
std::string compute_type_str(compute_type);
if (compute_type_str == "auto") {
if (loader.device == ctranslate2::Device::CPU) {
loader.compute_type = ctranslate2::ComputeType::INT8;
} else if (loader.device == ctranslate2::Device::CUDA) {
loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16;
}
} else {
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
}
if (model_type_str == "AutoModelForCausalLM") {
return DecoderImpl::create(loader);

View File

@ -54,10 +54,7 @@ pub enum Device {
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
#[clap(rename_all = "snake_case")]
pub enum ComputeType {
/// Set quantization automatically based on device:
///
/// CPU: Int8
/// CUDA: Int8Float32
/// Use the fastest computation type that is supported on this system and device
#[strum(serialize = "auto")]
Auto,
@ -65,7 +62,7 @@ pub enum ComputeType {
///
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
/// float32.
#[strum(serialize = "cpu")]
#[strum(serialize = "int8")]
Int8,
/// Use float16 for inference, only supported on CUDA devices.
@ -78,6 +75,10 @@ pub enum ComputeType {
/// instead of float32.
#[strum(serialize = "int8_float16")]
Int8Float16,
/// Use float32 for inference.
#[strum(serialize = "float32")]
Float32,
}
#[derive(Args)]