fix: llama.cpp requires kv cache to be N_CTX * parallelism (#714)
parent
9344c32b31
commit
eb7ae96157
|
|
@ -79,6 +79,17 @@ std::vector<llama_token> llama_tokenize(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename ... Args>
|
||||||
|
std::string string_format(const std::string& format, Args ... args)
|
||||||
|
{
|
||||||
|
int size_s = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; // Extra space for '\0'
|
||||||
|
if (size_s <= 0) { throw std::runtime_error("Error during formatting."); }
|
||||||
|
auto size = static_cast<size_t>(size_s);
|
||||||
|
std::unique_ptr<char[]> buf(new char[size]);
|
||||||
|
std::snprintf(buf.get(), size, format.c_str(), args ...);
|
||||||
|
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
|
||||||
|
}
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
||||||
|
|
@ -202,7 +213,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
throw std::runtime_error("Failed to eval");
|
throw std::runtime_error(string_format("llama_decode failed with code: %d", ret));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
const auto eos_id = llama_token_eos(llama_get_model(ctx));
|
||||||
|
|
@ -311,7 +322,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.n_ctx = N_CTX;
|
ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS;
|
||||||
ctx_params.n_batch = N_BATCH;
|
ctx_params.n_batch = N_BATCH;
|
||||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue