fix: output unicode characters error (#925)

add-signin-page
xcnick 2023-12-01 12:18:26 +08:00 committed by GitHub
parent 79e704458d
commit 2c2c95ccd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 21 deletions

View File

@ -17,7 +17,6 @@ TextInferenceEngine::~TextInferenceEngine() {}
namespace { namespace {
constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history. constexpr size_t N_CTX = 4096; // # max kv history.
struct Request { struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) : Request(size_t request_id, std::vector<llama_token> input_token_ids) :
id(request_id), id(request_id),
@ -31,7 +30,6 @@ struct Request {
size_t i_batch = -1; size_t i_batch = -1;
size_t n_past = 0; size_t n_past = 0;
int32_t multibyte_pending = 0;
std::string generated_text; std::string generated_text;
}; };
@ -230,27 +228,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
// FIXME: Hack for codellama to simplify tabby's implementation. // FIXME: Hack for codellama to simplify tabby's implementation.
const bool is_eos = next_token == eos_id || token_str == " <EOT>"; const bool is_eos = next_token == eos_id || token_str == " <EOT>";
if (request.multibyte_pending > 0) { bool incomplete = false;
request.multibyte_pending -= token_str.size(); for (size_t i = 1; i < 5 && i <= request.generated_text.size(); ++i)
} else if (token_str.size() == 1) { {
const char c = token_str[0]; const char c = request.generated_text[request.generated_text.size() - i];
// 2-byte characters: 110xxxxx 10xxxxxx if ((c & 0xC0) == 0x80)
if ((c & 0xE0) == 0xC0) { {
request.multibyte_pending = 1; // continuation byte: 10xxxxxx
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx continue;
} }
else if ((c & 0xF0) == 0xE0) { else if ((c & 0xE0) == 0xC0)
request.multibyte_pending = 2; {
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx // 2-byte character: 110xxxxx ...
} else if ((c & 0xF8) == 0xF0) { incomplete = i < 2;
request.multibyte_pending = 3; }
} else if ((c & 0xF0) == 0xE0)
else { {
request.multibyte_pending = 0; // 3-byte character: 1110xxxx ...
} incomplete = i < 3;
}
else if ((c & 0xF8) == 0xF0)
{
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
} }
if (request.multibyte_pending == 0) { if (!incomplete) {
rust::String generated_text; rust::String generated_text;
try { try {
generated_text = is_eos ? "" : request.generated_text; generated_text = is_eos ? "" : request.generated_text;