add-token-draft-v2
Meng Zhang 2023-11-29 20:18:11 +08:00
parent d1c4db52a8
commit 900e3c4d7b
1 changed files with 6 additions and 2 deletions

View File

@ -36,7 +36,7 @@ struct Request {
std::vector<llama_token> all_tokens;
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 10) {
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
if (all_tokens.size() < ngram_size) return {};
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
@ -231,7 +231,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
llama_token next_token = -1;
size_t n_tokens = request.tokens.size() - request.n_draft;
size_t n_tokens = request.tokens.size() - request.n_draft - 1;
request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens);
request.n_past += n_tokens;
@ -280,6 +280,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
request.generated_text.clear();
}
if (is_eos) {
break;
}
if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) {
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1);
break;