2023-09-03 01:59:07 +00:00
|
|
|
#include "engine.h"
|
|
|
|
|
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include <ggml.h>
|
|
|
|
|
#include <llama.h>
|
|
|
|
|
|
|
|
|
|
namespace llama {
|
|
|
|
|
TextInferenceEngine::~TextInferenceEngine() {}
|
|
|
|
|
|
|
|
|
|
namespace {
|
2023-10-02 15:39:15 +00:00
|
|
|
static size_t N_BATCH = 512; // # per batch inference.
|
|
|
|
|
static size_t N_CTX = 4096; // # max kv history.
|
2023-09-08 16:20:51 +00:00
|
|
|
|
2023-09-03 01:59:07 +00:00
|
|
|
template<class T>
|
|
|
|
|
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
|
|
|
|
|
|
|
|
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
|
|
|
|
public:
|
|
|
|
|
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
|
|
|
|
model_(std::move(model)),
|
|
|
|
|
ctx_(std::move(ctx)) {
|
2023-09-30 15:37:36 +00:00
|
|
|
batch_ = llama_batch_init(N_BATCH, 0);
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
void start(rust::Slice<const uint32_t> input_token_ids) override {
|
2023-09-03 01:59:07 +00:00
|
|
|
auto* ctx = ctx_.get();
|
2023-09-05 02:14:29 +00:00
|
|
|
llama_reset_timings(ctx);
|
2023-09-29 13:06:47 +00:00
|
|
|
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
2023-09-08 16:20:51 +00:00
|
|
|
|
|
|
|
|
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
|
|
|
|
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
|
|
|
|
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
|
|
|
|
}
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
uint32_t step() override {
|
2023-09-29 13:06:47 +00:00
|
|
|
const llama_token id = sample();
|
2023-09-28 23:59:59 +00:00
|
|
|
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
2023-09-29 13:06:47 +00:00
|
|
|
return id;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
void end() override {
|
2023-09-05 02:14:29 +00:00
|
|
|
llama_print_timings(ctx_.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t eos_token() const override {
|
|
|
|
|
return llama_token_eos(ctx_.get());
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-03 01:59:07 +00:00
|
|
|
private:
|
|
|
|
|
uint32_t sample() const {
|
|
|
|
|
auto* ctx = ctx_.get();
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
|
2023-09-28 23:59:59 +00:00
|
|
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
2023-09-03 01:59:07 +00:00
|
|
|
|
|
|
|
|
// Greedy sampling (always select the highest logit).
|
|
|
|
|
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-02 15:39:15 +00:00
|
|
|
void eval(llama_token* data, size_t size, bool reset) {
|
2023-09-30 15:37:36 +00:00
|
|
|
if (reset) {
|
|
|
|
|
n_past_ = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
batch_.n_tokens = size;
|
|
|
|
|
for (size_t i = 0; i < size; ++i) {
|
|
|
|
|
batch_.token[i] = data[i];
|
|
|
|
|
batch_.pos[i] = n_past_ + i;
|
|
|
|
|
batch_.seq_id[i] = 0;
|
|
|
|
|
batch_.logits[i] = false;
|
|
|
|
|
}
|
|
|
|
|
batch_.logits[size - 1] = true;
|
|
|
|
|
|
2023-09-03 01:59:07 +00:00
|
|
|
auto* ctx = ctx_.get();
|
2023-09-30 15:37:36 +00:00
|
|
|
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
|
|
|
|
if (llama_decode(ctx, batch_)) {
|
2023-10-02 15:39:15 +00:00
|
|
|
throw std::runtime_error("Failed to eval");
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
n_past_ += size;
|
2023-09-03 01:59:07 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
size_t n_past_;
|
2023-09-03 01:59:07 +00:00
|
|
|
owned<llama_model> model_;
|
|
|
|
|
owned<llama_context> ctx_;
|
2023-09-30 15:37:36 +00:00
|
|
|
|
|
|
|
|
llama_batch batch_;
|
2023-09-03 01:59:07 +00:00
|
|
|
};
|
|
|
|
|
|
2023-09-12 14:41:39 +00:00
|
|
|
static int g_llama_cpp_log_level = 0;
|
2023-09-28 23:59:59 +00:00
|
|
|
static void llama_log_callback(ggml_log_level level, const char * text, void * user_data) {
|
2023-09-12 14:41:39 +00:00
|
|
|
(void)user_data;
|
|
|
|
|
if (level < g_llama_cpp_log_level) {
|
|
|
|
|
fputs(text, stderr);
|
|
|
|
|
fflush(stderr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-03 01:59:07 +00:00
|
|
|
struct BackendInitializer {
|
|
|
|
|
BackendInitializer() {
|
2023-09-12 14:41:39 +00:00
|
|
|
if (const char* level = std::getenv("LLAMA_CPP_LOG_LEVEL")) {
|
|
|
|
|
g_llama_cpp_log_level = std::stoi(level);
|
|
|
|
|
}
|
|
|
|
|
llama_log_set(llama_log_callback, nullptr);
|
2023-09-03 01:59:07 +00:00
|
|
|
llama_backend_init(false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~BackendInitializer() {
|
|
|
|
|
llama_backend_free();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
2023-09-03 01:59:07 +00:00
|
|
|
static BackendInitializer initializer;
|
|
|
|
|
|
2023-09-28 23:59:59 +00:00
|
|
|
llama_model_params model_params = llama_model_default_params();
|
|
|
|
|
model_params.n_gpu_layers = 1;
|
|
|
|
|
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), model_params);
|
2023-09-03 01:59:07 +00:00
|
|
|
|
|
|
|
|
if (!model) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-28 23:59:59 +00:00
|
|
|
llama_context_params ctx_params = llama_context_default_params();
|
2023-10-02 15:39:15 +00:00
|
|
|
ctx_params.n_ctx = N_CTX;
|
2023-09-28 23:59:59 +00:00
|
|
|
ctx_params.n_batch = N_BATCH;
|
2023-09-03 01:59:07 +00:00
|
|
|
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
|
|
|
|
|
2023-09-30 15:37:36 +00:00
|
|
|
return std::make_unique<TextInferenceEngineImpl>(
|
2023-09-03 01:59:07 +00:00
|
|
|
owned<llama_model>(model, llama_free_model),
|
|
|
|
|
owned<llama_context>(ctx, llama_free)
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace tabby
|