update
parent
9c905e4849
commit
8c0afa458c
|
|
@ -181,7 +181,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
batch_.n_tokens += request.tokens.size();
|
batch_.n_tokens += request.tokens.size();
|
||||||
|
|
||||||
batch_.logits[batch_.n_tokens - 1] = true;
|
batch_.logits[batch_.n_tokens - 1] = true;
|
||||||
request.i_batch = batch_.n_tokens - 1;
|
request.i_batch = batch_.n_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
rust::Vec<StepOutput> result;
|
rust::Vec<StepOutput> result;
|
||||||
|
|
@ -208,11 +208,11 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
|
|
||||||
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
||||||
for (auto& request : requests_) {
|
for (auto& request : requests_) {
|
||||||
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
|
int32_t i_batch = request.i_batch - i - 1;
|
||||||
|
if ((i_batch < 0) || (i_batch >= n_tokens)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t i_batch = request.i_batch - i;
|
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue