diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 7b3f603..468584b 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -17,7 +17,6 @@ TextInferenceEngine::~TextInferenceEngine() {} namespace { constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_CTX = 4096; // # max kv history. - struct Request { Request(size_t request_id, std::vector input_token_ids) : id(request_id), @@ -31,7 +30,6 @@ struct Request { size_t i_batch = -1; size_t n_past = 0; - int32_t multibyte_pending = 0; std::string generated_text; }; @@ -230,27 +228,35 @@ class TextInferenceEngineImpl : public TextInferenceEngine { // FIXME: Hack for codellama to simplify tabby's implementation. const bool is_eos = next_token == eos_id || token_str == " "; - if (request.multibyte_pending > 0) { - request.multibyte_pending -= token_str.size(); - } else if (token_str.size() == 1) { - const char c = token_str[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) { - request.multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF0) == 0xE0) { - request.multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF8) == 0xF0) { - request.multibyte_pending = 3; - } - else { - request.multibyte_pending = 0; - } + bool incomplete = false; + for (size_t i = 1; i < 5 && i <= request.generated_text.size(); ++i) + { + const char c = request.generated_text[request.generated_text.size() - i]; + if ((c & 0xC0) == 0x80) + { + // continuation byte: 10xxxxxx + continue; + } + else if ((c & 0xE0) == 0xC0) + { + // 2-byte character: 110xxxxx ... + incomplete = i < 2; + } + else if ((c & 0xF0) == 0xE0) + { + // 3-byte character: 1110xxxx ... + incomplete = i < 3; + } + else if ((c & 0xF8) == 0xF0) + { + // 4-byte character: 11110xxx ... + incomplete = i < 4; + } + // else 1-byte character or invalid byte + break; } - if (request.multibyte_pending == 0) { + if (!incomplete) { rust::String generated_text; try { generated_text = is_eos ? "" : request.generated_text;