fix ngram build
parent
b2c4635ced
commit
de96d1b6af
|
|
@ -54,14 +54,16 @@ struct Request {
|
||||||
|
|
||||||
void step(llama_token next_token, size_t n_dropped) {
|
void step(llama_token next_token, size_t n_dropped) {
|
||||||
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);
|
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);
|
||||||
|
|
||||||
|
n_draft = 0;
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
tokens.push_back(next_token);
|
tokens.push_back(next_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
|
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
|
||||||
if (past_tokens.size() < ngram_size) return {};
|
auto ngram = build_ngram(ngram_size);
|
||||||
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
|
if (ngram.size() < ngram_size) return {};
|
||||||
|
|
||||||
const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
|
const auto end = past_tokens.end() - ngram_size - n_pred_tokens;
|
||||||
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
|
const auto matched = std::search(past_tokens.begin(), end, ngram.begin(), ngram.end());
|
||||||
|
|
@ -71,6 +73,22 @@ struct Request {
|
||||||
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> build_ngram(size_t ngram_size) {
|
||||||
|
GGML_ASSERT(n_draft == 0);
|
||||||
|
std::deque<llama_token> ret;
|
||||||
|
for (int i = tokens.size() - 1; i >= 0; --i) {
|
||||||
|
if (ret.size() == ngram_size) break;
|
||||||
|
ret.push_front(tokens[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = past_tokens.size() - 1; i >= 0; --i) {
|
||||||
|
if (ret.size() == ngram_size) break;
|
||||||
|
ret.push_front(past_tokens[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<llama_token>(ret.begin(), ret.end());
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<llama_token> past_tokens;
|
std::vector<llama_token> past_tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue