parent
9c905e4849
commit
49864f98c1
|
|
@ -20,20 +20,58 @@ constexpr size_t N_CTX = 4096; // # max kv history.
|
|||
struct Request {
|
||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||
id(request_id),
|
||||
tokens(input_token_ids.begin(), input_token_ids.end()) {
|
||||
}
|
||||
pending_tokens(input_token_ids.begin(), input_token_ids.end()) {
|
||||
}
|
||||
|
||||
uint32_t id = -1;
|
||||
llama_seq_id seq_id = -1;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
std::vector<llama_token> pending_tokens;
|
||||
size_t i_batch = -1;
|
||||
size_t n_past = 0;
|
||||
|
||||
int32_t multibyte_pending = 0;
|
||||
std::string generated_text;
|
||||
};
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
void step(llama_token id) {
|
||||
++n_past;
|
||||
tokens.insert(tokens.end(), pending_tokens.begin(), pending_tokens.end());
|
||||
|
||||
pending_tokens.clear();
|
||||
pending_tokens.push_back(id);
|
||||
}
|
||||
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) {
|
||||
for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
|
||||
if (tokens.size() < ngram_size) continue;
|
||||
std::vector<llama_token> ngram(tokens.begin() + tokens.size() - ngram_size, tokens.end());
|
||||
|
||||
const int matched = find_ngram(ngram, n_pred_tokens);
|
||||
if (matched < 0) continue;
|
||||
|
||||
const int offset = matched + ngram_size;
|
||||
return std::vector<llama_token>(tokens.begin() + offset, tokens.begin() + offset + n_pred_tokens);
|
||||
}
|
||||
|
||||
return std::vector<llama_token>();
|
||||
}
|
||||
|
||||
private:
|
||||
int find_ngram(const std::vector<llama_token> & ngram, size_t n_pred_tokens) {
|
||||
const int max = static_cast<int>(tokens.size()) - ngram.size() - n_pred_tokens;
|
||||
for (int i = 0; i < max; ++i) {
|
||||
const auto mismatch = std::mismatch(tokens.begin() + i, tokens.begin() + i + ngram.size(), ngram.begin());
|
||||
if (mismatch.second == ngram.end()) {
|
||||
// Matched
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||
std::vector<char> result(8, 0);
|
||||
|
|
@ -54,7 +92,7 @@ std::vector<llama_token> llama_tokenize(
|
|||
const rust::Str & text,
|
||||
bool add_bos,
|
||||
bool special) {
|
||||
// upper limit for the number of tokens
|
||||
// upper limit for the number of pending_tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
|
||||
|
|
@ -113,12 +151,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
|
||||
virtual void add_request(uint32_t request_id, rust::Str text, size_t max_input_length) override {
|
||||
auto tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
|
||||
if (tokens.size() > max_input_length) {
|
||||
int start = tokens.size() - max_input_length;
|
||||
tokens = std::vector<llama_token>(tokens.begin() + start, tokens.end());
|
||||
auto pending_tokens = llama_tokenize(llama_get_model(ctx_.get()), text, false, true);
|
||||
if (pending_tokens.size() > max_input_length) {
|
||||
int start = pending_tokens.size() - max_input_length;
|
||||
pending_tokens = std::vector<llama_token>(pending_tokens.begin() + start, pending_tokens.end());
|
||||
}
|
||||
pending_requests_.push_back(Request(request_id, tokens));
|
||||
pending_requests_.push_back(Request(request_id, pending_tokens));
|
||||
}
|
||||
|
||||
void stop_request(uint32_t request_id) override {
|
||||
|
|
@ -168,17 +206,17 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
// Clear the batch.
|
||||
batch_.n_tokens = 0;
|
||||
|
||||
// Insert tokens from ongoing requests to batch.
|
||||
// Insert pending_tokens from ongoing requests to batch.
|
||||
for (auto& request : requests_) {
|
||||
const size_t n_tokens = batch_.n_tokens;
|
||||
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||
batch_.token[n_tokens + i] = request.tokens[i];
|
||||
for (size_t i = 0; i < request.pending_tokens.size(); ++i) {
|
||||
batch_.token[n_tokens + i] = request.pending_tokens[i];
|
||||
batch_.pos[n_tokens + i] = request.n_past + i;
|
||||
batch_.n_seq_id[n_tokens + i] = 1;
|
||||
batch_.seq_id[n_tokens + i][0] = request.id;
|
||||
batch_.logits[n_tokens + i] = false;
|
||||
}
|
||||
batch_.n_tokens += request.tokens.size();
|
||||
batch_.n_tokens += request.pending_tokens.size();
|
||||
|
||||
batch_.logits[batch_.n_tokens - 1] = true;
|
||||
request.i_batch = batch_.n_tokens - 1;
|
||||
|
|
@ -187,7 +225,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
rust::Vec<StepOutput> result;
|
||||
result.reserve(requests_.size());
|
||||
|
||||
// Decode tokens in chunks
|
||||
// Decode pending_tokens in chunks
|
||||
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
|
||||
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
|
|
@ -216,10 +254,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
|
||||
request.n_past += request.tokens.size();
|
||||
|
||||
request.tokens.clear();
|
||||
request.tokens.push_back(next_token);
|
||||
request.step(next_token);
|
||||
|
||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||
request.generated_text += token_str;
|
||||
|
|
|
|||
Loading…
Reference in New Issue