From d1c4db52a8c716271b57c9d6821b4b5f0e1a00da Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 29 Nov 2023 19:45:07 +0800 Subject: [PATCH] refactor --- crates/llama-cpp-bindings/src/engine.cc | 35 +++++++------------------ 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 178c32b..54b8096 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -36,32 +36,16 @@ struct Request { std::vector all_tokens; - std::vector find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) { - for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { - if (all_tokens.size() < ngram_size) continue; - std::vector ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end()); + std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 10) { + if (all_tokens.size() < ngram_size) return {}; + std::vector ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end()); - const int matched = find_ngram(ngram, n_pred_tokens); - if (matched < 0) continue; + const auto end = all_tokens.end() - ngram_size - n_pred_tokens; + const auto matched = std::search(all_tokens.begin(), end, ngram.begin(), ngram.end()); + if (matched == end) return {}; - const int offset = matched + ngram_size; - return std::vector(all_tokens.begin() + offset, all_tokens.begin() + offset + n_pred_tokens); - } - - return std::vector(); - } - private: - int find_ngram(const std::vector & ngram, size_t n_pred_tokens) { - const int max = static_cast(all_tokens.size()) - ngram.size() - n_pred_tokens; - for (int i = 0; i < max; ++i) { - const auto mismatch = std::mismatch(all_tokens.begin() + i, all_tokens.begin() + i + ngram.size(), ngram.begin()); - if (mismatch.second == ngram.end()) { - // Matched - return i; - } - } - - return -1; + const auto begin = matched + ngram_size; + return std::vector(begin, begin + n_pred_tokens); } }; @@ -259,6 +243,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { request.all_tokens.push_back(next_token); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str; + request.n_past += 1; // FIXME: Hack for codellama to simplify tabby's implementation. const bool is_eos = next_token == eos_id || token_str == " "; @@ -299,8 +284,6 @@ class TextInferenceEngineImpl : public TextInferenceEngine { llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1); break; } - - request.n_past += 1; } request.tokens.clear();