feat: implement input truncation for llama-cpp-bindings (#416)
* feat: implement input truncation for llama-cpp-bindings * set max input length to 1024 * fix: batching tokens with n_batches * fix batchingrelease-0.2
parent
87b6b34120
commit
ad3b974d5c
|
|
@ -9,7 +9,7 @@ class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
|
|
||||||
virtual uint32_t start(const rust::Str prompt) const = 0;
|
virtual uint32_t start(const rust::Str prompt, size_t max_input_length) const = 0;
|
||||||
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
||||||
virtual void end() const = 0;
|
virtual void end() const = 0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,18 +10,27 @@ namespace llama {
|
||||||
TextInferenceEngine::~TextInferenceEngine() {}
|
TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
static size_t N_BATCH = 512;
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
||||||
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, size_t max_input_length, bool add_bos) {
|
||||||
// upper limit for the number of tokens
|
// upper limit for the number of tokens
|
||||||
int n_tokens = text.length() + add_bos;
|
int n_tokens = max_input_length;
|
||||||
std::vector<llama_token> result(n_tokens);
|
std::vector<llama_token> result(n_tokens);
|
||||||
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||||
if (n_tokens < 0) {
|
if (n_tokens < 0) {
|
||||||
result.resize(-n_tokens);
|
result.resize(-n_tokens);
|
||||||
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||||
GGML_ASSERT(check == -n_tokens);
|
GGML_ASSERT(check == -n_tokens);
|
||||||
|
|
||||||
|
int start = check - max_input_length;
|
||||||
|
GGML_ASSERT(start >= 0);
|
||||||
|
result = std::vector<llama_token>(result.begin() + start, result.end());
|
||||||
|
if (add_bos) {
|
||||||
|
result[0] = llama_token_bos(ctx);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
result.resize(n_tokens);
|
result.resize(n_tokens);
|
||||||
}
|
}
|
||||||
|
|
@ -35,16 +44,21 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
ctx_(std::move(ctx)) {
|
ctx_(std::move(ctx)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t start(const rust::Str prompt) const override {
|
uint32_t start(const rust::Str prompt, size_t max_input_length) const override {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
llama_reset_timings(ctx);
|
llama_reset_timings(ctx);
|
||||||
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
|
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), max_input_length, /* add_bos = */ true);
|
||||||
eval(tokens_list, /* reset = */ true);
|
|
||||||
|
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
||||||
|
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
||||||
|
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
||||||
|
}
|
||||||
return sample();
|
return sample();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t step(uint32_t next_token_id) const override {
|
uint32_t step(uint32_t next_token_id) const override {
|
||||||
eval({ static_cast<llama_token>(next_token_id) }, /* reset = */ false);
|
const llama_token id = next_token_id;
|
||||||
|
eval(&id, 1, /* reset = */ false);
|
||||||
return sample();
|
return sample();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -67,12 +81,12 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool eval(const std::vector<llama_token>& tokens_list, bool reset) const {
|
bool eval(const llama_token* data, size_t size, bool reset) const {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
if (llama_eval(
|
if (llama_eval(
|
||||||
ctx,
|
ctx,
|
||||||
tokens_list.data(),
|
data,
|
||||||
tokens_list.size(),
|
size,
|
||||||
reset ? 0 : llama_get_kv_cache_token_count(ctx),
|
reset ? 0 : llama_get_kv_cache_token_count(ctx),
|
||||||
/* n_threads = */ 4)) {
|
/* n_threads = */ 4)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
|
@ -102,6 +116,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||||
|
|
||||||
llama_context_params ctx_params = llama_context_default_params();
|
llama_context_params ctx_params = llama_context_default_params();
|
||||||
ctx_params.n_ctx = 2048;
|
ctx_params.n_ctx = 2048;
|
||||||
|
ctx_params.n_batch = N_BATCH;
|
||||||
ctx_params.n_gpu_layers = 1;
|
ctx_params.n_gpu_layers = 1;
|
||||||
|
|
||||||
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);
|
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ mod ffi {
|
||||||
|
|
||||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn start(&self, prompt: &str) -> u32;
|
fn start(&self, prompt: &str, max_input_length: usize) -> u32;
|
||||||
fn step(&self, next_token_id: u32) -> u32;
|
fn step(&self, next_token_id: u32) -> u32;
|
||||||
fn end(&self);
|
fn end(&self);
|
||||||
|
|
||||||
|
|
@ -67,7 +67,7 @@ impl TextGeneration for LlamaEngine {
|
||||||
let engine = engine.lock().unwrap();
|
let engine = engine.lock().unwrap();
|
||||||
let eos_token = engine.eos_token();
|
let eos_token = engine.eos_token();
|
||||||
|
|
||||||
let mut next_token_id = engine.start(&prompt);
|
let mut next_token_id = engine.start(&prompt, options.max_input_length);
|
||||||
if next_token_id == eos_token {
|
if next_token_id == eos_token {
|
||||||
return Vec::new();
|
return Vec::new();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue