refactor: cleanup llama cpp implementations to fix warnings (#495)
parent
aea8c74bdc
commit
2171ba72ff
|
|
@ -9,12 +9,12 @@ class TextInferenceEngine {
|
|||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
|
||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
|
||||
virtual uint32_t step() const = 0;
|
||||
virtual void end() const = 0;
|
||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||
virtual uint32_t step() = 0;
|
||||
virtual void end() = 0;
|
||||
|
||||
virtual uint32_t eos_token() const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||
model_(std::move(model)),
|
||||
ctx_(std::move(ctx)) {
|
||||
batch_ = llama_batch_init(N_BATCH, 0);
|
||||
}
|
||||
|
||||
void start(rust::Slice<const uint32_t> input_token_ids) const override {
|
||||
void start(rust::Slice<const uint32_t> input_token_ids) override {
|
||||
auto* ctx = ctx_.get();
|
||||
llama_reset_timings(ctx);
|
||||
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
||||
|
|
@ -33,13 +34,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t step() const override {
|
||||
uint32_t step() override {
|
||||
const llama_token id = sample();
|
||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
||||
return id;
|
||||
}
|
||||
|
||||
void end() const override {
|
||||
void end() override {
|
||||
llama_print_timings(ctx_.get());
|
||||
}
|
||||
|
||||
|
|
@ -51,29 +52,43 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
uint32_t sample() const {
|
||||
auto* ctx = ctx_.get();
|
||||
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
|
||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
// Greedy sampling (always select the highest logit).
|
||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
}
|
||||
|
||||
bool eval(llama_token* data, size_t size, bool reset) const {
|
||||
bool eval(llama_token* data, size_t size, bool reset) {
|
||||
if (reset) {
|
||||
n_past_ = 0;
|
||||
}
|
||||
|
||||
batch_.n_tokens = size;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
batch_.token[i] = data[i];
|
||||
batch_.pos[i] = n_past_ + i;
|
||||
batch_.seq_id[i] = 0;
|
||||
batch_.logits[i] = false;
|
||||
}
|
||||
batch_.logits[size - 1] = true;
|
||||
|
||||
auto* ctx = ctx_.get();
|
||||
if (llama_eval(
|
||||
ctx,
|
||||
data,
|
||||
size,
|
||||
reset ? 0 : llama_get_kv_cache_token_count(ctx))) {
|
||||
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
||||
if (llama_decode(ctx, batch_)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
n_past_ += size;
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t n_past_;
|
||||
owned<llama_model> model_;
|
||||
owned<llama_context> ctx_;
|
||||
|
||||
llama_batch batch_;
|
||||
};
|
||||
|
||||
static int g_llama_cpp_log_level = 0;
|
||||
|
|
@ -100,7 +115,7 @@ struct BackendInitializer {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||
static BackendInitializer initializer;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
|
@ -117,7 +132,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
|||
ctx_params.n_batch = N_BATCH;
|
||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
return std::make_shared<TextInferenceEngineImpl>(
|
||||
return std::make_unique<TextInferenceEngineImpl>(
|
||||
owned<llama_model>(model, llama_free_model),
|
||||
owned<llama_context>(ctx, llama_free)
|
||||
);
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ mod ffi {
|
|||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
||||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn start(&self, input_token_ids: &[u32]);
|
||||
fn step(&self) -> u32;
|
||||
fn end(&self);
|
||||
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
|
||||
fn end(self: Pin<&mut TextInferenceEngine>);
|
||||
|
||||
fn eos_token(&self) -> u32;
|
||||
}
|
||||
|
|
@ -35,7 +35,7 @@ pub struct LlamaEngineOptions {
|
|||
}
|
||||
|
||||
pub struct LlamaEngine {
|
||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
decoding_factory: DecodingFactory,
|
||||
}
|
||||
|
|
@ -65,15 +65,16 @@ impl TextGeneration for LlamaEngine {
|
|||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
|
||||
let s = stream! {
|
||||
let engine = self.engine.lock().await;
|
||||
let mut engine = self.engine.lock().await;
|
||||
let mut engine = engine.as_mut().unwrap();
|
||||
let eos_token = engine.eos_token();
|
||||
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.start(input_token_ids);
|
||||
engine.as_mut().start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let next_token_id = engine.step();
|
||||
let next_token_id = engine.as_mut().step();
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue