update
parent
9c905e4849
commit
8c0afa458c
|
|
@ -181,7 +181,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
batch_.n_tokens += request.tokens.size();
|
||||
|
||||
batch_.logits[batch_.n_tokens - 1] = true;
|
||||
request.i_batch = batch_.n_tokens - 1;
|
||||
request.i_batch = batch_.n_tokens;
|
||||
}
|
||||
|
||||
rust::Vec<StepOutput> result;
|
||||
|
|
@ -208,11 +208,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
|
||||
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
||||
for (auto& request : requests_) {
|
||||
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
|
||||
int32_t i_batch = request.i_batch - i - 1;
|
||||
if ((i_batch < 0) || (i_batch >= n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int32_t i_batch = request.i_batch - i;
|
||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue