feat: add --parallelism to control throughput and vram usage (#727)
* feat: add --parallelism to control throughput and vram usage * update default * Revert "update default" This reverts commit 349792c0d48d913dcd8be4ce1c9d7ce887918f29. * cargo fmtrefactor-extract-code
parent
3fb8445747
commit
8ab35b2639
|
|
@ -15,5 +15,5 @@ class TextInferenceEngine {
|
|||
virtual rust::Vec<StepOutput> step() = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -14,17 +14,6 @@ namespace llama {
|
|||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
namespace {
|
||||
int get_parallelism() {
|
||||
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
|
||||
if (parallelism) {
|
||||
return std::stoi(parallelism);
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static size_t N_CONCURRENT_REQUESTS = get_parallelism();
|
||||
|
||||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||
|
||||
|
|
@ -95,10 +84,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
|||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
|
||||
model_(std::move(model)),
|
||||
ctx_(std::move(ctx)) {
|
||||
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
|
||||
ctx_(std::move(ctx)),
|
||||
parallelism_(parallelism) {
|
||||
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
|
||||
// warm up
|
||||
{
|
||||
batch_.n_tokens = 16;
|
||||
|
|
@ -155,7 +145,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
// Add pending requests.
|
||||
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
|
||||
while (pending_requests_.size() > 0 && requests_.size() < parallelism_) {
|
||||
Request request = std::move(pending_requests_.front());
|
||||
pending_requests_.pop_front();
|
||||
|
||||
|
|
@ -283,6 +273,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
std::vector<Request> requests_;
|
||||
std::deque<Request> pending_requests_;
|
||||
std::unordered_set<uint32_t> stopped_requests_;
|
||||
|
||||
uint32_t parallelism_;
|
||||
};
|
||||
|
||||
static int g_llama_cpp_log_level = 0;
|
||||
|
|
@ -310,7 +302,7 @@ struct BackendInitializer {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
|
||||
static BackendInitializer initializer;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
|
@ -322,13 +314,14 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
|
|||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS;
|
||||
ctx_params.n_ctx = N_CTX * parallelism;
|
||||
ctx_params.n_batch = N_BATCH;
|
||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
return std::make_unique<TextInferenceEngineImpl>(
|
||||
owned<llama_model>(model, llama_free_model),
|
||||
owned<llama_context>(ctx, llama_free)
|
||||
owned<llama_context>(ctx, llama_free),
|
||||
parallelism
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,11 @@ mod ffi {
|
|||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
fn create_engine(
|
||||
use_gpu: bool,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn add_request(
|
||||
self: Pin<&mut TextInferenceEngine>,
|
||||
|
|
@ -43,6 +47,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
|
|||
pub struct LlamaTextGenerationOptions {
|
||||
model_path: String,
|
||||
use_gpu: bool,
|
||||
parallelism: u8,
|
||||
}
|
||||
|
||||
pub struct LlamaTextGeneration {
|
||||
|
|
@ -52,7 +57,7 @@ pub struct LlamaTextGeneration {
|
|||
|
||||
impl LlamaTextGeneration {
|
||||
pub fn new(options: LlamaTextGenerationOptions) -> Self {
|
||||
let engine = create_engine(options.use_gpu, &options.model_path);
|
||||
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
|
||||
if engine.is_null() {
|
||||
fatal!("Unable to load model: {}", options.model_path);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,11 @@ pub async fn create_engine(
|
|||
if fs::metadata(model_id).is_ok() {
|
||||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine =
|
||||
create_ggml_engine(&args.device, model_path.display().to_string().as_str());
|
||||
let engine = create_ggml_engine(
|
||||
&args.device,
|
||||
model_path.display().to_string().as_str(),
|
||||
args.parallelism,
|
||||
);
|
||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
||||
(engine, engine_info)
|
||||
} else {
|
||||
|
|
@ -23,7 +26,7 @@ pub async fn create_engine(
|
|||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(&args.device, &model_path);
|
||||
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
|
|
@ -57,10 +60,15 @@ impl EngineInfo {
|
|||
}
|
||||
}
|
||||
|
||||
fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
|
||||
fn create_ggml_engine(
|
||||
device: &super::Device,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> Box<dyn TextGeneration> {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_path.to_owned())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
.parallelism(parallelism)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -121,15 +121,13 @@ pub struct ServeArgs {
|
|||
#[clap(long, default_value_t=Device::Cpu)]
|
||||
device: Device,
|
||||
|
||||
/// DEPRECATED: Do not use.
|
||||
#[deprecated(since = "0.5.0")]
|
||||
#[clap(long, hide(true))]
|
||||
device_indices: Vec<i32>,
|
||||
/// Parallelism for model serving - increasing this number will have a significant impact on the
|
||||
/// memory requirement e.g., GPU vRAM.
|
||||
#[clap(long, default_value_t = 1)]
|
||||
parallelism: u8,
|
||||
}
|
||||
|
||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
valid_args(args);
|
||||
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
if fs::metadata(&args.model).is_ok() {
|
||||
info!("Loading model from local path {}", &args.model);
|
||||
|
|
@ -252,12 +250,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
.layer(opentelemetry_tracing_layer())
|
||||
}
|
||||
|
||||
fn valid_args(args: &ServeArgs) {
|
||||
if !args.device_indices.is_empty() {
|
||||
warn!("--device-indices is deprecated and will be removed in future release.");
|
||||
}
|
||||
}
|
||||
|
||||
fn start_heartbeat(args: &ServeArgs) {
|
||||
let state = HealthState::new(args);
|
||||
tokio::spawn(async move {
|
||||
|
|
|
|||
Loading…
Reference in New Issue