temp
parent
a42fde18ac
commit
8c770c6404
|
|
@ -29,6 +29,7 @@ struct Request {
|
|||
std::vector<llama_token> tokens;
|
||||
size_t i_batch = -1;
|
||||
size_t n_past = 0;
|
||||
size_t n_draft = 0;
|
||||
|
||||
int32_t multibyte_pending = 0;
|
||||
std::string generated_text;
|
||||
|
|
@ -213,50 +214,58 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
// FIXME: ensure batching logic always put i_batch - request.n_draft in this batch.
|
||||
for (int k = -request.n_draft; k < 1; ++k) {
|
||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
|
||||
request.n_past += request.tokens.size();
|
||||
request.n_past += request.tokens.size();
|
||||
|
||||
request.tokens.clear();
|
||||
request.tokens.push_back(next_token);
|
||||
request.tokens.clear();
|
||||
request.tokens.push_back(next_token);
|
||||
|
||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||
request.generated_text += token_str;
|
||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||
request.generated_text += token_str;
|
||||
|
||||
// FIXME: Hack for codellama to simplify tabby's implementation.
|
||||
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
||||
// FIXME: Hack for codellama to simplify tabby's implementation.
|
||||
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
||||
|
||||
if (request.multibyte_pending > 0) {
|
||||
request.multibyte_pending -= token_str.size();
|
||||
} else if (token_str.size() == 1) {
|
||||
const char c = token_str[0];
|
||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
request.multibyte_pending = 1;
|
||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF0) == 0xE0) {
|
||||
request.multibyte_pending = 2;
|
||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
request.multibyte_pending = 3;
|
||||
}
|
||||
else {
|
||||
request.multibyte_pending = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (request.multibyte_pending == 0) {
|
||||
rust::String generated_text;
|
||||
try {
|
||||
generated_text = is_eos ? "" : request.generated_text;
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__);
|
||||
if (request.multibyte_pending > 0) {
|
||||
request.multibyte_pending -= token_str.size();
|
||||
} else if (token_str.size() == 1) {
|
||||
const char c = token_str[0];
|
||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
request.multibyte_pending = 1;
|
||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
||||
}
|
||||
else if ((c & 0xF0) == 0xE0) {
|
||||
request.multibyte_pending = 2;
|
||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
request.multibyte_pending = 3;
|
||||
}
|
||||
else {
|
||||
request.multibyte_pending = 0;
|
||||
}
|
||||
}
|
||||
|
||||
result.push_back({request.id, generated_text});
|
||||
request.generated_text.clear();
|
||||
if (request.multibyte_pending == 0) {
|
||||
rust::String generated_text;
|
||||
try {
|
||||
generated_text = is_eos ? "" : request.generated_text;
|
||||
} catch (const std::invalid_argument& e) {
|
||||
fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__);
|
||||
}
|
||||
|
||||
result.push_back({request.id, generated_text});
|
||||
request.generated_text.clear();
|
||||
}
|
||||
|
||||
if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) {
|
||||
// FIXME: shift kv cache
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue