fix: output unicode characters error (#925)
parent
79e704458d
commit
2c2c95ccd7
|
|
@ -17,7 +17,6 @@ TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||||
|
|
||||||
struct Request {
|
struct Request {
|
||||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||||
id(request_id),
|
id(request_id),
|
||||||
|
|
@ -31,7 +30,6 @@ struct Request {
|
||||||
size_t i_batch = -1;
|
size_t i_batch = -1;
|
||||||
size_t n_past = 0;
|
size_t n_past = 0;
|
||||||
|
|
||||||
int32_t multibyte_pending = 0;
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -230,27 +228,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
// FIXME: Hack for codellama to simplify tabby's implementation.
|
// FIXME: Hack for codellama to simplify tabby's implementation.
|
||||||
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
||||||
|
|
||||||
if (request.multibyte_pending > 0) {
|
bool incomplete = false;
|
||||||
request.multibyte_pending -= token_str.size();
|
for (size_t i = 1; i < 5 && i <= request.generated_text.size(); ++i)
|
||||||
} else if (token_str.size() == 1) {
|
{
|
||||||
const char c = token_str[0];
|
const char c = request.generated_text[request.generated_text.size() - i];
|
||||||
// 2-byte characters: 110xxxxx 10xxxxxx
|
if ((c & 0xC0) == 0x80)
|
||||||
if ((c & 0xE0) == 0xC0) {
|
{
|
||||||
request.multibyte_pending = 1;
|
// continuation byte: 10xxxxxx
|
||||||
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx
|
continue;
|
||||||
}
|
}
|
||||||
else if ((c & 0xF0) == 0xE0) {
|
else if ((c & 0xE0) == 0xC0)
|
||||||
request.multibyte_pending = 2;
|
{
|
||||||
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
|
// 2-byte character: 110xxxxx ...
|
||||||
} else if ((c & 0xF8) == 0xF0) {
|
incomplete = i < 2;
|
||||||
request.multibyte_pending = 3;
|
}
|
||||||
}
|
else if ((c & 0xF0) == 0xE0)
|
||||||
else {
|
{
|
||||||
request.multibyte_pending = 0;
|
// 3-byte character: 1110xxxx ...
|
||||||
}
|
incomplete = i < 3;
|
||||||
|
}
|
||||||
|
else if ((c & 0xF8) == 0xF0)
|
||||||
|
{
|
||||||
|
// 4-byte character: 11110xxx ...
|
||||||
|
incomplete = i < 4;
|
||||||
|
}
|
||||||
|
// else 1-byte character or invalid byte
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (request.multibyte_pending == 0) {
|
if (!incomplete) {
|
||||||
rust::String generated_text;
|
rust::String generated_text;
|
||||||
try {
|
try {
|
||||||
generated_text = is_eos ? "" : request.generated_text;
|
generated_text = is_eos ? "" : request.generated_text;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue