From de96d1b6af271a979ffb4eed4544617d31749590 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 30 Nov 2023 15:26:23 +0800 Subject: [PATCH] fix ngram build --- crates/llama-cpp-bindings/src/engine.cc | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index d3bb336..9a9b7b6 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -54,14 +54,16 @@ struct Request { void step(llama_token next_token, size_t n_dropped) { past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped); + + n_draft = 0; tokens.clear(); tokens.push_back(next_token); } private: std::vector find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) { - if (past_tokens.size() < ngram_size) return {}; - std::vector ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end()); + auto ngram = build_ngram(ngram_size); + if (ngram.size() < ngram_size) return {}; const auto end = past_tokens.end() - ngram_size - n_pred_tokens; const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end()); @@ -71,6 +73,22 @@ struct Request { return std::vector(begin, begin + n_pred_tokens); } + std::vector build_ngram(size_t ngram_size) { + GGML_ASSERT(n_draft == 0); + std::deque ret; + for (int i = tokens.size() - 1; i >= 0; --i) { + if (ret.size() == ngram_size) break; + ret.push_front(tokens[i]); + } + + for (int i = past_tokens.size() - 1; i >= 0; --i) { + if (ret.size() == ngram_size) break; + ret.push_front(past_tokens[i]); + } + + return std::vector(ret.begin(), ret.end()); + } + std::vector past_tokens; };