fix ngram build
parent
b2c4635ced
commit
de96d1b6af
|
|
@ -54,14 +54,16 @@ struct Request {
|
|||
|
||||
void step(llama_token next_token, size_t n_dropped) {
|
||||
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);
|
||||
|
||||
n_draft = 0;
|
||||
tokens.clear();
|
||||
tokens.push_back(next_token);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
|
||||
if (past_tokens.size() < ngram_size) return {};
|
||||
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
|
||||
auto ngram = build_ngram(ngram_size);
|
||||
if (ngram.size() < ngram_size) return {};
|
||||
|
||||
const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
|
||||
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
|
||||
|
|
@ -71,6 +73,22 @@ struct Request {
|
|||
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
||||
}
|
||||
|
||||
std::vector<llama_token> build_ngram(size_t ngram_size) {
|
||||
GGML_ASSERT(n_draft == 0);
|
||||
std::deque<llama_token> ret;
|
||||
for (int i = tokens.size() - 1; i >= 0; --i) {
|
||||
if (ret.size() == ngram_size) break;
|
||||
ret.push_front(tokens[i]);
|
||||
}
|
||||
|
||||
for (int i = past_tokens.size() - 1; i >= 0; --i) {
|
||||
if (ret.size() == ngram_size) break;
|
||||
ret.push_front(past_tokens[i]);
|
||||
}
|
||||
|
||||
return std::vector<llama_token>(ret.begin(), ret.end());
|
||||
}
|
||||
|
||||
std::vector<llama_token> past_tokens;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue