From ad3b974d5c93d4e5757c46bfc0452556db658e95 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 9 Sep 2023 00:20:51 +0800 Subject: [PATCH] feat: implement input truncation for llama-cpp-bindings (#416) * feat: implement input truncation for llama-cpp-bindings * set max input length to 1024 * fix: batching tokens with n_batches * fix batching --- crates/llama-cpp-bindings/include/engine.h | 2 +- crates/llama-cpp-bindings/src/engine.cc | 33 ++++++++++++++++------ crates/llama-cpp-bindings/src/lib.rs | 4 +-- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index 2110c9e..0fae02f 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -9,7 +9,7 @@ class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual uint32_t start(const rust::Str prompt) const = 0; + virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0; virtual uint32_t step(uint32_t next_token_id) const = 0; virtual void end() const = 0; diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index a514895..f250f62 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -10,18 +10,27 @@ namespace llama { TextInferenceEngine::~TextInferenceEngine() {} namespace { +static size_t N_BATCH = 512; + template using owned = std::unique_ptr>; -std::vector tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { +std::vector tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) { // upper limit for the number of tokens - int n_tokens = text.length() + add_bos; + int n_tokens = max_input_length; std::vector result(n_tokens); n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); + + int start = check - max_input_length; + GGML_ASSERT(start >= 0); + result = std::vector(result.begin() + start, result.end()); + if (add_bos) { + result[0] = llama_token_bos(ctx); + } } else { result.resize(n_tokens); } @@ -35,16 +44,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine { ctx_(std::move(ctx)) { } - uint32_t start(const rust::Str prompt) const override { + uint32_t start(const rust::Str prompt, size_t max_input_length) const override { auto* ctx = ctx_.get(); llama_reset_timings(ctx); - std::vector tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true); - eval(tokens_list, /* reset = */ true); + std::vector tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ true); + + for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) { + const size_t size = std::min(N_BATCH, tokens_list.size() - i); + eval(tokens_list.data() + i, size, /* reset = */ i == 0); + } return sample(); } uint32_t step(uint32_t next_token_id) const override { - eval({ static_cast(next_token_id) }, /* reset = */ false); + const llama_token id = next_token_id; + eval(&id, 1, /* reset = */ false); return sample(); } @@ -67,12 +81,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine { return std::distance(logits, std::max_element(logits, logits + n_vocab)); } - bool eval(const std::vector& tokens_list, bool reset) const { + bool eval(const llama_token* data, size_t size, bool reset) const { auto* ctx = ctx_.get(); if (llama_eval( ctx, - tokens_list.data(), - tokens_list.size(), + data, + size, reset ? 0 : llama_get_kv_cache_token_count(ctx), /* n_threads = */ 4)) { fprintf(stderr, "%s : failed to eval\n", __func__); @@ -102,6 +116,7 @@ std::shared_ptr create_engine(rust::Str model_path) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = 2048; + ctx_params.n_batch = N_BATCH; ctx_params.n_gpu_layers = 1; llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params); diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 1f144ac..176d7c7 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -17,7 +17,7 @@ mod ffi { fn create_engine(model_path: &str) -> SharedPtr; - fn start(&self, prompt: &str) -> u32; + fn start(&self, prompt: &str, max_input_length: usize) -> u32; fn step(&self, next_token_id: u32) -> u32; fn end(&self); @@ -67,7 +67,7 @@ impl TextGeneration for LlamaEngine { let engine = engine.lock().unwrap(); let eos_token = engine.eos_token(); - let mut next_token_id = engine.start(&prompt); + let mut next_token_id = engine.start(&prompt, options.max_input_length); if next_token_id == eos_token { return Vec::new(); }