diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 30014f4..3c2e980 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -29,6 +29,7 @@ struct Request { std::vector tokens; size_t i_batch = -1; size_t n_past = 0; + size_t n_draft = 0; int32_t multibyte_pending = 0; std::string generated_text; @@ -213,50 +214,58 @@ class TextInferenceEngineImpl : public TextInferenceEngine { continue; } - auto logits = llama_get_logits_ith(ctx, i_batch); - auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); + // FIXME: ensure batching logic always put i_batch - request.n_draft in this batch. + for (int k = -request.n_draft; k < 1; ++k) { + auto logits = llama_get_logits_ith(ctx, i_batch); + auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); - request.n_past += request.tokens.size(); + request.n_past += request.tokens.size(); - request.tokens.clear(); - request.tokens.push_back(next_token); + request.tokens.clear(); + request.tokens.push_back(next_token); - const auto token_str = llama_token_to_piece(ctx, next_token); - request.generated_text += token_str; + const auto token_str = llama_token_to_piece(ctx, next_token); + request.generated_text += token_str; - // FIXME: Hack for codellama to simplify tabby's implementation. - const bool is_eos = next_token == eos_id || token_str == " "; + // FIXME: Hack for codellama to simplify tabby's implementation. + const bool is_eos = next_token == eos_id || token_str == " "; - if (request.multibyte_pending > 0) { - request.multibyte_pending -= token_str.size(); - } else if (token_str.size() == 1) { - const char c = token_str[0]; - // 2-byte characters: 110xxxxx 10xxxxxx - if ((c & 0xE0) == 0xC0) { - request.multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - } - else if ((c & 0xF0) == 0xE0) { - request.multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - } else if ((c & 0xF8) == 0xF0) { - request.multibyte_pending = 3; - } - else { - request.multibyte_pending = 0; - } - } - - if (request.multibyte_pending == 0) { - rust::String generated_text; - try { - generated_text = is_eos ? "" : request.generated_text; - } catch (const std::invalid_argument& e) { - fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__); + if (request.multibyte_pending > 0) { + request.multibyte_pending -= token_str.size(); + } else if (token_str.size() == 1) { + const char c = token_str[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) { + request.multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF0) == 0xE0) { + request.multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } else if ((c & 0xF8) == 0xF0) { + request.multibyte_pending = 3; + } + else { + request.multibyte_pending = 0; + } } - result.push_back({request.id, generated_text}); - request.generated_text.clear(); + if (request.multibyte_pending == 0) { + rust::String generated_text; + try { + generated_text = is_eos ? "" : request.generated_text; + } catch (const std::invalid_argument& e) { + fprintf(stderr, "%s:%d [%s] - ignoring non utf-8/utf-16 output\n", __FILE__, __LINE__, __func__); + } + + result.push_back({request.id, generated_text}); + request.generated_text.clear(); + } + + if (k < 0 && next_token != request.tokens[request.tokens.size() + k]) { + // FIXME: shift kv cache + break; + } } } }