diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 54b8096..10b1bb7 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -36,7 +36,7 @@ struct Request { std::vector all_tokens; - std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 10) { + std::vector find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) { if (all_tokens.size() < ngram_size) return {}; std::vector ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end()); @@ -231,7 +231,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } llama_token next_token = -1; - size_t n_tokens = request.tokens.size() - request.n_draft; + size_t n_tokens = request.tokens.size() - request.n_draft - 1; request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); request.n_past += n_tokens; @@ -280,6 +280,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine { request.generated_text.clear(); } + if (is_eos) { + break; + } + if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) { llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1); break;