fix
parent
d1c4db52a8
commit
900e3c4d7b
|
|
@ -36,7 +36,7 @@ struct Request {
|
|||
|
||||
|
||||
std::vector<llama_token> all_tokens;
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 10) {
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
|
||||
if (all_tokens.size() < ngram_size) return {};
|
||||
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
llama_token next_token = -1;
|
||||
size_t n_tokens = request.tokens.size() - request.n_draft;
|
||||
size_t n_tokens = request.tokens.size() - request.n_draft - 1;
|
||||
request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens);
|
||||
request.n_past += n_tokens;
|
||||
|
||||
|
|
@ -280,6 +280,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
request.generated_text.clear();
|
||||
}
|
||||
|
||||
if (is_eos) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) {
|
||||
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1);
|
||||
break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue