feat: tune llama metal backend performance (#393)
* feat: support eos based stop * feat: print performance stats after each inference * update llama.cpp * update commitsrelease-0.2
parent
9e54eda318
commit
e93a971d0e
|
|
@ -11,6 +11,9 @@ class TextInferenceEngine {
|
|||
|
||||
virtual uint32_t start(const rust::Str prompt) const = 0;
|
||||
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
||||
virtual void end() const = 0;
|
||||
|
||||
virtual uint32_t eos_token() const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f
|
||||
Subproject commit 06fc4020de0b92ee13407fdabca7870f53c75de5
|
||||
|
|
@ -37,6 +37,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
|
||||
uint32_t start(const rust::Str prompt) const override {
|
||||
auto* ctx = ctx_.get();
|
||||
llama_reset_timings(ctx);
|
||||
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
|
||||
eval(tokens_list, /* reset = */ true);
|
||||
return sample();
|
||||
|
|
@ -47,6 +48,14 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
return sample();
|
||||
}
|
||||
|
||||
void end() const override {
|
||||
llama_print_timings(ctx_.get());
|
||||
}
|
||||
|
||||
uint32_t eos_token() const override {
|
||||
return llama_token_eos(ctx_.get());
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t sample() const {
|
||||
auto* ctx = ctx_.get();
|
||||
|
|
@ -65,7 +74,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
tokens_list.data(),
|
||||
tokens_list.size(),
|
||||
reset ? 0 : llama_get_kv_cache_token_count(ctx),
|
||||
/* n_threads = */ 1)) {
|
||||
/* n_threads = */ 4)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
|
@ -92,7 +101,8 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
|||
static BackendInitializer initializer;
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_gpu_layers = 4;
|
||||
ctx_params.n_ctx = 2048;
|
||||
ctx_params.n_gpu_layers = 1;
|
||||
|
||||
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ mod ffi {
|
|||
|
||||
fn start(&self, prompt: &str) -> u32;
|
||||
fn step(&self, next_token_id: u32) -> u32;
|
||||
fn end(&self);
|
||||
|
||||
fn eos_token(&self) -> u32;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -62,7 +65,13 @@ impl TextGeneration for LlamaEngine {
|
|||
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let engine = engine.lock().unwrap();
|
||||
let eos_token = engine.eos_token();
|
||||
|
||||
let mut next_token_id = engine.start(&prompt);
|
||||
if next_token_id == eos_token {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut n_remains = options.max_decoding_length - 1;
|
||||
let mut output_ids = vec![next_token_id];
|
||||
|
||||
|
|
@ -73,6 +82,10 @@ impl TextGeneration for LlamaEngine {
|
|||
}
|
||||
|
||||
next_token_id = engine.step(next_token_id);
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if stop_condition.next_token(next_token_id) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -80,11 +93,11 @@ impl TextGeneration for LlamaEngine {
|
|||
n_remains -= 1;
|
||||
}
|
||||
|
||||
engine.end();
|
||||
output_ids
|
||||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
|
||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue