update
parent
8c770c6404
commit
30b42bf186
|
|
@ -33,6 +33,36 @@ struct Request {
|
||||||
|
|
||||||
int32_t multibyte_pending = 0;
|
int32_t multibyte_pending = 0;
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<llama_token> all_tokens;
|
||||||
|
std::vector<llama_token> find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) {
|
||||||
|
for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
|
||||||
|
if (all_tokens.size() < ngram_size) continue;
|
||||||
|
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
|
||||||
|
|
||||||
|
const int matched = find_ngram(ngram, n_pred_tokens);
|
||||||
|
if (matched < 0) continue;
|
||||||
|
|
||||||
|
const int offset = matched + ngram_size;
|
||||||
|
return std::vector<llama_token>(all_tokens.begin() + offset, all_tokens.begin() + offset + n_pred_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<llama_token>();
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
int find_ngram(const std::vector<llama_token> & ngram, size_t n_pred_tokens) {
|
||||||
|
const int max = static_cast<int>(all_tokens.size()) - ngram.size() - n_pred_tokens;
|
||||||
|
for (int i = 0; i < max; ++i) {
|
||||||
|
const auto mismatch = std::mismatch(all_tokens.begin() + i, all_tokens.begin() + i + ngram.size(), ngram.begin());
|
||||||
|
if (mismatch.second == ngram.end()) {
|
||||||
|
// Matched
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,7 +211,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
}
|
}
|
||||||
batch_.n_tokens += request.tokens.size();
|
batch_.n_tokens += request.tokens.size();
|
||||||
|
|
||||||
batch_.logits[batch_.n_tokens - 1] = true;
|
for (int k = batch_.n_tokens - request.n_draft - 1; k <= batch_.n_tokens - 1; ++ k) {
|
||||||
|
batch_.logits[k] = true;
|
||||||
|
}
|
||||||
request.i_batch = batch_.n_tokens;
|
request.i_batch = batch_.n_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -214,16 +246,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token next_token = -1;
|
||||||
|
size_t n_tokens = request.tokens.size() - request.n_draft;
|
||||||
|
request.all_tokens.insert(request.all_tokens.end(), request.tokens.begin(), request.tokens.begin() + n_tokens);
|
||||||
|
request.n_past += n_tokens;
|
||||||
|
|
||||||
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
|
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
|
||||||
for (int k = -request.n_draft; k < 1; ++k) {
|
for (int k = -request.n_draft; k < 1; ++k) {
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
auto logits = llama_get_logits_ith(ctx, i_batch + k);
|
||||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
|
|
||||||
request.n_past += request.tokens.size();
|
|
||||||
|
|
||||||
request.tokens.clear();
|
|
||||||
request.tokens.push_back(next_token);
|
|
||||||
|
|
||||||
|
request.all_tokens.push_back(next_token);
|
||||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||||
request.generated_text += token_str;
|
request.generated_text += token_str;
|
||||||
|
|
||||||
|
|
@ -262,11 +295,20 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
request.generated_text.clear();
|
request.generated_text.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) {
|
if ((k < 0 && next_token != request.tokens[request.tokens.size() + k])) {
|
||||||
// FIXME: shift kv cache
|
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.n_past += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.tokens.clear();
|
||||||
|
request.tokens.push_back(next_token);
|
||||||
|
|
||||||
|
auto draft_tokens = request.find_candidate_pred_tokens();
|
||||||
|
request.n_draft = draft_tokens.size();
|
||||||
|
request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue