temp
parent
a42fde18ac
commit
8c770c6404
|
|
@ -29,6 +29,7 @@ struct Request {
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
size_t i_batch = -1;
|
size_t i_batch = -1;
|
||||||
size_t n_past = 0;
|
size_t n_past = 0;
|
||||||
|
size_t n_draft = 0;
|
||||||
|
|
||||||
int32_t multibyte_pending = 0;
|
int32_t multibyte_pending = 0;
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
|
|
@ -213,6 +214,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
|
||||||
|
for (int k = -request.n_draft; k < 1; ++k) {
|
||||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
|
|
||||||
|
|
@ -258,6 +261,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
result.push_back({request.id, generated_text});
|
result.push_back({request.id, generated_text});
|
||||||
request.generated_text.clear();
|
request.generated_text.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) {
|
||||||
|
// FIXME: shift kv cache
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue