From ce64207ad71e5c0a4ba809620cd4f059633cf81f Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 30 Nov 2023 14:07:23 +0800 Subject: [PATCH] cleanup --- crates/llama-cpp-bindings/src/engine.cc | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 476e0fb..f834381 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -34,8 +34,6 @@ struct Request { std::string generated_text; - std::vector past_tokens; - std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) { if (past_tokens.size() < ngram_size) return {}; std::vector ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end()); @@ -51,6 +49,15 @@ struct Request { size_t n_past() { return past_tokens.size(); } + + void step(llama_token next_token, size_t n_dropped) { + past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.begin() + tokens.size() - n_dropped); + tokens.clear(); + tokens.push_back(next_token); + } + + private: + std::vector past_tokens; }; @@ -235,15 +242,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } llama_token next_token = -1; - size_t n_tokens = request.tokens.size() - request.n_draft - 1; - request.past_tokens.insert(request.past_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); - + int k = -request.n_draft; // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. - for (int k = -request.n_draft; k < 1; ++k) { + for (k = -request.n_draft; k < 1; ++k) { auto logits = llama_get_logits_ith(ctx, i_batch + k); next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); - request.past_tokens.push_back(next_token); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str; @@ -286,15 +290,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine { break; } - if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) { + if ((k == 0) || ((k < 0 && next_token != request.tokens[request.tokens.size() + k]))) { + request.step(next_token, -k); llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past(), -1); break; } } - request.tokens.clear(); - request.tokens.push_back(next_token); - auto draft_tokens = request.find_candidate_pred_tokens(); request.n_draft = draft_tokens.size(); request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end());