From 49864f98c1f97b1eb835899da0814830d66bbfa0 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 29 Nov 2023 13:45:55 +0800 Subject: [PATCH] implement find_candidate_pred_tokens fix update update --- crates/llama-cpp-bindings/src/engine.cc | 73 ++++++++++++++++++------- 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 2a762b5..6157847 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -20,20 +20,58 @@ constexpr size_t N_CTX = 4096; // # max kv history. struct Request { Request(size_t request_id, std::vector input_token_ids) : id(request_id), - tokens(input_token_ids.begin(), input_token_ids.end()) { - } + pending_tokens(input_token_ids.begin(), input_token_ids.end()) { + } uint32_t id = -1; llama_seq_id seq_id = -1; - std::vector tokens; + std::vector pending_tokens; size_t i_batch = -1; size_t n_past = 0; int32_t multibyte_pending = 0; std::string generated_text; -}; + std::vector tokens; + + void step(llama_token id) { + ++n_past; + tokens.insert(tokens.end(), pending_tokens.begin(), pending_tokens.end()); + + pending_tokens.clear(); + pending_tokens.push_back(id); + } + + std::vector find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) { + for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { + if (tokens.size() < ngram_size) continue; + std::vector ngram(tokens.begin() + tokens.size() - ngram_size, tokens.end()); + + const int matched = find_ngram(ngram, n_pred_tokens); + if (matched < 0) continue; + + const int offset = matched + ngram_size; + return std::vector(tokens.begin() + offset, tokens.begin() + offset + n_pred_tokens); + } + + return std::vector(); + } + + private: + int find_ngram(const std::vector & ngram, size_t n_pred_tokens) { + const int max = static_cast(tokens.size()) - ngram.size() - n_pred_tokens; + for (int i = 0; i < max; ++i) { + const auto mismatch = std::mismatch(tokens.begin() + i, tokens.begin() + i + ngram.size(), ngram.begin()); + if (mismatch.second == ngram.end()) { + // Matched + return i; + } + } + + return -1; + } +}; std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); @@ -54,7 +92,7 @@ std::vector llama_tokenize( const rust::Str & text, bool add_bos, bool special) { - // upper limit for the number of tokens + // upper limit for the number of pending_tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); @@ -113,12 +151,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override { - auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true); - if (tokens.size() > max_input_length) { - int start = tokens.size() - max_input_length; - tokens = std::vector(tokens.begin() + start, tokens.end()); + auto pending_tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true); + if (pending_tokens.size() > max_input_length) { + int start = pending_tokens.size() - max_input_length; + pending_tokens = std::vector(pending_tokens.begin() + start, pending_tokens.end()); } - pending_requests_.push_back(Request(request_id, tokens)); + pending_requests_.push_back(Request(request_id, pending_tokens)); } void stop_request(uint32_t request_id) override { @@ -168,17 +206,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine { // Clear the batch. batch_.n_tokens = 0; - // Insert tokens from ongoing requests to batch. + // Insert pending_tokens from ongoing requests to batch. for (auto& request : requests_) { const size_t n_tokens = batch_.n_tokens; - for (size_t i = 0; i < request.tokens.size(); ++i) { - batch_.token[n_tokens + i] = request.tokens[i]; + for (size_t i = 0; i < request.pending_tokens.size(); ++i) { + batch_.token[n_tokens + i] = request.pending_tokens[i]; batch_.pos[n_tokens + i] = request.n_past + i; batch_.n_seq_id[n_tokens + i] = 1; batch_.seq_id[n_tokens + i][0] = request.id; batch_.logits[n_tokens + i] = false; } - batch_.n_tokens += request.tokens.size(); + batch_.n_tokens += request.pending_tokens.size(); batch_.logits[batch_.n_tokens - 1] = true; request.i_batch = batch_.n_tokens - 1; @@ -187,7 +225,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { rust::Vec result; result.reserve(requests_.size()); - // Decode tokens in chunks + // Decode pending_tokens in chunks for (size_t i = 0; i < static_cast(batch_.n_tokens); i += N_BATCH) { const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i); llama_batch batch_view = { @@ -216,10 +254,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { auto logits = llama_get_logits_ith(ctx, i_batch); auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); - request.n_past += request.tokens.size(); - - request.tokens.clear(); - request.tokens.push_back(next_token); + request.step(next_token); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str;