From 8c0afa458c99129e485f05904309c746df622fa0 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 29 Nov 2023 16:53:08 +0800 Subject: [PATCH] update --- crates/llama-cpp-bindings/src/engine.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 2a762b5..9ac9b9a 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -181,7 +181,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { batch_.n_tokens += request.tokens.size(); batch_.logits[batch_.n_tokens - 1] = true; - request.i_batch = batch_.n_tokens - 1; + request.i_batch = batch_.n_tokens; } rust::Vec result; @@ -208,11 +208,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine { const auto eos_id = llama_token_eos(llama_get_model(ctx)); for (auto& request : requests_) { - if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) { + int32_t i_batch = request.i_batch - i - 1; + if ((i_batch < 0) || (i_batch >= n_tokens)) { continue; } - int32_t i_batch = request.i_batch - i; auto logits = llama_get_logits_ith(ctx, i_batch); auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));