add-token-draft
Meng Zhang 2023-11-30 14:07:23 +08:00
parent ce20ae5b77
commit ce64207ad7
1 changed files with 13 additions and 11 deletions

View File

@ -34,8 +34,6 @@ struct Request {
std::string generated_text;
std::vector<llama_token> past_tokens;
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
if (past_tokens.size() < ngram_size) return {};
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
@ -51,6 +49,15 @@ struct Request {
size_t n_past() {
return past_tokens.size();
}
void step(llama_token next_token, size_t n_dropped) {
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.begin() + tokens.size() - n_dropped);
tokens.clear();
tokens.push_back(next_token);
}
private:
std::vector<llama_token> past_tokens;
};
@ -235,15 +242,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
llama_token next_token = -1;
size_t n_tokens = request.tokens.size() - request.n_draft - 1;
request.past_tokens.insert(request.past_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens);
int k = -request.n_draft;
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
for (int k = -request.n_draft; k < 1; ++k) {
for (k = -request.n_draft; k < 1; ++k) {
auto logits = llama_get_logits_ith(ctx, i_batch + k);
next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
request.past_tokens.push_back(next_token);
const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;
@ -286,15 +290,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
break;
}
if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) {
if ((k == 0) || ((k < 0 && next_token != request.tokens[request.tokens.size() + k]))) {
request.step(next_token, -k);
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past(), -1);
break;
}
}
request.tokens.clear();
request.tokens.push_back(next_token);
auto draft_tokens = request.find_candidate_pred_tokens();
request.n_draft = draft_tokens.size();
request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end());