refactor: cleanup llama cpp implementations to fix warnings (#495)
parent
aea8c74bdc
commit
2171ba72ff
|
|
@ -9,12 +9,12 @@ class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
|
|
||||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
|
virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||||
virtual uint32_t step() const = 0;
|
virtual uint32_t step() = 0;
|
||||||
virtual void end() const = 0;
|
virtual void end() = 0;
|
||||||
|
|
||||||
virtual uint32_t eos_token() const = 0;
|
virtual uint32_t eos_token() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||||
model_(std::move(model)),
|
model_(std::move(model)),
|
||||||
ctx_(std::move(ctx)) {
|
ctx_(std::move(ctx)) {
|
||||||
|
batch_ = llama_batch_init(N_BATCH, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void start(rust::Slice<const uint32_t> input_token_ids) const override {
|
void start(rust::Slice<const uint32_t> input_token_ids) override {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
llama_reset_timings(ctx);
|
llama_reset_timings(ctx);
|
||||||
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
||||||
|
|
@ -33,13 +34,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t step() const override {
|
uint32_t step() override {
|
||||||
const llama_token id = sample();
|
const llama_token id = sample();
|
||||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void end() const override {
|
void end() override {
|
||||||
llama_print_timings(ctx_.get());
|
llama_print_timings(ctx_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -51,29 +52,43 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
uint32_t sample() const {
|
uint32_t sample() const {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
|
|
||||||
auto logits = llama_get_logits(ctx);
|
auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
|
||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
// Greedy sampling (always select the highest logit).
|
// Greedy sampling (always select the highest logit).
|
||||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool eval(llama_token* data, size_t size, bool reset) const {
|
bool eval(llama_token* data, size_t size, bool reset) {
|
||||||
|
if (reset) {
|
||||||
|
n_past_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_.n_tokens = size;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
batch_.token[i] = data[i];
|
||||||
|
batch_.pos[i] = n_past_ + i;
|
||||||
|
batch_.seq_id[i] = 0;
|
||||||
|
batch_.logits[i] = false;
|
||||||
|
}
|
||||||
|
batch_.logits[size - 1] = true;
|
||||||
|
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
if (llama_eval(
|
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
||||||
ctx,
|
if (llama_decode(ctx, batch_)) {
|
||||||
data,
|
|
||||||
size,
|
|
||||||
reset ? 0 : llama_get_kv_cache_token_count(ctx))) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n_past_ += size;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t n_past_;
|
||||||
owned<llama_model> model_;
|
owned<llama_model> model_;
|
||||||
owned<llama_context> ctx_;
|
owned<llama_context> ctx_;
|
||||||
|
|
||||||
|
llama_batch batch_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int g_llama_cpp_log_level = 0;
|
static int g_llama_cpp_log_level = 0;
|
||||||
|
|
@ -100,7 +115,7 @@ struct BackendInitializer {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||||
static BackendInitializer initializer;
|
static BackendInitializer initializer;
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
@ -117,7 +132,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||||
ctx_params.n_batch = N_BATCH;
|
ctx_params.n_batch = N_BATCH;
|
||||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
return std::make_shared<TextInferenceEngineImpl>(
|
return std::make_unique<TextInferenceEngineImpl>(
|
||||||
owned<llama_model>(model, llama_free_model),
|
owned<llama_model>(model, llama_free_model),
|
||||||
owned<llama_context>(ctx, llama_free)
|
owned<llama_context>(ctx, llama_free)
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ mod ffi {
|
||||||
|
|
||||||
type TextInferenceEngine;
|
type TextInferenceEngine;
|
||||||
|
|
||||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn start(&self, input_token_ids: &[u32]);
|
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
||||||
fn step(&self) -> u32;
|
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
|
||||||
fn end(&self);
|
fn end(self: Pin<&mut TextInferenceEngine>);
|
||||||
|
|
||||||
fn eos_token(&self) -> u32;
|
fn eos_token(&self) -> u32;
|
||||||
}
|
}
|
||||||
|
|
@ -35,7 +35,7 @@ pub struct LlamaEngineOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaEngine {
|
pub struct LlamaEngine {
|
||||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Arc<Tokenizer>,
|
||||||
decoding_factory: DecodingFactory,
|
decoding_factory: DecodingFactory,
|
||||||
}
|
}
|
||||||
|
|
@ -65,15 +65,16 @@ impl TextGeneration for LlamaEngine {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
|
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
let engine = self.engine.lock().await;
|
let mut engine = self.engine.lock().await;
|
||||||
|
let mut engine = engine.as_mut().unwrap();
|
||||||
let eos_token = engine.eos_token();
|
let eos_token = engine.eos_token();
|
||||||
|
|
||||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||||
engine.start(input_token_ids);
|
engine.as_mut().start(input_token_ids);
|
||||||
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
||||||
let mut n_remains = options.max_decoding_length ;
|
let mut n_remains = options.max_decoding_length ;
|
||||||
while n_remains > 0 {
|
while n_remains > 0 {
|
||||||
let next_token_id = engine.step();
|
let next_token_id = engine.as_mut().step();
|
||||||
if next_token_id == eos_token {
|
if next_token_id == eos_token {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue