add-token-draft
Meng Zhang 2023-11-30 12:51:14 +08:00
parent 900e3c4d7b
commit ce20ae5b77
1 changed files with 15 additions and 13 deletions

View File

@ -28,25 +28,29 @@ struct Request {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
size_t i_batch = -1; size_t i_batch = -1;
size_t n_past = 0;
size_t n_draft = 0; size_t n_draft = 0;
int32_t multibyte_pending = 0; int32_t multibyte_pending = 0;
std::string generated_text; std::string generated_text;
std::vector<llama_token> all_tokens; std::vector<llama_token> past_tokens;
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
if (all_tokens.size() < ngram_size) return {};
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
const auto end = all_tokens.end() - ngram_size - n_pred_tokens; std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
const auto matched = std::search(all_tokens.begin(), end, ngram.begin(), ngram.end()); if (past_tokens.size() < ngram_size) return {};
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
if (matched == end) return {}; if (matched == end) return {};
const auto begin = matched + ngram_size; const auto begin = matched + ngram_size;
return std::vector<llama_token>(begin, begin + n_pred_tokens); return std::vector<llama_token>(begin, begin + n_pred_tokens);
} }
size_t n_past() {
return past_tokens.size();
}
}; };
@ -188,7 +192,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const size_t n_tokens = batch_.n_tokens; const size_t n_tokens = batch_.n_tokens;
for (size_t i = 0; i < request.tokens.size(); ++i) { for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i]; batch_.token[n_tokens + i] = request.tokens[i];
batch_.pos[n_tokens + i] = request.n_past + i; batch_.pos[n_tokens + i] = request.n_past() + i;
batch_.n_seq_id[n_tokens + i] = 1; batch_.n_seq_id[n_tokens + i] = 1;
batch_.seq_id[n_tokens + i][0] = request.id; batch_.seq_id[n_tokens + i][0] = request.id;
batch_.logits[n_tokens + i] = false; batch_.logits[n_tokens + i] = false;
@ -232,18 +236,16 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
llama_token next_token = -1; llama_token next_token = -1;
size_t n_tokens = request.tokens.size() - request.n_draft - 1; size_t n_tokens = request.tokens.size() - request.n_draft - 1;
request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens); request.past_tokens.insert(request.past_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens);
request.n_past += n_tokens;
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
for (int k = -request.n_draft; k < 1; ++k) { for (int k = -request.n_draft; k < 1; ++k) {
auto logits = llama_get_logits_ith(ctx, i_batch + k); auto logits = llama_get_logits_ith(ctx, i_batch + k);
next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
request.all_tokens.push_back(next_token); request.past_tokens.push_back(next_token);
const auto token_str = llama_token_to_piece(ctx, next_token); const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str; request.generated_text += token_str;
request.n_past += 1;
// FIXME: Hack for codellama to simplify tabby's implementation. // FIXME: Hack for codellama to simplify tabby's implementation.
const bool is_eos = next_token == eos_id || token_str == " <EOT>"; const bool is_eos = next_token == eos_id || token_str == " <EOT>";
@ -285,7 +287,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) { if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) {
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1); llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past(), -1);
break; break;
} }
} }