feat: support continuous batching in llama.cpp backend (#659)
* refactor: switch back to llama batch interface * feat: support cont batchingrelease-notes-05
parent
14d03b6826
commit
7bd99d14c0
|
|
@ -9,11 +9,11 @@ class TextInferenceEngine {
|
|||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
|
||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||
virtual uint32_t step() = 0;
|
||||
virtual void end() = 0;
|
||||
virtual void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||
virtual void stop_request(uint32_t request_id) = 0;
|
||||
virtual rust::Vec<uint32_t> step() = 0;
|
||||
|
||||
virtual uint32_t eos_token() const = 0;
|
||||
virtual uint32_t eos_token_id() const = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <ggml.h>
|
||||
#include <llama.h>
|
||||
|
|
@ -10,8 +12,34 @@ namespace llama {
|
|||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
namespace {
|
||||
static size_t N_BATCH = 512; // # per batch inference.
|
||||
static size_t N_CTX = 4096; // # max kv history.
|
||||
int get_parallelism() {
|
||||
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
|
||||
if (parallelism) {
|
||||
return std::stoi(parallelism);
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
static size_t N_CONCURRENT_REQUESTS = get_parallelism();
|
||||
|
||||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||
|
||||
struct Request {
|
||||
Request(size_t request_id, rust::Slice<const uint32_t> input_token_ids) :
|
||||
id(request_id),
|
||||
tokens(input_token_ids.begin(), input_token_ids.end()) {
|
||||
}
|
||||
|
||||
size_t id = -1;
|
||||
llama_seq_id seq_id = -1;
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
size_t i_batch = -1;
|
||||
size_t n_past = 0;
|
||||
};
|
||||
|
||||
|
||||
template<class T>
|
||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
|
@ -21,61 +49,136 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||
model_(std::move(model)),
|
||||
ctx_(std::move(ctx)) {
|
||||
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
|
||||
}
|
||||
|
||||
void start(rust::Slice<const uint32_t> input_token_ids) override {
|
||||
~TextInferenceEngineImpl() {
|
||||
llama_batch_free(batch_);
|
||||
}
|
||||
|
||||
void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) override {
|
||||
pending_requests_.push_back(Request(request_id, input_token_ids));
|
||||
}
|
||||
|
||||
void stop_request(uint32_t request_id) override {
|
||||
stopped_requests_.insert(request_id);
|
||||
}
|
||||
|
||||
rust::Vec<uint32_t> step() override {
|
||||
auto* ctx = ctx_.get();
|
||||
llama_reset_timings(ctx);
|
||||
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
||||
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
||||
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
||||
// Remove stopped requests.
|
||||
if (!stopped_requests_.empty()) {
|
||||
std::vector<Request> requests;
|
||||
for (auto& request : requests_) {
|
||||
if (stopped_requests_.count(request.id) > 0) {
|
||||
// Release KV cache.
|
||||
llama_kv_cache_seq_rm(ctx_.get(), request.id, -1, -1);
|
||||
} else {
|
||||
requests.emplace_back(request);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t step() override {
|
||||
const llama_token id = sample();
|
||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
||||
return id;
|
||||
requests_ = requests;
|
||||
}
|
||||
|
||||
void end() override {
|
||||
llama_print_timings(ctx_.get());
|
||||
// Add pending requests.
|
||||
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
|
||||
Request request = std::move(pending_requests_.front());
|
||||
pending_requests_.pop_front();
|
||||
|
||||
// Ignore stopped pending requests.
|
||||
if (stopped_requests_.count(request.id) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t eos_token() const override {
|
||||
requests_.push_back(request);
|
||||
}
|
||||
|
||||
// Clear stopped requests.
|
||||
stopped_requests_.clear();
|
||||
|
||||
if (requests_.size() == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Clear the batch.
|
||||
batch_.n_tokens = 0;
|
||||
|
||||
// Insert tokens from ongoing requests to batch.
|
||||
for (auto& request : requests_) {
|
||||
const size_t n_tokens = batch_.n_tokens;
|
||||
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||
batch_.token[n_tokens + i] = request.tokens[i];
|
||||
batch_.pos[n_tokens + i] = request.n_past + i;
|
||||
batch_.n_seq_id[n_tokens + i] = 1;
|
||||
batch_.seq_id[n_tokens + i][0] = request.id;
|
||||
batch_.logits[n_tokens + i] = false;
|
||||
}
|
||||
batch_.n_tokens += request.tokens.size();
|
||||
|
||||
batch_.logits[batch_.n_tokens - 1] = true;
|
||||
request.i_batch = batch_.n_tokens - 1;
|
||||
}
|
||||
|
||||
rust::Vec<uint32_t> result;
|
||||
result.reserve(requests_.size() * 2);
|
||||
|
||||
// Decode tokens in chunks
|
||||
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
|
||||
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch_.token + i,
|
||||
nullptr,
|
||||
batch_.pos + i,
|
||||
batch_.n_seq_id + i,
|
||||
batch_.seq_id + i,
|
||||
batch_.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
if (ret != 0) {
|
||||
throw std::runtime_error("Failed to eval");
|
||||
}
|
||||
|
||||
for (auto& request : requests_) {
|
||||
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int32_t i_batch = request.i_batch - i;
|
||||
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
|
||||
request.n_past += request.tokens.size();
|
||||
|
||||
request.tokens.clear();
|
||||
request.tokens.push_back(next_token);
|
||||
|
||||
result.push_back(request.id);
|
||||
result.push_back(next_token);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t eos_token_id() const override {
|
||||
return llama_token_eos(llama_get_model(ctx_.get()));
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t sample() const {
|
||||
auto* ctx = ctx_.get();
|
||||
|
||||
auto logits = llama_get_logits_ith(ctx, 0);
|
||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
// Greedy sampling (always select the highest logit).
|
||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
}
|
||||
|
||||
void eval(llama_token* data, size_t size, bool reset) {
|
||||
if (reset) {
|
||||
n_past_ = 0;
|
||||
}
|
||||
|
||||
auto* ctx = ctx_.get();
|
||||
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
||||
if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) {
|
||||
throw std::runtime_error("Failed to eval");
|
||||
}
|
||||
|
||||
n_past_ += size;
|
||||
}
|
||||
|
||||
size_t n_past_;
|
||||
owned<llama_model> model_;
|
||||
owned<llama_context> ctx_;
|
||||
|
||||
llama_batch batch_;
|
||||
|
||||
std::vector<Request> requests_;
|
||||
std::deque<Request> pending_requests_;
|
||||
std::unordered_set<uint32_t> stopped_requests_;
|
||||
};
|
||||
|
||||
static int g_llama_cpp_log_level = 0;
|
||||
|
|
@ -100,6 +203,7 @@ struct BackendInitializer {
|
|||
llama_backend_free();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use derive_builder::Builder;
|
||||
use ffi::create_engine;
|
||||
use futures::{lock::Mutex, stream::BoxStream};
|
||||
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
|
||||
use tabby_inference::{
|
||||
decoding::{DecodingFactory, IncrementalDecoding},
|
||||
helpers, TextGeneration, TextGenerationOptions,
|
||||
};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::{
|
||||
sync::mpsc::{channel, Sender},
|
||||
task::yield_now,
|
||||
};
|
||||
|
||||
#[cxx::bridge(namespace = "llama")]
|
||||
mod ffi {
|
||||
|
|
@ -17,46 +25,168 @@ mod ffi {
|
|||
|
||||
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
|
||||
fn end(self: Pin<&mut TextInferenceEngine>);
|
||||
fn add_request(
|
||||
self: Pin<&mut TextInferenceEngine>,
|
||||
request_id: u32,
|
||||
input_token_ids: &[u32],
|
||||
);
|
||||
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
|
||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<u32>>;
|
||||
|
||||
fn eos_token(&self) -> u32;
|
||||
fn eos_token_id(&self) -> u32;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
struct InferenceRequest {
|
||||
tx: Sender<String>,
|
||||
decoding: IncrementalDecoding,
|
||||
}
|
||||
|
||||
struct AsyncTextInferenceEngine {
|
||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
decoding_factory: DecodingFactory,
|
||||
requests: Mutex<HashMap<u32, InferenceRequest>>,
|
||||
|
||||
next_request_id: Mutex<u32>,
|
||||
eos_token_id: u32,
|
||||
}
|
||||
|
||||
impl AsyncTextInferenceEngine {
|
||||
fn create(engine: UniquePtr<ffi::TextInferenceEngine>, tokenizer: Tokenizer) -> Self {
|
||||
Self {
|
||||
eos_token_id: engine.eos_token_id(),
|
||||
engine: Mutex::new(engine),
|
||||
tokenizer: Arc::new(tokenizer),
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
requests: Mutex::new(HashMap::new()),
|
||||
next_request_id: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
async fn background_job(&self) {
|
||||
let mut requests = self.requests.lock().await;
|
||||
if requests.len() == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut engine = self.engine.lock().await;
|
||||
|
||||
let Ok(result) = engine.as_mut().unwrap().step() else {
|
||||
panic!("Failed to evaluation");
|
||||
};
|
||||
|
||||
for i in (0..result.len()).step_by(2) {
|
||||
let request_id = result[i];
|
||||
let token_id = result[i + 1];
|
||||
|
||||
let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap();
|
||||
let mut stopped = false;
|
||||
|
||||
if tx.is_closed() || token_id == self.eos_token_id {
|
||||
// Cancelled by client side or hit eos.
|
||||
stopped = true;
|
||||
} else if let Some(new_text) = decoding.next_token(token_id) {
|
||||
tx.send(new_text).await.expect("send failed");
|
||||
} else {
|
||||
// Stoop words stopped
|
||||
stopped = true;
|
||||
}
|
||||
|
||||
if stopped {
|
||||
requests.remove(&request_id);
|
||||
engine.as_mut().unwrap().stop_request(request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||
self.tokenizer.clone(),
|
||||
input_token_ids,
|
||||
options.language,
|
||||
);
|
||||
|
||||
let (tx, mut rx) = channel::<String>(4);
|
||||
{
|
||||
let mut engine = self.engine.lock().await;
|
||||
let engine = engine.as_mut().unwrap();
|
||||
|
||||
let mut request_id = self.next_request_id.lock().await;
|
||||
self.requests
|
||||
.lock()
|
||||
.await
|
||||
.insert(*request_id, InferenceRequest { tx, decoding });
|
||||
engine.add_request(*request_id, input_token_ids);
|
||||
|
||||
// 2048 should be large enough to avoid collision.
|
||||
*request_id = (*request_id + 1) % 2048;
|
||||
}
|
||||
|
||||
let s = stream! {
|
||||
let mut length = 0;
|
||||
while let Some(new_text) = rx.recv().await {
|
||||
yield new_text;
|
||||
length += 1;
|
||||
if length >= options.max_decoding_length {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rx.close();
|
||||
};
|
||||
|
||||
Box::pin(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct LlamaEngineOptions {
|
||||
pub struct LlamaTextGenerationOptions {
|
||||
model_path: String,
|
||||
tokenizer_path: String,
|
||||
use_gpu: bool,
|
||||
}
|
||||
|
||||
pub struct LlamaEngine {
|
||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
decoding_factory: DecodingFactory,
|
||||
pub struct LlamaTextGeneration {
|
||||
engine: Arc<AsyncTextInferenceEngine>,
|
||||
}
|
||||
|
||||
impl LlamaEngine {
|
||||
pub fn create(options: LlamaEngineOptions) -> Self {
|
||||
impl LlamaTextGeneration {
|
||||
pub fn create(options: LlamaTextGenerationOptions) -> Self {
|
||||
let engine = create_engine(options.use_gpu, &options.model_path);
|
||||
if engine.is_null() {
|
||||
panic!("Unable to load model: {}", options.model_path);
|
||||
}
|
||||
LlamaEngine {
|
||||
engine: Mutex::new(engine),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
decoding_factory: DecodingFactory::default(),
|
||||
let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap();
|
||||
let ret = LlamaTextGeneration {
|
||||
engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)),
|
||||
};
|
||||
ret.start_background_job();
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn start_background_job(&self) {
|
||||
let engine = self.engine.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
engine.background_job().await;
|
||||
yield_now().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TextGeneration for LlamaEngine {
|
||||
impl TextGeneration for LlamaTextGeneration {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let s = self.generate_stream(prompt, options).await;
|
||||
helpers::stream_to_string(s).await
|
||||
|
|
@ -67,38 +197,7 @@ impl TextGeneration for LlamaEngine {
|
|||
prompt: &str,
|
||||
options: TextGenerationOptions,
|
||||
) -> BoxStream<String> {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
|
||||
let s = stream! {
|
||||
let mut engine = self.engine.lock().await;
|
||||
let mut engine = engine.as_mut().unwrap();
|
||||
let eos_token = engine.eos_token();
|
||||
|
||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||
engine.as_mut().start(input_token_ids);
|
||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
|
||||
let mut n_remains = options.max_decoding_length ;
|
||||
while n_remains > 0 {
|
||||
let Ok(next_token_id) = engine.as_mut().step() else {
|
||||
panic!("Failed to eval");
|
||||
};
|
||||
if next_token_id == eos_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(new_text) = decoding.next_token(next_token_id) {
|
||||
yield new_text;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
||||
n_remains -= 1;
|
||||
}
|
||||
|
||||
engine.end();
|
||||
};
|
||||
|
||||
Box::pin(s)
|
||||
self.engine.generate_stream(prompt, options).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use clap::Args;
|
||||
use tabby_download::Downloader;
|
||||
use tracing::{info, log::warn};
|
||||
use tracing::info;
|
||||
|
||||
use crate::fatal;
|
||||
|
||||
|
|
|
|||
|
|
@ -39,14 +39,14 @@ pub struct EngineInfo {
|
|||
}
|
||||
|
||||
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_dir.ggml_q8_0_v2_file())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
|
||||
Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
|
||||
}
|
||||
|
||||
fn get_model_dir(model: &str) -> ModelDir {
|
||||
|
|
|
|||
Loading…
Reference in New Issue