add-token-draft-v2
Meng Zhang 2023-11-29 17:00:58 +08:00
parent a42fde18ac
commit 8c770c6404
1 changed files with 46 additions and 37 deletions

View File

@ -29,6 +29,7 @@ struct Request {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
size_t i_batch = -1; size_t i_batch = -1;
size_t n_past = 0; size_t n_past = 0;
size_t n_draft = 0;
int32_t multibyte_pending = 0; int32_t multibyte_pending = 0;
std::string generated_text; std::string generated_text;
@ -213,6 +214,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
continue; continue;
} }
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
for (int k = -request.n_draft; k < 1; ++k) {
auto logits = llama_get_logits_ith(ctx, i_batch); auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
@ -258,6 +261,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
result.push_back({request.id, generated_text}); result.push_back({request.id, generated_text});
request.generated_text.clear(); request.generated_text.clear();
} }
if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) {
// FIXME: shift kv cache
break;
}
}
} }
} }