refactor: cleanup llama cpp implementations to fix warnings (#495)

release-0.2
Meng Zhang 2023-09-30 08:37:36 -07:00 committed by GitHub
parent aea8c74bdc
commit 2171ba72ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 24 deletions

View File

@ -9,12 +9,12 @@ class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0;
virtual uint32_t step() const = 0;
virtual void end() const = 0;
virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
virtual uint32_t step() = 0;
virtual void end() = 0;
virtual uint32_t eos_token() const = 0;
};
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
} // namespace

View File

@ -20,9 +20,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_BATCH, 0);
}
void start(rust::Slice<const uint32_t> input_token_ids) const override {
void start(rust::Slice<const uint32_t> input_token_ids) override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
@ -33,13 +34,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}
}
uint32_t step() const override {
uint32_t step() override {
const llama_token id = sample();
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
return id;
}
void end() const override {
void end() override {
llama_print_timings(ctx_.get());
}
@ -51,29 +52,43 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
uint32_t sample() const {
auto* ctx = ctx_.get();
auto logits = llama_get_logits(ctx);
auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Greedy sampling (always select the highest logit).
return std::distance(logits, std::max_element(logits, logits + n_vocab));
}
bool eval(llama_token* data, size_t size, bool reset) const {
bool eval(llama_token* data, size_t size, bool reset) {
if (reset) {
n_past_ = 0;
}
batch_.n_tokens = size;
for (size_t i = 0; i < size; ++i) {
batch_.token[i] = data[i];
batch_.pos[i] = n_past_ + i;
batch_.seq_id[i] = 0;
batch_.logits[i] = false;
}
batch_.logits[size - 1] = true;
auto* ctx = ctx_.get();
if (llama_eval(
ctx,
data,
size,
reset ? 0 : llama_get_kv_cache_token_count(ctx))) {
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, batch_)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past_ += size;
return true;
}
size_t n_past_;
owned<llama_model> model_;
owned<llama_context> ctx_;
llama_batch batch_;
};
static int g_llama_cpp_log_level = 0;
@ -100,7 +115,7 @@ struct BackendInitializer {
};
} // namespace
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params();
@ -117,7 +132,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
return std::make_shared<TextInferenceEngineImpl>(
return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free)
);

View File

@ -15,11 +15,11 @@ mod ffi {
type TextInferenceEngine;
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn start(&self, input_token_ids: &[u32]);
fn step(&self) -> u32;
fn end(&self);
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
fn end(self: Pin<&mut TextInferenceEngine>);
fn eos_token(&self) -> u32;
}
@ -35,7 +35,7 @@ pub struct LlamaEngineOptions {
}
pub struct LlamaEngine {
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
decoding_factory: DecodingFactory,
}
@ -65,15 +65,16 @@ impl TextGeneration for LlamaEngine {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let s = stream! {
let engine = self.engine.lock().await;
let mut engine = self.engine.lock().await;
let mut engine = engine.as_mut().unwrap();
let eos_token = engine.eos_token();
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.step();
let next_token_id = engine.as_mut().step();
if next_token_id == eos_token {
break;
}