From eb7ae96157d4d4d67d9c32eaa611cf0c8da7b892 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 6 Nov 2023 22:16:36 -0800 Subject: [PATCH] fix: llama.cpp requires kv cache to be N_CTX * parallelism (#714) --- crates/llama-cpp-bindings/src/engine.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index ac9d8ea..fabec66 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -79,6 +79,17 @@ std::vector llama_tokenize( return result; } +template +std::string string_format(const std::string& format, Args ... args) +{ + int size_s = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; // Extra space for '\0' + if (size_s <= 0) { throw std::runtime_error("Error during formatting."); } + auto size = static_cast(size_s); + std::unique_ptr buf(new char[size]); + std::snprintf(buf.get(), size, format.c_str(), args ...); + return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside +} + template using owned = std::unique_ptr>; @@ -202,7 +213,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { const int ret = llama_decode(ctx, batch_view); if (ret != 0) { - throw std::runtime_error("Failed to eval"); + throw std::runtime_error(string_format("llama_decode failed with code: %d", ret)); } const auto eos_id = llama_token_eos(llama_get_model(ctx)); @@ -311,7 +322,7 @@ std::unique_ptr create_engine(bool use_gpu, rust::Str model } llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = N_CTX; + ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS; ctx_params.n_batch = N_BATCH; llama_context* ctx = llama_new_context_with_model(model, ctx_params);