add-token-draft-v2
Meng Zhang 2023-11-29 16:53:08 +08:00
parent 9c905e4849
commit 8c0afa458c
1 changed files with 3 additions and 3 deletions

View File

@ -181,7 +181,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
batch_.n_tokens += request.tokens.size(); batch_.n_tokens += request.tokens.size();
batch_.logits[batch_.n_tokens - 1] = true; batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1; request.i_batch = batch_.n_tokens;
} }
rust::Vec<StepOutput> result; rust::Vec<StepOutput> result;
@ -208,11 +208,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const auto eos_id = llama_token_eos(llama_get_model(ctx)); const auto eos_id = llama_token_eos(llama_get_model(ctx));
for (auto& request : requests_) { for (auto& request : requests_) {
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) { int32_t i_batch = request.i_batch - i - 1;
if ((i_batch < 0) || (i_batch >= n_tokens)) {
continue; continue;
} }
int32_t i_batch = request.i_batch - i;
auto logits = llama_get_logits_ith(ctx, i_batch); auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));