diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 10b1bb7..476e0fb 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -28,25 +28,29 @@ struct Request { std::vector tokens; size_t i_batch = -1; - size_t n_past = 0; size_t n_draft = 0; int32_t multibyte_pending = 0; std::string generated_text; - std::vector all_tokens; - std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) { - if (all_tokens.size() < ngram_size) return {}; - std::vector ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end()); + std::vector past_tokens; - const auto end = all_tokens.end() - ngram_size - n_pred_tokens; - const auto matched = std::search(all_tokens.begin(), end, ngram.begin(), ngram.end()); + std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) { + if (past_tokens.size() < ngram_size) return {}; + std::vector ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end()); + + const auto end = past_tokens.end() - ngram_size - n_pred_tokens; + const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end()); if (matched == end) return {}; const auto begin = matched + ngram_size; return std::vector(begin, begin + n_pred_tokens); } + + size_t n_past() { + return past_tokens.size(); + } }; @@ -188,7 +192,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { const size_t n_tokens = batch_.n_tokens; for (size_t i = 0; i < request.tokens.size(); ++i) { batch_.token[n_tokens + i] = request.tokens[i]; - batch_.pos[n_tokens + i] = request.n_past + i; + batch_.pos[n_tokens + i] = request.n_past() + i; batch_.n_seq_id[n_tokens + i] = 1; batch_.seq_id[n_tokens + i][0] = request.id; batch_.logits[n_tokens + i] = false; @@ -232,18 +236,16 @@ class TextInferenceEngineImpl : public TextInferenceEngine { llama_token next_token = -1; size_t n_tokens = request.tokens.size() - request.n_draft - 1; - request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); - request.n_past += n_tokens; + request.past_tokens.insert(request.past_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. for (int k = -request.n_draft; k < 1; ++k) { auto logits = llama_get_logits_ith(ctx, i_batch + k); next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); - request.all_tokens.push_back(next_token); + request.past_tokens.push_back(next_token); const auto token_str = llama_token_to_piece(ctx, next_token); request.generated_text += token_str; - request.n_past += 1; // FIXME: Hack for codellama to simplify tabby's implementation. const bool is_eos = next_token == eos_id || token_str == " "; @@ -285,7 +287,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) { - llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1); + llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past(), -1); break; } }