feat: tune llama metal backend performance (#393)

* feat: support eos based stop

* feat: print performance stats after each inference

* update llama.cpp

* update commits
release-0.2
Meng Zhang 2023-09-05 10:14:29 +08:00 committed by GitHub
parent 9e54eda318
commit e93a971d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 4 deletions

View File

@ -11,6 +11,9 @@ class TextInferenceEngine {
virtual uint32_t start(const rust::Str prompt) const = 0;
virtual uint32_t step(uint32_t next_token_id) const = 0;
virtual void end() const = 0;
virtual uint32_t eos_token() const = 0;
};
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);

@ -1 +1 @@
Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f
Subproject commit 06fc4020de0b92ee13407fdabca7870f53c75de5

View File

@ -37,6 +37,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
uint32_t start(const rust::Str prompt) const override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
eval(tokens_list, /* reset = */ true);
return sample();
@ -47,6 +48,14 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
return sample();
}
void end() const override {
llama_print_timings(ctx_.get());
}
uint32_t eos_token() const override {
return llama_token_eos(ctx_.get());
}
private:
uint32_t sample() const {
auto* ctx = ctx_.get();
@ -65,7 +74,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
tokens_list.data(),
tokens_list.size(),
reset ? 0 : llama_get_kv_cache_token_count(ctx),
/* n_threads = */ 1)) {
/* n_threads = */ 4)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@ -92,7 +101,8 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
static BackendInitializer initializer;
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_gpu_layers = 4;
ctx_params.n_ctx = 2048;
ctx_params.n_gpu_layers = 1;
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);

View File

@ -19,6 +19,9 @@ mod ffi {
fn start(&self, prompt: &str) -> u32;
fn step(&self, next_token_id: u32) -> u32;
fn end(&self);
fn eos_token(&self) -> u32;
}
}
@ -62,7 +65,13 @@ impl TextGeneration for LlamaEngine {
let output_ids = tokio::task::spawn_blocking(move || {
let engine = engine.lock().unwrap();
let eos_token = engine.eos_token();
let mut next_token_id = engine.start(&prompt);
if next_token_id == eos_token {
return Vec::new();
}
let mut n_remains = options.max_decoding_length - 1;
let mut output_ids = vec![next_token_id];
@ -73,6 +82,10 @@ impl TextGeneration for LlamaEngine {
}
next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}
if stop_condition.next_token(next_token_id) {
break;
}
@ -80,11 +93,11 @@ impl TextGeneration for LlamaEngine {
n_remains -= 1;
}
engine.end();
output_ids
})
.await
.expect("Inference failed");
self.tokenizer.decode(&output_ids, true).unwrap()
}
}