diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index 1dad36d..834a1d7 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -9,12 +9,12 @@ class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual void start(rust::Slice input_token_ids) const = 0; - virtual uint32_t step() const = 0; - virtual void end() const = 0; + virtual void start(rust::Slice input_token_ids) = 0; + virtual uint32_t step() = 0; + virtual void end() = 0; virtual uint32_t eos_token() const = 0; }; -std::shared_ptr create_engine(rust::Str model_path); +std::unique_ptr create_engine(rust::Str model_path); } // namespace diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 0f1a8cd..2aeedab 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -20,9 +20,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine { TextInferenceEngineImpl(owned model, owned ctx) : model_(std::move(model)), ctx_(std::move(ctx)) { + batch_ = llama_batch_init(N_BATCH, 0); } - void start(rust::Slice input_token_ids) const override { + void start(rust::Slice input_token_ids) override { auto* ctx = ctx_.get(); llama_reset_timings(ctx); std::vector tokens_list(input_token_ids.begin(), input_token_ids.end()); @@ -33,13 +34,13 @@ class TextInferenceEngineImpl : public TextInferenceEngine { } } - uint32_t step() const override { + uint32_t step() override { const llama_token id = sample(); eval(const_cast(&id), 1, /* reset = */ false); return id; } - void end() const override { + void end() override { llama_print_timings(ctx_.get()); } @@ -51,29 +52,43 @@ class TextInferenceEngineImpl : public TextInferenceEngine { uint32_t sample() const { auto* ctx = ctx_.get(); - auto logits = llama_get_logits(ctx); + auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1); auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Greedy sampling (always select the highest logit). return std::distance(logits, std::max_element(logits, logits + n_vocab)); } - bool eval(llama_token* data, size_t size, bool reset) const { + bool eval(llama_token* data, size_t size, bool reset) { + if (reset) { + n_past_ = 0; + } + + batch_.n_tokens = size; + for (size_t i = 0; i < size; ++i) { + batch_.token[i] = data[i]; + batch_.pos[i] = n_past_ + i; + batch_.seq_id[i] = 0; + batch_.logits[i] = false; + } + batch_.logits[size - 1] = true; + auto* ctx = ctx_.get(); - if (llama_eval( - ctx, - data, - size, - reset ? 0 : llama_get_kv_cache_token_count(ctx))) { + llama_kv_cache_tokens_rm(ctx, n_past_, -1); + if (llama_decode(ctx, batch_)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } + n_past_ += size; return true; } + size_t n_past_; owned model_; owned ctx_; + + llama_batch batch_; }; static int g_llama_cpp_log_level = 0; @@ -100,7 +115,7 @@ struct BackendInitializer { }; } // namespace -std::shared_ptr create_engine(rust::Str model_path) { +std::unique_ptr create_engine(rust::Str model_path) { static BackendInitializer initializer; llama_model_params model_params = llama_model_default_params(); @@ -117,7 +132,7 @@ std::shared_ptr create_engine(rust::Str model_path) { ctx_params.n_batch = N_BATCH; llama_context* ctx = llama_new_context_with_model(model, ctx_params); - return std::make_shared( + return std::make_unique( owned(model, llama_free_model), owned(ctx, llama_free) ); diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 02a171f..da91aa2 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -15,11 +15,11 @@ mod ffi { type TextInferenceEngine; - fn create_engine(model_path: &str) -> SharedPtr; + fn create_engine(model_path: &str) -> UniquePtr; - fn start(&self, input_token_ids: &[u32]); - fn step(&self) -> u32; - fn end(&self); + fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]); + fn step(self: Pin<&mut TextInferenceEngine>) -> u32; + fn end(self: Pin<&mut TextInferenceEngine>); fn eos_token(&self) -> u32; } @@ -35,7 +35,7 @@ pub struct LlamaEngineOptions { } pub struct LlamaEngine { - engine: Mutex>, + engine: Mutex>, tokenizer: Arc, decoding_factory: DecodingFactory, } @@ -65,15 +65,16 @@ impl TextGeneration for LlamaEngine { let encoding = self.tokenizer.encode(prompt, true).unwrap(); let s = stream! { - let engine = self.engine.lock().await; + let mut engine = self.engine.lock().await; + let mut engine = engine.as_mut().unwrap(); let eos_token = engine.eos_token(); let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); - engine.start(input_token_ids); + engine.as_mut().start(input_token_ids); let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words); let mut n_remains = options.max_decoding_length ; while n_remains > 0 { - let next_token_id = engine.step(); + let next_token_id = engine.as_mut().step(); if next_token_id == eos_token { break; }