fix: llama.cpp requires kv cache to be N_CTX * parallelism (#714)

refactor-extract-code
Meng Zhang 2023-11-06 22:16:36 -08:00 committed by GitHub
parent 9344c32b31
commit eb7ae96157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 2 deletions

View File

@ -79,6 +79,17 @@ std::vector<llama_token> llama_tokenize(
return result; return result;
} }
template<typename ... Args>
std::string string_format(const std::string& format, Args ... args)
{
int size_s = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; // Extra space for '\0'
if (size_s <= 0) { throw std::runtime_error("Error during formatting."); }
auto size = static_cast<size_t>(size_s);
std::unique_ptr<char[]> buf(new char[size]);
std::snprintf(buf.get(), size, format.c_str(), args ...);
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
}
template<class T> template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>; using owned = std::unique_ptr<T, std::function<void(T*)>>;
@ -202,7 +213,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const int ret = llama_decode(ctx, batch_view); const int ret = llama_decode(ctx, batch_view);
if (ret != 0) { if (ret != 0) {
throw std::runtime_error("Failed to eval"); throw std::runtime_error(string_format("llama_decode failed with code: %d", ret));
} }
const auto eos_id = llama_token_eos(llama_get_model(ctx)); const auto eos_id = llama_token_eos(llama_get_model(ctx));
@ -311,7 +322,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
} }
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = N_CTX; ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS;
ctx_params.n_batch = N_BATCH; ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params); llama_context* ctx = llama_new_context_with_model(model, ctx_params);