feat: add --parallelism to control throughput and vram usage (#727)

* feat: add --parallelism to control throughput and vram usage

* update default

* Revert "update default"

This reverts commit 349792c0d48d913dcd8be4ce1c9d7ce887918f29.

* cargo fmt
refactor-extract-code
Meng Zhang 2023-11-08 10:31:22 -08:00 committed by GitHub
parent 3fb8445747
commit 8ab35b2639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 37 deletions

View File

@ -15,5 +15,5 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
} // namespace

View File

@ -14,17 +14,6 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}
namespace {
int get_parallelism() {
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
if (parallelism) {
return std::stoi(parallelism);
} else {
return 4;
}
}
static size_t N_CONCURRENT_REQUESTS = get_parallelism();
constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.
@ -95,10 +84,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
ctx_(std::move(ctx)),
parallelism_(parallelism) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up
{
batch_.n_tokens = 16;
@ -155,7 +145,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
// Add pending requests.
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
while (pending_requests_.size() > 0 && requests_.size() < parallelism_) {
Request request = std::move(pending_requests_.front());
pending_requests_.pop_front();
@ -283,6 +273,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::vector<Request> requests_;
std::deque<Request> pending_requests_;
std::unordered_set<uint32_t> stopped_requests_;
uint32_t parallelism_;
};
static int g_llama_cpp_log_level = 0;
@ -310,7 +302,7 @@ struct BackendInitializer {
} // namespace
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params();
@ -322,13 +314,14 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS;
ctx_params.n_ctx = N_CTX * parallelism;
ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free)
owned<llama_context>(ctx, llama_free),
parallelism
);
}

View File

@ -23,7 +23,11 @@ mod ffi {
type TextInferenceEngine;
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn create_engine(
use_gpu: bool,
model_path: &str,
parallelism: u8,
) -> UniquePtr<TextInferenceEngine>;
fn add_request(
self: Pin<&mut TextInferenceEngine>,
@ -43,6 +47,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
parallelism: u8,
}
pub struct LlamaTextGeneration {
@ -52,7 +57,7 @@ pub struct LlamaTextGeneration {
impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}

View File

@ -14,8 +14,11 @@ pub async fn create_engine(
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine =
create_ggml_engine(&args.device, model_path.display().to_string().as_str());
let engine = create_ggml_engine(
&args.device,
model_path.display().to_string().as_str(),
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
@ -23,7 +26,7 @@ pub async fn create_engine(
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
(
engine,
EngineInfo {
@ -57,10 +60,15 @@ impl EngineInfo {
}
}
fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
fn create_ggml_engine(
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())
.parallelism(parallelism)
.build()
.unwrap();

View File

@ -121,15 +121,13 @@ pub struct ServeArgs {
#[clap(long, default_value_t=Device::Cpu)]
device: Device,
/// DEPRECATED: Do not use.
#[deprecated(since = "0.5.0")]
#[clap(long, hide(true))]
device_indices: Vec<i32>,
/// Parallelism for model serving - increasing this number will have a significant impact on the
/// memory requirement e.g., GPU vRAM.
#[clap(long, default_value_t = 1)]
parallelism: u8,
}
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
if args.device != Device::ExperimentalHttp {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
@ -252,12 +250,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
.layer(opentelemetry_tracing_layer())
}
fn valid_args(args: &ServeArgs) {
if !args.device_indices.is_empty() {
warn!("--device-indices is deprecated and will be removed in future release.");
}
}
fn start_heartbeat(args: &ServeArgs) {
let state = HealthState::new(args);
tokio::spawn(async move {