refactor: cleanup llama cpp implementations to fix warnings (#495)

release-0.2
Meng Zhang 2023-09-30 08:37:36 -07:00 committed by GitHub
parent aea8c74bdc
commit 2171ba72ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 24 deletions

View File

@ -9,12 +9,12 @@ class TextInferenceEngine {
public: public:
virtual ~TextInferenceEngine(); virtual ~TextInferenceEngine();
virtual void start(rust::Slice<const uint32_t> input_token_ids) const = 0; virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
virtual uint32_t step() const = 0; virtual uint32_t step() = 0;
virtual void end() const = 0; virtual void end() = 0;
virtual uint32_t eos_token() const = 0; virtual uint32_t eos_token() const = 0;
}; };
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path); std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
} // namespace } // namespace

View File

@ -20,9 +20,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) : TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
model_(std::move(model)), model_(std::move(model)),
ctx_(std::move(ctx)) { ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_BATCH, 0);
} }
void start(rust::Slice<const uint32_t> input_token_ids) const override { void start(rust::Slice<const uint32_t> input_token_ids) override {
auto* ctx = ctx_.get(); auto* ctx = ctx_.get();
llama_reset_timings(ctx); llama_reset_timings(ctx);
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end()); std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
@ -33,13 +34,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
} }
} }
uint32_t step() const override { uint32_t step() override {
const llama_token id = sample(); const llama_token id = sample();
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false); eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
return id; return id;
} }
void end() const override { void end() override {
llama_print_timings(ctx_.get()); llama_print_timings(ctx_.get());
} }
@ -51,29 +52,43 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
uint32_t sample() const { uint32_t sample() const {
auto* ctx = ctx_.get(); auto* ctx = ctx_.get();
auto logits = llama_get_logits(ctx); auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx)); auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Greedy sampling (always select the highest logit). // Greedy sampling (always select the highest logit).
return std::distance(logits, std::max_element(logits, logits + n_vocab)); return std::distance(logits, std::max_element(logits, logits + n_vocab));
} }
bool eval(llama_token* data, size_t size, bool reset) const { bool eval(llama_token* data, size_t size, bool reset) {
if (reset) {
n_past_ = 0;
}
batch_.n_tokens = size;
for (size_t i = 0; i < size; ++i) {
batch_.token[i] = data[i];
batch_.pos[i] = n_past_ + i;
batch_.seq_id[i] = 0;
batch_.logits[i] = false;
}
batch_.logits[size - 1] = true;
auto* ctx = ctx_.get(); auto* ctx = ctx_.get();
if (llama_eval( llama_kv_cache_tokens_rm(ctx, n_past_, -1);
ctx, if (llama_decode(ctx, batch_)) {
data,
size,
reset ? 0 : llama_get_kv_cache_token_count(ctx))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;
} }
n_past_ += size;
return true; return true;
} }
size_t n_past_;
owned<llama_model> model_; owned<llama_model> model_;
owned<llama_context> ctx_; owned<llama_context> ctx_;
llama_batch batch_;
}; };
static int g_llama_cpp_log_level = 0; static int g_llama_cpp_log_level = 0;
@ -100,7 +115,7 @@ struct BackendInitializer {
}; };
} // namespace } // namespace
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) { std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
static BackendInitializer initializer; static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
@ -117,7 +132,7 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
ctx_params.n_batch = N_BATCH; ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params); llama_context* ctx = llama_new_context_with_model(model, ctx_params);
return std::make_shared<TextInferenceEngineImpl>( return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model), owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free) owned<llama_context>(ctx, llama_free)
); );

View File

@ -15,11 +15,11 @@ mod ffi {
type TextInferenceEngine; type TextInferenceEngine;
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>; fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn start(&self, input_token_ids: &[u32]); fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(&self) -> u32; fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
fn end(&self); fn end(self: Pin<&mut TextInferenceEngine>);
fn eos_token(&self) -> u32; fn eos_token(&self) -> u32;
} }
@ -35,7 +35,7 @@ pub struct LlamaEngineOptions {
} }
pub struct LlamaEngine { pub struct LlamaEngine {
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>, engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>, tokenizer: Arc<Tokenizer>,
decoding_factory: DecodingFactory, decoding_factory: DecodingFactory,
} }
@ -65,15 +65,16 @@ impl TextGeneration for LlamaEngine {
let encoding = self.tokenizer.encode(prompt, true).unwrap(); let encoding = self.tokenizer.encode(prompt, true).unwrap();
let s = stream! { let s = stream! {
let engine = self.engine.lock().await; let mut engine = self.engine.lock().await;
let mut engine = engine.as_mut().unwrap();
let eos_token = engine.eos_token(); let eos_token = engine.eos_token();
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.start(input_token_ids); engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words); let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut n_remains = options.max_decoding_length ; let mut n_remains = options.max_decoding_length ;
while n_remains > 0 { while n_remains > 0 {
let next_token_id = engine.step(); let next_token_id = engine.as_mut().step();
if next_token_id == eos_token { if next_token_id == eos_token {
break; break;
} }