feat: add --parallelism to control throughput and vram usage (#727)

* feat: add --parallelism to control throughput and vram usage

* update default

* Revert "update default"

This reverts commit 349792c0d48d913dcd8be4ce1c9d7ce887918f29.

* cargo fmt
refactor-extract-code
Meng Zhang 2023-11-08 10:31:22 -08:00 committed by GitHub
parent 3fb8445747
commit 8ab35b2639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 37 deletions

View File

@ -15,5 +15,5 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0; virtual rust::Vec<StepOutput> step() = 0;
}; };
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path); std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
} // namespace } // namespace

View File

@ -14,17 +14,6 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {} TextInferenceEngine::~TextInferenceEngine() {}
namespace { namespace {
int get_parallelism() {
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
if (parallelism) {
return std::stoi(parallelism);
} else {
return 4;
}
}
static size_t N_CONCURRENT_REQUESTS = get_parallelism();
constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history. constexpr size_t N_CTX = 4096; // # max kv history.
@ -95,10 +84,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
class TextInferenceEngineImpl : public TextInferenceEngine { class TextInferenceEngineImpl : public TextInferenceEngine {
public: public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) : TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
model_(std::move(model)), model_(std::move(model)),
ctx_(std::move(ctx)) { ctx_(std::move(ctx)),
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1); parallelism_(parallelism) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up // warm up
{ {
batch_.n_tokens = 16; batch_.n_tokens = 16;
@ -155,7 +145,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
// Add pending requests. // Add pending requests.
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) { while (pending_requests_.size() > 0 && requests_.size() < parallelism_) {
Request request = std::move(pending_requests_.front()); Request request = std::move(pending_requests_.front());
pending_requests_.pop_front(); pending_requests_.pop_front();
@ -283,6 +273,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::vector<Request> requests_; std::vector<Request> requests_;
std::deque<Request> pending_requests_; std::deque<Request> pending_requests_;
std::unordered_set<uint32_t> stopped_requests_; std::unordered_set<uint32_t> stopped_requests_;
uint32_t parallelism_;
}; };
static int g_llama_cpp_log_level = 0; static int g_llama_cpp_log_level = 0;
@ -310,7 +302,7 @@ struct BackendInitializer {
} // namespace } // namespace
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) { std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
static BackendInitializer initializer; static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
@ -322,13 +314,14 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
} }
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS; ctx_params.n_ctx = N_CTX * parallelism;
ctx_params.n_batch = N_BATCH; ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params); llama_context* ctx = llama_new_context_with_model(model, ctx_params);
return std::make_unique<TextInferenceEngineImpl>( return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model), owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free) owned<llama_context>(ctx, llama_free),
parallelism
); );
} }

View File

@ -23,7 +23,11 @@ mod ffi {
type TextInferenceEngine; type TextInferenceEngine;
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>; fn create_engine(
use_gpu: bool,
model_path: &str,
parallelism: u8,
) -> UniquePtr<TextInferenceEngine>;
fn add_request( fn add_request(
self: Pin<&mut TextInferenceEngine>, self: Pin<&mut TextInferenceEngine>,
@ -43,6 +47,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
pub struct LlamaTextGenerationOptions { pub struct LlamaTextGenerationOptions {
model_path: String, model_path: String,
use_gpu: bool, use_gpu: bool,
parallelism: u8,
} }
pub struct LlamaTextGeneration { pub struct LlamaTextGeneration {
@ -52,7 +57,7 @@ pub struct LlamaTextGeneration {
impl LlamaTextGeneration { impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self { pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path); let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
if engine.is_null() { if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path); fatal!("Unable to load model: {}", options.model_path);
} }

View File

@ -14,8 +14,11 @@ pub async fn create_engine(
if fs::metadata(model_id).is_ok() { if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id); let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH); let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = let engine = create_ggml_engine(
create_ggml_engine(&args.device, model_path.display().to_string().as_str()); &args.device,
model_path.display().to_string().as_str(),
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json")); let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info) (engine, engine_info)
} else { } else {
@ -23,7 +26,7 @@ pub async fn create_engine(
let registry = ModelRegistry::new(registry).await; let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string(); let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name); let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path); let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
( (
engine, engine,
EngineInfo { EngineInfo {
@ -57,10 +60,15 @@ impl EngineInfo {
} }
} }
fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> { fn create_ggml_engine(
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned()) .model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu()) .use_gpu(device.ggml_use_gpu())
.parallelism(parallelism)
.build() .build()
.unwrap(); .unwrap();

View File

@ -121,15 +121,13 @@ pub struct ServeArgs {
#[clap(long, default_value_t=Device::Cpu)] #[clap(long, default_value_t=Device::Cpu)]
device: Device, device: Device,
/// DEPRECATED: Do not use. /// Parallelism for model serving - increasing this number will have a significant impact on the
#[deprecated(since = "0.5.0")] /// memory requirement e.g., GPU vRAM.
#[clap(long, hide(true))] #[clap(long, default_value_t = 1)]
device_indices: Vec<i32>, parallelism: u8,
} }
pub async fn main(config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
if args.device != Device::ExperimentalHttp { if args.device != Device::ExperimentalHttp {
if fs::metadata(&args.model).is_ok() { if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model); info!("Loading model from local path {}", &args.model);
@ -252,12 +250,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
.layer(opentelemetry_tracing_layer()) .layer(opentelemetry_tracing_layer())
} }
fn valid_args(args: &ServeArgs) {
if !args.device_indices.is_empty() {
warn!("--device-indices is deprecated and will be removed in future release.");
}
}
fn start_heartbeat(args: &ServeArgs) { fn start_heartbeat(args: &ServeArgs) {
let state = HealthState::new(args); let state = HealthState::new(args);
tokio::spawn(async move { tokio::spawn(async move {