diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 3c2e980..178c32b 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -33,6 +33,36 @@ struct Request { int32_t multibyte_pending = 0; std::string generated_text; + + + std::vector all_tokens; + std::vector find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) { + for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { + if (all_tokens.size() < ngram_size) continue; + std::vector ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end()); + + const int matched = find_ngram(ngram, n_pred_tokens); + if (matched < 0) continue; + + const int offset = matched + ngram_size; + return std::vector(all_tokens.begin() + offset, all_tokens.begin() + offset + n_pred_tokens); + } + + return std::vector(); + } + private: + int find_ngram(const std::vector & ngram, size_t n_pred_tokens) { + const int max = static_cast(all_tokens.size()) - ngram.size() - n_pred_tokens; + for (int i = 0; i < max; ++i) { + const auto mismatch = std::mismatch(all_tokens.begin() + i, all_tokens.begin() + i + ngram.size(), ngram.begin()); + if (mismatch.second == ngram.end()) { + // Matched + return i; + } + } + + return -1; + } }; @@ -181,7 +211,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } batch_.n_tokens += request.tokens.size(); - batch_.logits[batch_.n_tokens - 1] = true; + for (int k = batch_.n_tokens - request.n_draft - 1; k <= batch_.n_tokens - 1; ++ k) { + batch_.logits[k] = true; + } request.i_batch = batch_.n_tokens; } @@ -214,16 +246,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine { continue; } + llama_token next_token = -1; + size_t n_tokens = request.tokens.size() - request.n_draft; + request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); + request.n_past += n_tokens; + // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. for (int k = -request.n_draft; k < 1; ++k) { - auto logits = llama_get_logits_ith(ctx, i_batch); - auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); - - request.n_past += request.tokens.size(); - - request.tokens.clear(); - request.tokens.push_back(next_token); + auto logits = llama_get_logits_ith(ctx, i_batch + k); + next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); + request.all_tokens.push_back(next_token); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str; @@ -262,11 +295,20 @@ class TextInferenceEngineImpl : public TextInferenceEngine { request.generated_text.clear(); } - if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) { - // FIXME: shift kv cache + if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) { + llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1); break; } + + request.n_past += 1; } + + request.tokens.clear(); + request.tokens.push_back(next_token); + + auto draft_tokens = request.find_candidate_pred_tokens(); + request.n_draft = draft_tokens.size(); + request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end()); } }