add-token-draft-v2
Meng Zhang 2023-11-29 16:53:08 +08:00
parent 9c905e4849
commit 8c0afa458c
1 changed files with 3 additions and 3 deletions

View File

@ -181,7 +181,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
batch_.n_tokens += request.tokens.size();
batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1;
request.i_batch = batch_.n_tokens;
}
rust::Vec<StepOutput> result;
@ -208,11 +208,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const auto eos_id = llama_token_eos(llama_get_model(ctx));
for (auto& request : requests_) {
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
int32_t i_batch = request.i_batch - i - 1;
if ((i_batch < 0) || (i_batch >= n_tokens)) {
continue;
}
int32_t i_batch = request.i_batch - i;
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));