From 7bd99d14c0d346bde2a78c4eacc0ed5d53bb666d Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 28 Oct 2023 23:37:05 -0700 Subject: [PATCH] feat: support continuous batching in llama.cpp backend (#659) * refactor: switch back to llama batch interface * feat: support cont batching --- crates/llama-cpp-bindings/include/engine.h | 8 +- crates/llama-cpp-bindings/src/engine.cc | 192 +++++++++++++++----- crates/llama-cpp-bindings/src/lib.rs | 201 +++++++++++++++------ crates/tabby/src/download.rs | 2 +- crates/tabby/src/serve/engine.rs | 4 +- 5 files changed, 305 insertions(+), 102 deletions(-) diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index fffc0a2..e6a9c4d 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -9,11 +9,11 @@ class TextInferenceEngine { public: virtual ~TextInferenceEngine(); - virtual void start(rust::Slice input_token_ids) = 0; - virtual uint32_t step() = 0; - virtual void end() = 0; + virtual void add_request(uint32_t request_id, rust::Slice input_token_ids) = 0; + virtual void stop_request(uint32_t request_id) = 0; + virtual rust::Vec step() = 0; - virtual uint32_t eos_token() const = 0; + virtual uint32_t eos_token_id() const = 0; }; std::unique_ptr create_engine(bool use_gpu, rust::Str model_path); diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 3b5caaa..9a93f36 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include @@ -10,8 +12,34 @@ namespace llama { TextInferenceEngine::~TextInferenceEngine() {} namespace { -static size_t N_BATCH = 512; // # per batch inference. -static size_t N_CTX = 4096; // # max kv history. +int get_parallelism() { + const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM"); + if (parallelism) { + return std::stoi(parallelism); + } else { + return 4; + } +} + +static size_t N_CONCURRENT_REQUESTS = get_parallelism(); + +constexpr size_t N_BATCH = 512; // # per batch inference. +constexpr size_t N_CTX = 4096; // # max kv history. + +struct Request { + Request(size_t request_id, rust::Slice input_token_ids) : + id(request_id), + tokens(input_token_ids.begin(), input_token_ids.end()) { + } + + size_t id = -1; + llama_seq_id seq_id = -1; + + std::vector tokens; + size_t i_batch = -1; + size_t n_past = 0; +}; + template using owned = std::unique_ptr>; @@ -21,61 +49,136 @@ class TextInferenceEngineImpl : public TextInferenceEngine { TextInferenceEngineImpl(owned model, owned ctx) : model_(std::move(model)), ctx_(std::move(ctx)) { + batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1); } - void start(rust::Slice input_token_ids) override { + ~TextInferenceEngineImpl() { + llama_batch_free(batch_); + } + + void add_request(uint32_t request_id, rust::Slice input_token_ids) override { + pending_requests_.push_back(Request(request_id, input_token_ids)); + } + + void stop_request(uint32_t request_id) override { + stopped_requests_.insert(request_id); + } + + rust::Vec step() override { auto* ctx = ctx_.get(); - llama_reset_timings(ctx); - std::vector tokens_list(input_token_ids.begin(), input_token_ids.end()); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) { - const size_t size = std::min(N_BATCH, tokens_list.size() - i); - eval(tokens_list.data() + i, size, /* reset = */ i == 0); + // Remove stopped requests. + if (!stopped_requests_.empty()) { + std::vector requests; + for (auto& request : requests_) { + if (stopped_requests_.count(request.id) > 0) { + // Release KV cache. + llama_kv_cache_seq_rm(ctx_.get(), request.id, -1, -1); + } else { + requests.emplace_back(request); + } + } + + requests_ = requests; } + + // Add pending requests. + while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) { + Request request = std::move(pending_requests_.front()); + pending_requests_.pop_front(); + + // Ignore stopped pending requests. + if (stopped_requests_.count(request.id) > 0) { + continue; + } + + requests_.push_back(request); + } + + // Clear stopped requests. + stopped_requests_.clear(); + + if (requests_.size() == 0) { + return {}; + } + + // Clear the batch. + batch_.n_tokens = 0; + + // Insert tokens from ongoing requests to batch. + for (auto& request : requests_) { + const size_t n_tokens = batch_.n_tokens; + for (size_t i = 0; i < request.tokens.size(); ++i) { + batch_.token[n_tokens + i] = request.tokens[i]; + batch_.pos[n_tokens + i] = request.n_past + i; + batch_.n_seq_id[n_tokens + i] = 1; + batch_.seq_id[n_tokens + i][0] = request.id; + batch_.logits[n_tokens + i] = false; + } + batch_.n_tokens += request.tokens.size(); + + batch_.logits[batch_.n_tokens - 1] = true; + request.i_batch = batch_.n_tokens - 1; + } + + rust::Vec result; + result.reserve(requests_.size() * 2); + + // Decode tokens in chunks + for (size_t i = 0; i < static_cast(batch_.n_tokens); i += N_BATCH) { + const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i); + llama_batch batch_view = { + n_tokens, + batch_.token + i, + nullptr, + batch_.pos + i, + batch_.n_seq_id + i, + batch_.seq_id + i, + batch_.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + throw std::runtime_error("Failed to eval"); + } + + for (auto& request : requests_) { + if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) { + continue; + } + + int32_t i_batch = request.i_batch - i; + auto logits = llama_get_logits_ith(ctx, i_batch); + auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab)); + + request.n_past += request.tokens.size(); + + request.tokens.clear(); + request.tokens.push_back(next_token); + + result.push_back(request.id); + result.push_back(next_token); + } + } + + return result; } - uint32_t step() override { - const llama_token id = sample(); - eval(const_cast(&id), 1, /* reset = */ false); - return id; - } - - void end() override { - llama_print_timings(ctx_.get()); - } - - uint32_t eos_token() const override { + uint32_t eos_token_id() const override { return llama_token_eos(llama_get_model(ctx_.get())); } private: - uint32_t sample() const { - auto* ctx = ctx_.get(); - - auto logits = llama_get_logits_ith(ctx, 0); - auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - - // Greedy sampling (always select the highest logit). - return std::distance(logits, std::max_element(logits, logits + n_vocab)); - } - - void eval(llama_token* data, size_t size, bool reset) { - if (reset) { - n_past_ = 0; - } - - auto* ctx = ctx_.get(); - llama_kv_cache_tokens_rm(ctx, n_past_, -1); - if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) { - throw std::runtime_error("Failed to eval"); - } - - n_past_ += size; - } - - size_t n_past_; owned model_; owned ctx_; + + llama_batch batch_; + + std::vector requests_; + std::deque pending_requests_; + std::unordered_set stopped_requests_; }; static int g_llama_cpp_log_level = 0; @@ -100,6 +203,7 @@ struct BackendInitializer { llama_backend_free(); } }; + } // namespace std::unique_ptr create_engine(bool use_gpu, rust::Str model_path) { diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 53870fc..00e5879 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -1,12 +1,20 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use async_stream::stream; use async_trait::async_trait; +use cxx::UniquePtr; use derive_builder::Builder; use ffi::create_engine; use futures::{lock::Mutex, stream::BoxStream}; -use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions}; +use tabby_inference::{ + decoding::{DecodingFactory, IncrementalDecoding}, + helpers, TextGeneration, TextGenerationOptions, +}; use tokenizers::tokenizer::Tokenizer; +use tokio::{ + sync::mpsc::{channel, Sender}, + task::yield_now, +}; #[cxx::bridge(namespace = "llama")] mod ffi { @@ -17,46 +25,168 @@ mod ffi { fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr; - fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]); - fn step(self: Pin<&mut TextInferenceEngine>) -> Result; - fn end(self: Pin<&mut TextInferenceEngine>); + fn add_request( + self: Pin<&mut TextInferenceEngine>, + request_id: u32, + input_token_ids: &[u32], + ); + fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32); + fn step(self: Pin<&mut TextInferenceEngine>) -> Result>; - fn eos_token(&self) -> u32; + fn eos_token_id(&self) -> u32; } } unsafe impl Send for ffi::TextInferenceEngine {} unsafe impl Sync for ffi::TextInferenceEngine {} +struct InferenceRequest { + tx: Sender, + decoding: IncrementalDecoding, +} + +struct AsyncTextInferenceEngine { + engine: Mutex>, + tokenizer: Arc, + decoding_factory: DecodingFactory, + requests: Mutex>, + + next_request_id: Mutex, + eos_token_id: u32, +} + +impl AsyncTextInferenceEngine { + fn create(engine: UniquePtr, tokenizer: Tokenizer) -> Self { + Self { + eos_token_id: engine.eos_token_id(), + engine: Mutex::new(engine), + tokenizer: Arc::new(tokenizer), + decoding_factory: DecodingFactory::default(), + requests: Mutex::new(HashMap::new()), + next_request_id: Mutex::new(0), + } + } + + async fn background_job(&self) { + let mut requests = self.requests.lock().await; + if requests.len() == 0 { + return; + } + + let mut engine = self.engine.lock().await; + + let Ok(result) = engine.as_mut().unwrap().step() else { + panic!("Failed to evaluation"); + }; + + for i in (0..result.len()).step_by(2) { + let request_id = result[i]; + let token_id = result[i + 1]; + + let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap(); + let mut stopped = false; + + if tx.is_closed() || token_id == self.eos_token_id { + // Cancelled by client side or hit eos. + stopped = true; + } else if let Some(new_text) = decoding.next_token(token_id) { + tx.send(new_text).await.expect("send failed"); + } else { + // Stoop words stopped + stopped = true; + } + + if stopped { + requests.remove(&request_id); + engine.as_mut().unwrap().stop_request(request_id); + } + } + } + + async fn generate_stream( + &self, + prompt: &str, + options: TextGenerationOptions, + ) -> BoxStream { + let encoding = self.tokenizer.encode(prompt, true).unwrap(); + let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); + let decoding = self.decoding_factory.create_incremental_decoding( + self.tokenizer.clone(), + input_token_ids, + options.language, + ); + + let (tx, mut rx) = channel::(4); + { + let mut engine = self.engine.lock().await; + let engine = engine.as_mut().unwrap(); + + let mut request_id = self.next_request_id.lock().await; + self.requests + .lock() + .await + .insert(*request_id, InferenceRequest { tx, decoding }); + engine.add_request(*request_id, input_token_ids); + + // 2048 should be large enough to avoid collision. + *request_id = (*request_id + 1) % 2048; + } + + let s = stream! { + let mut length = 0; + while let Some(new_text) = rx.recv().await { + yield new_text; + length += 1; + if length >= options.max_decoding_length { + break; + } + } + + rx.close(); + }; + + Box::pin(s) + } +} + #[derive(Builder, Debug)] -pub struct LlamaEngineOptions { +pub struct LlamaTextGenerationOptions { model_path: String, tokenizer_path: String, use_gpu: bool, } -pub struct LlamaEngine { - engine: Mutex>, - tokenizer: Arc, - decoding_factory: DecodingFactory, +pub struct LlamaTextGeneration { + engine: Arc, } -impl LlamaEngine { - pub fn create(options: LlamaEngineOptions) -> Self { +impl LlamaTextGeneration { + pub fn create(options: LlamaTextGenerationOptions) -> Self { let engine = create_engine(options.use_gpu, &options.model_path); if engine.is_null() { panic!("Unable to load model: {}", options.model_path); } - LlamaEngine { - engine: Mutex::new(engine), - tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), - decoding_factory: DecodingFactory::default(), - } + let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap(); + let ret = LlamaTextGeneration { + engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)), + }; + ret.start_background_job(); + ret + } + + pub fn start_background_job(&self) { + let engine = self.engine.clone(); + tokio::spawn(async move { + loop { + engine.background_job().await; + yield_now().await; + } + }); } } #[async_trait] -impl TextGeneration for LlamaEngine { +impl TextGeneration for LlamaTextGeneration { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { let s = self.generate_stream(prompt, options).await; helpers::stream_to_string(s).await @@ -67,38 +197,7 @@ impl TextGeneration for LlamaEngine { prompt: &str, options: TextGenerationOptions, ) -> BoxStream { - let encoding = self.tokenizer.encode(prompt, true).unwrap(); - - let s = stream! { - let mut engine = self.engine.lock().await; - let mut engine = engine.as_mut().unwrap(); - let eos_token = engine.eos_token(); - - let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); - engine.as_mut().start(input_token_ids); - let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language); - let mut n_remains = options.max_decoding_length ; - while n_remains > 0 { - let Ok(next_token_id) = engine.as_mut().step() else { - panic!("Failed to eval"); - }; - if next_token_id == eos_token { - break; - } - - if let Some(new_text) = decoding.next_token(next_token_id) { - yield new_text; - } else { - break; - } - - n_remains -= 1; - } - - engine.end(); - }; - - Box::pin(s) + self.engine.generate_stream(prompt, options).await } } diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index 6ee52f0..bfea8b8 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,6 +1,6 @@ use clap::Args; use tabby_download::Downloader; -use tracing::{info, log::warn}; +use tracing::info; use crate::fatal; diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 4fb9dd1..0b89ea5 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -39,14 +39,14 @@ pub struct EngineInfo { } fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box { - let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() + let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_dir.ggml_q8_0_v2_file()) .tokenizer_path(model_dir.tokenizer_file()) .use_gpu(device.ggml_use_gpu()) .build() .unwrap(); - Box::new(llama_cpp_bindings::LlamaEngine::create(options)) + Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options)) } fn get_model_dir(model: &str) -> ModelDir {