From be900474775da992a2891f465bd0d1db3cea6ac1 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 12 Jul 2023 11:09:38 +0800 Subject: [PATCH] fix: fix int8 compute type, fix auto compute type selection (include float32 into consideration for cuda compute capability <= 6.0) (#291) --- crates/ctranslate2-bindings/src/ctranslate2.cc | 10 +--------- crates/tabby/src/serve/mod.rs | 11 ++++++----- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index ab633ce..ca11612 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -117,15 +117,7 @@ std::shared_ptr create_engine( loader.num_replicas_per_device = num_replicas_per_device; std::string compute_type_str(compute_type); - if (compute_type_str == "auto") { - if (loader.device == ctranslate2::Device::CPU) { - loader.compute_type = ctranslate2::ComputeType::INT8; - } else if (loader.device == ctranslate2::Device::CUDA) { - loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; - } - } else { - loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); - } + loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); if (model_type_str == "AutoModelForCausalLM") { return DecoderImpl::create(loader); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 79312b9..4b8d67f 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -54,10 +54,7 @@ pub enum Device { #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[clap(rename_all = "snake_case")] pub enum ComputeType { - /// Set quantization automatically based on device: - /// - /// CPU: Int8 - /// CUDA: Int8Float32 + /// Use the fastest computation type that is supported on this system and device #[strum(serialize = "auto")] Auto, @@ -65,7 +62,7 @@ pub enum ComputeType { /// /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on /// float32. - #[strum(serialize = "cpu")] + #[strum(serialize = "int8")] Int8, /// Use float16 for inference, only supported on CUDA devices. @@ -78,6 +75,10 @@ pub enum ComputeType { /// instead of float32. #[strum(serialize = "int8_float16")] Int8Float16, + + /// Use float32 for inference. + #[strum(serialize = "float32")] + Float32, } #[derive(Args)]