diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index ab633ce..ca11612 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -117,15 +117,7 @@ std::shared_ptr create_engine( loader.num_replicas_per_device = num_replicas_per_device; std::string compute_type_str(compute_type); - if (compute_type_str == "auto") { - if (loader.device == ctranslate2::Device::CPU) { - loader.compute_type = ctranslate2::ComputeType::INT8; - } else if (loader.device == ctranslate2::Device::CUDA) { - loader.compute_type = ctranslate2::ComputeType::INT8_FLOAT16; - } - } else { - loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); - } + loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str); if (model_type_str == "AutoModelForCausalLM") { return DecoderImpl::create(loader); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 79312b9..4b8d67f 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -54,10 +54,7 @@ pub enum Device { #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[clap(rename_all = "snake_case")] pub enum ComputeType { - /// Set quantization automatically based on device: - /// - /// CPU: Int8 - /// CUDA: Int8Float32 + /// Use the fastest computation type that is supported on this system and device #[strum(serialize = "auto")] Auto, @@ -65,7 +62,7 @@ pub enum ComputeType { /// /// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on /// float32. - #[strum(serialize = "cpu")] + #[strum(serialize = "int8")] Int8, /// Use float16 for inference, only supported on CUDA devices. @@ -78,6 +75,10 @@ pub enum ComputeType { /// instead of float32. #[strum(serialize = "int8_float16")] Int8Float16, + + /// Use float32 for inference. + #[strum(serialize = "float32")] + Float32, } #[derive(Args)]