parent
e125ab82fc
commit
a9f1829a52
|
|
@ -25,8 +25,6 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
rust::Str compute_type,
|
rust::Slice<const int32_t> device_indices
|
||||||
rust::Slice<const int32_t> device_indices,
|
|
||||||
size_t num_replicas_per_device
|
|
||||||
);
|
);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -105,19 +105,24 @@ std::shared_ptr<TextInferenceEngine> create_engine(
|
||||||
rust::Str model_path,
|
rust::Str model_path,
|
||||||
rust::Str model_type,
|
rust::Str model_type,
|
||||||
rust::Str device,
|
rust::Str device,
|
||||||
rust::Str compute_type,
|
rust::Slice<const int32_t> device_indices
|
||||||
rust::Slice<const int32_t> device_indices,
|
|
||||||
size_t num_replicas_per_device
|
|
||||||
) {
|
) {
|
||||||
std::string model_type_str(model_type);
|
std::string model_type_str(model_type);
|
||||||
std::string model_path_str(model_path);
|
std::string model_path_str(model_path);
|
||||||
ctranslate2::models::ModelLoader loader(model_path_str);
|
ctranslate2::models::ModelLoader loader(model_path_str);
|
||||||
loader.device = ctranslate2::str_to_device(std::string(device));
|
loader.device = ctranslate2::str_to_device(std::string(device));
|
||||||
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||||
loader.num_replicas_per_device = num_replicas_per_device;
|
loader.compute_type = ctranslate2::ComputeType::AUTO;
|
||||||
|
|
||||||
std::string compute_type_str(compute_type);
|
const size_t num_cpus = std::thread::hardware_concurrency();
|
||||||
loader.compute_type = ctranslate2::str_to_compute_type(compute_type_str);
|
if (loader.device == ctranslate2::Device::CUDA) {
|
||||||
|
// When device is cuda, set parallelism to be number of thread.
|
||||||
|
loader.num_replicas_per_device = num_cpus;
|
||||||
|
} else if (loader.device == ctranslate2::Device::CPU){
|
||||||
|
// When device is cpu, adjust the number based on threads per replica.
|
||||||
|
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
||||||
|
loader.num_replicas_per_device = std::max<int32_t>(num_cpus / 4, 1);
|
||||||
|
}
|
||||||
|
|
||||||
if (model_type_str == "AutoModelForCausalLM") {
|
if (model_type_str == "AutoModelForCausalLM") {
|
||||||
return DecoderImpl::create(loader);
|
return DecoderImpl::create(loader);
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,7 @@ mod ffi {
|
||||||
model_path: &str,
|
model_path: &str,
|
||||||
model_type: &str,
|
model_type: &str,
|
||||||
device: &str,
|
device: &str,
|
||||||
compute_type: &str,
|
|
||||||
device_indices: &[i32],
|
device_indices: &[i32],
|
||||||
num_replicas_per_device: usize,
|
|
||||||
) -> SharedPtr<TextInferenceEngine>;
|
) -> SharedPtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn inference(
|
fn inference(
|
||||||
|
|
@ -65,10 +63,6 @@ pub struct CTranslate2EngineOptions {
|
||||||
device: String,
|
device: String,
|
||||||
|
|
||||||
device_indices: Vec<i32>,
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
num_replicas_per_device: usize,
|
|
||||||
|
|
||||||
compute_type: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct InferenceContext {
|
pub struct InferenceContext {
|
||||||
|
|
@ -103,9 +97,7 @@ impl CTranslate2Engine {
|
||||||
&options.model_path,
|
&options.model_path,
|
||||||
&options.model_type,
|
&options.model_type,
|
||||||
&options.device,
|
&options.device,
|
||||||
&options.compute_type,
|
|
||||||
&options.device_indices,
|
&options.device_indices,
|
||||||
options.num_replicas_per_device,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
return Self {
|
return Self {
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ use serde::Deserialize;
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
use tabby_inference::TextGeneration;
|
use tabby_inference::TextGeneration;
|
||||||
|
|
||||||
use super::Device;
|
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
pub fn create_engine(
|
pub fn create_engine(
|
||||||
|
|
@ -68,28 +67,12 @@ fn create_ctranslate2_engine(
|
||||||
metadata: &Metadata,
|
metadata: &Metadata,
|
||||||
) -> Box<dyn TextGeneration> {
|
) -> Box<dyn TextGeneration> {
|
||||||
let device = format!("{}", args.device);
|
let device = format!("{}", args.device);
|
||||||
let compute_type = format!("{}", args.compute_type);
|
|
||||||
let num_replicas_per_device = {
|
|
||||||
let num_cpus = std::thread::available_parallelism()
|
|
||||||
.expect("Failed to read # of cpu")
|
|
||||||
.get();
|
|
||||||
if args.device == Device::Cuda {
|
|
||||||
// When device is cuda, set parallelism to be number of thread.
|
|
||||||
num_cpus
|
|
||||||
} else {
|
|
||||||
// Otherwise, adjust the number based on threads per replica.
|
|
||||||
// https://github.com/OpenNMT/CTranslate2/blob/master/src/utils.cc#L77
|
|
||||||
std::cmp::max(num_cpus / 4, 1)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let options = CTranslate2EngineOptionsBuilder::default()
|
let options = CTranslate2EngineOptionsBuilder::default()
|
||||||
.model_path(model_dir.ctranslate2_dir())
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
.device(device)
|
.device(device)
|
||||||
.model_type(metadata.auto_model.clone())
|
.model_type(metadata.auto_model.clone())
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(num_replicas_per_device)
|
|
||||||
.compute_type(compute_type)
|
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
Box::new(CTranslate2Engine::create(options))
|
Box::new(CTranslate2Engine::create(options))
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ pub struct HealthState {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
chat_model: Option<String>,
|
chat_model: Option<String>,
|
||||||
device: String,
|
device: String,
|
||||||
compute_type: String,
|
|
||||||
arch: String,
|
arch: String,
|
||||||
cpu_info: String,
|
cpu_info: String,
|
||||||
cpu_count: usize,
|
cpu_count: usize,
|
||||||
|
|
@ -34,7 +33,6 @@ impl HealthState {
|
||||||
model: args.model.clone(),
|
model: args.model.clone(),
|
||||||
chat_model: args.chat_model.clone(),
|
chat_model: args.chat_model.clone(),
|
||||||
device: args.device.to_string(),
|
device: args.device.to_string(),
|
||||||
compute_type: args.compute_type.to_string(),
|
|
||||||
arch: ARCH.to_string(),
|
arch: ARCH.to_string(),
|
||||||
cpu_info,
|
cpu_info,
|
||||||
cpu_count,
|
cpu_count,
|
||||||
|
|
|
||||||
|
|
@ -88,36 +88,6 @@ pub enum Device {
|
||||||
ExperimentalHttp,
|
ExperimentalHttp,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
|
||||||
#[clap(rename_all = "snake_case")]
|
|
||||||
pub enum ComputeType {
|
|
||||||
/// Use the fastest computation type that is supported on this system and device
|
|
||||||
#[strum(serialize = "auto")]
|
|
||||||
Auto,
|
|
||||||
|
|
||||||
/// Quantize model weights to use int8 for inference.
|
|
||||||
///
|
|
||||||
/// On CUDA devices, embedding / linear layers runs on int8, while other layers runs on
|
|
||||||
/// float32.
|
|
||||||
#[strum(serialize = "int8")]
|
|
||||||
Int8,
|
|
||||||
|
|
||||||
/// Use float16 for inference, only supported on CUDA devices.
|
|
||||||
#[strum(serialize = "float16")]
|
|
||||||
Float16,
|
|
||||||
|
|
||||||
/// Use int8 / float16 mixed precision for inference, only supported on CUDA devices.
|
|
||||||
///
|
|
||||||
/// This mode is the same as int8 for CUDA devices, but all non quantized layers are run in float16
|
|
||||||
/// instead of float32.
|
|
||||||
#[strum(serialize = "int8_float16")]
|
|
||||||
Int8Float16,
|
|
||||||
|
|
||||||
/// Use float32 for inference.
|
|
||||||
#[strum(serialize = "float32")]
|
|
||||||
Float32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct ServeArgs {
|
pub struct ServeArgs {
|
||||||
/// Model id for `/completions` API endpoint.
|
/// Model id for `/completions` API endpoint.
|
||||||
|
|
@ -140,12 +110,12 @@ pub struct ServeArgs {
|
||||||
device_indices: Vec<i32>,
|
device_indices: Vec<i32>,
|
||||||
|
|
||||||
/// DEPRECATED: Do not use.
|
/// DEPRECATED: Do not use.
|
||||||
#[clap(long)]
|
#[clap(long, hide(true))]
|
||||||
num_replicas_per_device: Option<usize>,
|
num_replicas_per_device: Option<usize>,
|
||||||
|
|
||||||
/// Compute type
|
/// DEPRECATED: Do not use.
|
||||||
#[clap(long, default_value_t=ComputeType::Auto)]
|
#[clap(long, hide(true))]
|
||||||
compute_type: ComputeType,
|
compute_type: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||||
|
|
@ -282,7 +252,7 @@ fn fallback() -> routing::MethodRouter {
|
||||||
|
|
||||||
fn valid_args(args: &ServeArgs) {
|
fn valid_args(args: &ServeArgs) {
|
||||||
if args.num_replicas_per_device.is_some() {
|
if args.num_replicas_per_device.is_some() {
|
||||||
warn!("num_replicas_per_device is deprecated and will be removed in future release.");
|
warn!("--num-replicas-per-device is deprecated and will be removed in future release.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
||||||
|
|
@ -290,11 +260,8 @@ fn valid_args(args: &ServeArgs) {
|
||||||
fatal!("CPU device only supports device indices = [0]");
|
fatal!("CPU device only supports device indices = [0]");
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 {
|
if args.compute_type.is_some() {
|
||||||
match args.compute_type {
|
warn!("--compute-type is deprecated and will be removed in future release.");
|
||||||
ComputeType::Auto | ComputeType::Int8 => {}
|
|
||||||
_ => fatal!("CPU device only supports int8 compute type"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue