fix ngram build

add-token-draft
Meng Zhang 2023-11-30 15:26:23 +08:00
parent b2c4635ced
commit de96d1b6af
1 changed files with 20 additions and 2 deletions

View File

@ -54,14 +54,16 @@ struct Request {
void step(llama_token next_token, size_t n_dropped) {
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);
n_draft = 0;
tokens.clear();
tokens.push_back(next_token);
}
private:
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
if (past_tokens.size() < ngram_size) return {};
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
auto ngram = build_ngram(ngram_size);
if (ngram.size() < ngram_size) return {};
const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
@ -71,6 +73,22 @@ struct Request {
return std::vector<llama_token>(begin, begin + n_pred_tokens);
}
std::vector<llama_token> build_ngram(size_t ngram_size) {
GGML_ASSERT(n_draft == 0);
std::deque<llama_token> ret;
for (int i = tokens.size() - 1; i >= 0; --i) {
if (ret.size() == ngram_size) break;
ret.push_front(tokens[i]);
}
for (int i = past_tokens.size() - 1; i >= 0; --i) {
if (ret.size() == ngram_size) break;
ret.push_front(past_tokens[i]);
}
return std::vector<llama_token>(ret.begin(), ret.end());
}
std::vector<llama_token> past_tokens;
};