implement find_candidate_pred_tokens

fix

update

update
add-prompt-lookup
Meng Zhang 2023-11-29 13:45:55 +08:00
parent 9c905e4849
commit 49864f98c1
1 changed files with 54 additions and 19 deletions

View File

@ -20,20 +20,58 @@ constexpr size_t N_CTX = 4096; // # max kv history.
struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
pending_tokens(input_token_ids.begin(), input_token_ids.end()) {
}
uint32_t id = -1;
llama_seq_id seq_id = -1;
std::vector<llama_token> tokens;
std::vector<llama_token> pending_tokens;
size_t i_batch = -1;
size_t n_past = 0;
int32_t multibyte_pending = 0;
std::string generated_text;
};
std::vector<llama_token> tokens;
void step(llama_token id) {
++n_past;
tokens.insert(tokens.end(), pending_tokens.begin(), pending_tokens.end());
pending_tokens.clear();
pending_tokens.push_back(id);
}
std::vector<llama_token> find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) {
for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
if (tokens.size() < ngram_size) continue;
std::vector<llama_token> ngram(tokens.begin() + tokens.size() - ngram_size, tokens.end());
const int matched = find_ngram(ngram, n_pred_tokens);
if (matched < 0) continue;
const int offset = matched + ngram_size;
return std::vector<llama_token>(tokens.begin() + offset, tokens.begin() + offset + n_pred_tokens);
}
return std::vector<llama_token>();
}
private:
int find_ngram(const std::vector<llama_token> & ngram, size_t n_pred_tokens) {
const int max = static_cast<int>(tokens.size()) - ngram.size() - n_pred_tokens;
for (int i = 0; i < max; ++i) {
const auto mismatch = std::mismatch(tokens.begin() + i, tokens.begin() + i + ngram.size(), ngram.begin());
if (mismatch.second == ngram.end()) {
// Matched
return i;
}
}
return -1;
}
};
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
@ -54,7 +92,7 @@ std::vector<llama_token> llama_tokenize(
const rust::Str & text,
bool add_bos,
bool special) {
// upper limit for the number of tokens
// upper limit for the number of pending_tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
@ -113,12 +151,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (tokens.size() > max_input_length) {
int start = tokens.size() - max_input_length;
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
auto pending_tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
if (pending_tokens.size() > max_input_length) {
int start = pending_tokens.size() - max_input_length;
pending_tokens = std::vector<llama_token>(pending_tokens.begin() + start, pending_tokens.end());
}
pending_requests_.push_back(Request(request_id, tokens));
pending_requests_.push_back(Request(request_id, pending_tokens));
}
void stop_request(uint32_t request_id) override {
@ -168,17 +206,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
// Clear the batch.
batch_.n_tokens = 0;
// Insert tokens from ongoing requests to batch.
// Insert pending_tokens from ongoing requests to batch.
for (auto& request : requests_) {
const size_t n_tokens = batch_.n_tokens;
for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i];
for (size_t i = 0; i < request.pending_tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.pending_tokens[i];
batch_.pos[n_tokens + i] = request.n_past + i;
batch_.n_seq_id[n_tokens + i] = 1;
batch_.seq_id[n_tokens + i][0] = request.id;
batch_.logits[n_tokens + i] = false;
}
batch_.n_tokens += request.tokens.size();
batch_.n_tokens += request.pending_tokens.size();
batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1;
@ -187,7 +225,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
rust::Vec<StepOutput> result;
result.reserve(requests_.size());
// Decode tokens in chunks
// Decode pending_tokens in chunks
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
llama_batch batch_view = {
@ -216,10 +254,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
request.n_past += request.tokens.size();
request.tokens.clear();
request.tokens.push_back(next_token);
request.step(next_token);
const auto token_str = llama_token_to_piece(ctx, next_token);
request.generated_text += token_str;