From b2c4635cede31e04eef99e330af0b583f2848ad1 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 30 Nov 2023 14:31:07 +0800 Subject: [PATCH] update --- crates/llama-cpp-bindings/src/engine.cc | 54 +++++++++++++++---------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index f834381..d3bb336 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -16,6 +16,9 @@ TextInferenceEngine::~TextInferenceEngine() {} namespace { constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_CTX = 4096; // # max kv history. + +constexpr size_t DRAFT_N_GRAM_SIZE = 3; +constexpr size_t DRAFT_N_PRED_TOKENS = 10; struct Request { Request(size_t request_id, std::vector input_token_ids) : @@ -34,7 +37,29 @@ struct Request { std::string generated_text; - std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) { + void draft_tokens(int n_draft_quota) { + if (n_draft_quota < DRAFT_N_PRED_TOKENS) { + n_draft = 0; + return; + } + + auto draft = find_candidate_pred_tokens(DRAFT_N_GRAM_SIZE, DRAFT_N_PRED_TOKENS); + n_draft = draft.size(); + tokens.insert(tokens.end(), draft.begin(), draft.end()); + } + + size_t n_past() { + return past_tokens.size(); + } + + void step(llama_token next_token, size_t n_dropped) { + past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped); + tokens.clear(); + tokens.push_back(next_token); + } + + private: + std::vector find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) { if (past_tokens.size() < ngram_size) return {}; std::vector ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end()); @@ -46,17 +71,6 @@ struct Request { return std::vector(begin, begin + n_pred_tokens); } - size_t n_past() { - return past_tokens.size(); - } - - void step(llama_token next_token, size_t n_dropped) { - past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.begin() + tokens.size() - n_dropped); - tokens.clear(); - tokens.push_back(next_token); - } - - private: std::vector past_tokens; }; @@ -197,6 +211,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine { // Insert tokens from ongoing requests to batch. for (auto& request : requests_) { const size_t n_tokens = batch_.n_tokens; + + // Ensure the draft logits always fall into the same batch. + const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH; + request.draft_tokens(n_draft_quota); + for (size_t i = 0; i < request.tokens.size(); ++i) { batch_.token[n_tokens + i] = request.tokens[i]; batch_.pos[n_tokens + i] = request.n_past() + i; @@ -241,12 +260,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine { continue; } - llama_token next_token = -1; - int k = -request.n_draft; - // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. - for (k = -request.n_draft; k < 1; ++k) { + for (int k = -request.n_draft; k < 1; ++k) { auto logits = llama_get_logits_ith(ctx, i_batch + k); - next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); + llama_token next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str; @@ -296,10 +312,6 @@ class TextInferenceEngineImpl : public TextInferenceEngine { break; } } - - auto draft_tokens = request.find_candidate_pred_tokens(); - request.n_draft = draft_tokens.size(); - request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end()); } }