update
parent
ce64207ad7
commit
b2c4635ced
|
|
@ -16,6 +16,9 @@ TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||||
|
|
||||||
|
constexpr size_t DRAFT_N_GRAM_SIZE = 3;
|
||||||
|
constexpr size_t DRAFT_N_PRED_TOKENS = 10;
|
||||||
|
|
||||||
struct Request {
|
struct Request {
|
||||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||||
|
|
@ -34,7 +37,29 @@ struct Request {
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
|
|
||||||
|
|
||||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 8) {
|
void draft_tokens(int n_draft_quota) {
|
||||||
|
if (n_draft_quota < DRAFT_N_PRED_TOKENS) {
|
||||||
|
n_draft = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto draft = find_candidate_pred_tokens(DRAFT_N_GRAM_SIZE, DRAFT_N_PRED_TOKENS);
|
||||||
|
n_draft = draft.size();
|
||||||
|
tokens.insert(tokens.end(), draft.begin(), draft.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t n_past() {
|
||||||
|
return past_tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void step(llama_token next_token, size_t n_dropped) {
|
||||||
|
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.end() - n_dropped);
|
||||||
|
tokens.clear();
|
||||||
|
tokens.push_back(next_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size, size_t n_pred_tokens) {
|
||||||
if (past_tokens.size() < ngram_size) return {};
|
if (past_tokens.size() < ngram_size) return {};
|
||||||
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
|
std::vector<llama_token> ngram(past_tokens.begin() + past_tokens.size() - ngram_size, past_tokens.end());
|
||||||
|
|
||||||
|
|
@ -46,17 +71,6 @@ struct Request {
|
||||||
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t n_past() {
|
|
||||||
return past_tokens.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
void step(llama_token next_token, size_t n_dropped) {
|
|
||||||
past_tokens.insert(past_tokens.end(), tokens.begin(), tokens.begin() + tokens.size() - n_dropped);
|
|
||||||
tokens.clear();
|
|
||||||
tokens.push_back(next_token);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<llama_token> past_tokens;
|
std::vector<llama_token> past_tokens;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -197,6 +211,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
// Insert tokens from ongoing requests to batch.
|
// Insert tokens from ongoing requests to batch.
|
||||||
for (auto& request : requests_) {
|
for (auto& request : requests_) {
|
||||||
const size_t n_tokens = batch_.n_tokens;
|
const size_t n_tokens = batch_.n_tokens;
|
||||||
|
|
||||||
|
// Ensure the draft logits always fall into the same batch.
|
||||||
|
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
|
||||||
|
request.draft_tokens(n_draft_quota);
|
||||||
|
|
||||||
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||||
batch_.token[n_tokens + i] = request.tokens[i];
|
batch_.token[n_tokens + i] = request.tokens[i];
|
||||||
batch_.pos[n_tokens + i] = request.n_past() + i;
|
batch_.pos[n_tokens + i] = request.n_past() + i;
|
||||||
|
|
@ -241,12 +260,9 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token next_token = -1;
|
for (int k = -request.n_draft; k < 1; ++k) {
|
||||||
int k = -request.n_draft;
|
|
||||||
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
|
|
||||||
for (k = -request.n_draft; k < 1; ++k) {
|
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch + k);
|
auto logits = llama_get_logits_ith(ctx, i_batch + k);
|
||||||
next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
llama_token next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
|
|
||||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||||
request.generated_text += token_str;
|
request.generated_text += token_str;
|
||||||
|
|
@ -296,10 +312,6 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto draft_tokens = request.find_candidate_pred_tokens();
|
|
||||||
request.n_draft = draft_tokens.size();
|
|
||||||
request.tokens.insert(request.tokens.end(), draft_tokens.begin(), draft_tokens.end());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue