feat: support continuous batching in llama.cpp backend (#659)
* refactor: switch back to llama batch interface * feat: support cont batchingrelease-notes-05
parent
14d03b6826
commit
7bd99d14c0
|
|
@ -9,11 +9,11 @@ class TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
virtual ~TextInferenceEngine();
|
virtual ~TextInferenceEngine();
|
||||||
|
|
||||||
virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
|
virtual void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) = 0;
|
||||||
virtual uint32_t step() = 0;
|
virtual void stop_request(uint32_t request_id) = 0;
|
||||||
virtual void end() = 0;
|
virtual rust::Vec<uint32_t> step() = 0;
|
||||||
|
|
||||||
virtual uint32_t eos_token() const = 0;
|
virtual uint32_t eos_token_id() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
|
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <deque>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
|
|
@ -10,8 +12,34 @@ namespace llama {
|
||||||
TextInferenceEngine::~TextInferenceEngine() {}
|
TextInferenceEngine::~TextInferenceEngine() {}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static size_t N_BATCH = 512; // # per batch inference.
|
int get_parallelism() {
|
||||||
static size_t N_CTX = 4096; // # max kv history.
|
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
|
||||||
|
if (parallelism) {
|
||||||
|
return std::stoi(parallelism);
|
||||||
|
} else {
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t N_CONCURRENT_REQUESTS = get_parallelism();
|
||||||
|
|
||||||
|
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||||
|
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||||
|
|
||||||
|
struct Request {
|
||||||
|
Request(size_t request_id, rust::Slice<const uint32_t> input_token_ids) :
|
||||||
|
id(request_id),
|
||||||
|
tokens(input_token_ids.begin(), input_token_ids.end()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t id = -1;
|
||||||
|
llama_seq_id seq_id = -1;
|
||||||
|
|
||||||
|
std::vector<llama_token> tokens;
|
||||||
|
size_t i_batch = -1;
|
||||||
|
size_t n_past = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
@ -21,61 +49,136 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||||
model_(std::move(model)),
|
model_(std::move(model)),
|
||||||
ctx_(std::move(ctx)) {
|
ctx_(std::move(ctx)) {
|
||||||
|
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void start(rust::Slice<const uint32_t> input_token_ids) override {
|
~TextInferenceEngineImpl() {
|
||||||
|
llama_batch_free(batch_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) override {
|
||||||
|
pending_requests_.push_back(Request(request_id, input_token_ids));
|
||||||
|
}
|
||||||
|
|
||||||
|
void stop_request(uint32_t request_id) override {
|
||||||
|
stopped_requests_.insert(request_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
rust::Vec<uint32_t> step() override {
|
||||||
auto* ctx = ctx_.get();
|
auto* ctx = ctx_.get();
|
||||||
llama_reset_timings(ctx);
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
|
// Remove stopped requests.
|
||||||
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
|
if (!stopped_requests_.empty()) {
|
||||||
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
|
std::vector<Request> requests;
|
||||||
|
for (auto& request : requests_) {
|
||||||
|
if (stopped_requests_.count(request.id) > 0) {
|
||||||
|
// Release KV cache.
|
||||||
|
llama_kv_cache_seq_rm(ctx_.get(), request.id, -1, -1);
|
||||||
|
} else {
|
||||||
|
requests.emplace_back(request);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
requests_ = requests;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add pending requests.
|
||||||
|
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
|
||||||
|
Request request = std::move(pending_requests_.front());
|
||||||
|
pending_requests_.pop_front();
|
||||||
|
|
||||||
|
// Ignore stopped pending requests.
|
||||||
|
if (stopped_requests_.count(request.id) > 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
requests_.push_back(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear stopped requests.
|
||||||
|
stopped_requests_.clear();
|
||||||
|
|
||||||
|
if (requests_.size() == 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the batch.
|
||||||
|
batch_.n_tokens = 0;
|
||||||
|
|
||||||
|
// Insert tokens from ongoing requests to batch.
|
||||||
|
for (auto& request : requests_) {
|
||||||
|
const size_t n_tokens = batch_.n_tokens;
|
||||||
|
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||||
|
batch_.token[n_tokens + i] = request.tokens[i];
|
||||||
|
batch_.pos[n_tokens + i] = request.n_past + i;
|
||||||
|
batch_.n_seq_id[n_tokens + i] = 1;
|
||||||
|
batch_.seq_id[n_tokens + i][0] = request.id;
|
||||||
|
batch_.logits[n_tokens + i] = false;
|
||||||
|
}
|
||||||
|
batch_.n_tokens += request.tokens.size();
|
||||||
|
|
||||||
|
batch_.logits[batch_.n_tokens - 1] = true;
|
||||||
|
request.i_batch = batch_.n_tokens - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
rust::Vec<uint32_t> result;
|
||||||
|
result.reserve(requests_.size() * 2);
|
||||||
|
|
||||||
|
// Decode tokens in chunks
|
||||||
|
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
|
||||||
|
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
|
||||||
|
llama_batch batch_view = {
|
||||||
|
n_tokens,
|
||||||
|
batch_.token + i,
|
||||||
|
nullptr,
|
||||||
|
batch_.pos + i,
|
||||||
|
batch_.n_seq_id + i,
|
||||||
|
batch_.seq_id + i,
|
||||||
|
batch_.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
|
||||||
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
if (ret != 0) {
|
||||||
|
throw std::runtime_error("Failed to eval");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& request : requests_) {
|
||||||
|
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t i_batch = request.i_batch - i;
|
||||||
|
auto logits = llama_get_logits_ith(ctx, i_batch);
|
||||||
|
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||||
|
|
||||||
|
request.n_past += request.tokens.size();
|
||||||
|
|
||||||
|
request.tokens.clear();
|
||||||
|
request.tokens.push_back(next_token);
|
||||||
|
|
||||||
|
result.push_back(request.id);
|
||||||
|
result.push_back(next_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t step() override {
|
uint32_t eos_token_id() const override {
|
||||||
const llama_token id = sample();
|
|
||||||
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
void end() override {
|
|
||||||
llama_print_timings(ctx_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t eos_token() const override {
|
|
||||||
return llama_token_eos(llama_get_model(ctx_.get()));
|
return llama_token_eos(llama_get_model(ctx_.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint32_t sample() const {
|
|
||||||
auto* ctx = ctx_.get();
|
|
||||||
|
|
||||||
auto logits = llama_get_logits_ith(ctx, 0);
|
|
||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
||||||
|
|
||||||
// Greedy sampling (always select the highest logit).
|
|
||||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval(llama_token* data, size_t size, bool reset) {
|
|
||||||
if (reset) {
|
|
||||||
n_past_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* ctx = ctx_.get();
|
|
||||||
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(data, size, n_past_, 0))) {
|
|
||||||
throw std::runtime_error("Failed to eval");
|
|
||||||
}
|
|
||||||
|
|
||||||
n_past_ += size;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t n_past_;
|
|
||||||
owned<llama_model> model_;
|
owned<llama_model> model_;
|
||||||
owned<llama_context> ctx_;
|
owned<llama_context> ctx_;
|
||||||
|
|
||||||
|
llama_batch batch_;
|
||||||
|
|
||||||
|
std::vector<Request> requests_;
|
||||||
|
std::deque<Request> pending_requests_;
|
||||||
|
std::unordered_set<uint32_t> stopped_requests_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int g_llama_cpp_log_level = 0;
|
static int g_llama_cpp_log_level = 0;
|
||||||
|
|
@ -100,6 +203,7 @@ struct BackendInitializer {
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
|
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,20 @@
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use cxx::UniquePtr;
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use ffi::create_engine;
|
use ffi::create_engine;
|
||||||
use futures::{lock::Mutex, stream::BoxStream};
|
use futures::{lock::Mutex, stream::BoxStream};
|
||||||
use tabby_inference::{decoding::DecodingFactory, helpers, TextGeneration, TextGenerationOptions};
|
use tabby_inference::{
|
||||||
|
decoding::{DecodingFactory, IncrementalDecoding},
|
||||||
|
helpers, TextGeneration, TextGenerationOptions,
|
||||||
|
};
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokio::{
|
||||||
|
sync::mpsc::{channel, Sender},
|
||||||
|
task::yield_now,
|
||||||
|
};
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "llama")]
|
#[cxx::bridge(namespace = "llama")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
@ -17,46 +25,168 @@ mod ffi {
|
||||||
|
|
||||||
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
|
fn add_request(
|
||||||
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
|
self: Pin<&mut TextInferenceEngine>,
|
||||||
fn end(self: Pin<&mut TextInferenceEngine>);
|
request_id: u32,
|
||||||
|
input_token_ids: &[u32],
|
||||||
|
);
|
||||||
|
fn stop_request(self: Pin<&mut TextInferenceEngine>, request_id: u32);
|
||||||
|
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<Vec<u32>>;
|
||||||
|
|
||||||
fn eos_token(&self) -> u32;
|
fn eos_token_id(&self) -> u32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||||
|
|
||||||
|
struct InferenceRequest {
|
||||||
|
tx: Sender<String>,
|
||||||
|
decoding: IncrementalDecoding,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AsyncTextInferenceEngine {
|
||||||
|
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||||
|
tokenizer: Arc<Tokenizer>,
|
||||||
|
decoding_factory: DecodingFactory,
|
||||||
|
requests: Mutex<HashMap<u32, InferenceRequest>>,
|
||||||
|
|
||||||
|
next_request_id: Mutex<u32>,
|
||||||
|
eos_token_id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncTextInferenceEngine {
|
||||||
|
fn create(engine: UniquePtr<ffi::TextInferenceEngine>, tokenizer: Tokenizer) -> Self {
|
||||||
|
Self {
|
||||||
|
eos_token_id: engine.eos_token_id(),
|
||||||
|
engine: Mutex::new(engine),
|
||||||
|
tokenizer: Arc::new(tokenizer),
|
||||||
|
decoding_factory: DecodingFactory::default(),
|
||||||
|
requests: Mutex::new(HashMap::new()),
|
||||||
|
next_request_id: Mutex::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn background_job(&self) {
|
||||||
|
let mut requests = self.requests.lock().await;
|
||||||
|
if requests.len() == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut engine = self.engine.lock().await;
|
||||||
|
|
||||||
|
let Ok(result) = engine.as_mut().unwrap().step() else {
|
||||||
|
panic!("Failed to evaluation");
|
||||||
|
};
|
||||||
|
|
||||||
|
for i in (0..result.len()).step_by(2) {
|
||||||
|
let request_id = result[i];
|
||||||
|
let token_id = result[i + 1];
|
||||||
|
|
||||||
|
let InferenceRequest { tx, decoding } = requests.get_mut(&request_id).unwrap();
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if tx.is_closed() || token_id == self.eos_token_id {
|
||||||
|
// Cancelled by client side or hit eos.
|
||||||
|
stopped = true;
|
||||||
|
} else if let Some(new_text) = decoding.next_token(token_id) {
|
||||||
|
tx.send(new_text).await.expect("send failed");
|
||||||
|
} else {
|
||||||
|
// Stoop words stopped
|
||||||
|
stopped = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopped {
|
||||||
|
requests.remove(&request_id);
|
||||||
|
engine.as_mut().unwrap().stop_request(request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
options: TextGenerationOptions,
|
||||||
|
) -> BoxStream<String> {
|
||||||
|
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||||
|
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
||||||
|
let decoding = self.decoding_factory.create_incremental_decoding(
|
||||||
|
self.tokenizer.clone(),
|
||||||
|
input_token_ids,
|
||||||
|
options.language,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (tx, mut rx) = channel::<String>(4);
|
||||||
|
{
|
||||||
|
let mut engine = self.engine.lock().await;
|
||||||
|
let engine = engine.as_mut().unwrap();
|
||||||
|
|
||||||
|
let mut request_id = self.next_request_id.lock().await;
|
||||||
|
self.requests
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(*request_id, InferenceRequest { tx, decoding });
|
||||||
|
engine.add_request(*request_id, input_token_ids);
|
||||||
|
|
||||||
|
// 2048 should be large enough to avoid collision.
|
||||||
|
*request_id = (*request_id + 1) % 2048;
|
||||||
|
}
|
||||||
|
|
||||||
|
let s = stream! {
|
||||||
|
let mut length = 0;
|
||||||
|
while let Some(new_text) = rx.recv().await {
|
||||||
|
yield new_text;
|
||||||
|
length += 1;
|
||||||
|
if length >= options.max_decoding_length {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rx.close();
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Builder, Debug)]
|
#[derive(Builder, Debug)]
|
||||||
pub struct LlamaEngineOptions {
|
pub struct LlamaTextGenerationOptions {
|
||||||
model_path: String,
|
model_path: String,
|
||||||
tokenizer_path: String,
|
tokenizer_path: String,
|
||||||
use_gpu: bool,
|
use_gpu: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaEngine {
|
pub struct LlamaTextGeneration {
|
||||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
engine: Arc<AsyncTextInferenceEngine>,
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
decoding_factory: DecodingFactory,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaEngine {
|
impl LlamaTextGeneration {
|
||||||
pub fn create(options: LlamaEngineOptions) -> Self {
|
pub fn create(options: LlamaTextGenerationOptions) -> Self {
|
||||||
let engine = create_engine(options.use_gpu, &options.model_path);
|
let engine = create_engine(options.use_gpu, &options.model_path);
|
||||||
if engine.is_null() {
|
if engine.is_null() {
|
||||||
panic!("Unable to load model: {}", options.model_path);
|
panic!("Unable to load model: {}", options.model_path);
|
||||||
}
|
}
|
||||||
LlamaEngine {
|
let tokenizer = Tokenizer::from_file(&options.tokenizer_path).unwrap();
|
||||||
engine: Mutex::new(engine),
|
let ret = LlamaTextGeneration {
|
||||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
engine: Arc::new(AsyncTextInferenceEngine::create(engine, tokenizer)),
|
||||||
decoding_factory: DecodingFactory::default(),
|
};
|
||||||
}
|
ret.start_background_job();
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_background_job(&self) {
|
||||||
|
let engine = self.engine.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
engine.background_job().await;
|
||||||
|
yield_now().await;
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl TextGeneration for LlamaEngine {
|
impl TextGeneration for LlamaTextGeneration {
|
||||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||||
let s = self.generate_stream(prompt, options).await;
|
let s = self.generate_stream(prompt, options).await;
|
||||||
helpers::stream_to_string(s).await
|
helpers::stream_to_string(s).await
|
||||||
|
|
@ -67,38 +197,7 @@ impl TextGeneration for LlamaEngine {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
options: TextGenerationOptions,
|
options: TextGenerationOptions,
|
||||||
) -> BoxStream<String> {
|
) -> BoxStream<String> {
|
||||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
self.engine.generate_stream(prompt, options).await
|
||||||
|
|
||||||
let s = stream! {
|
|
||||||
let mut engine = self.engine.lock().await;
|
|
||||||
let mut engine = engine.as_mut().unwrap();
|
|
||||||
let eos_token = engine.eos_token();
|
|
||||||
|
|
||||||
let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
|
|
||||||
engine.as_mut().start(input_token_ids);
|
|
||||||
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
|
|
||||||
let mut n_remains = options.max_decoding_length ;
|
|
||||||
while n_remains > 0 {
|
|
||||||
let Ok(next_token_id) = engine.as_mut().step() else {
|
|
||||||
panic!("Failed to eval");
|
|
||||||
};
|
|
||||||
if next_token_id == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(new_text) = decoding.next_token(next_token_id) {
|
|
||||||
yield new_text;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
n_remains -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
engine.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
Box::pin(s)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use tabby_download::Downloader;
|
use tabby_download::Downloader;
|
||||||
use tracing::{info, log::warn};
|
use tracing::info;
|
||||||
|
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,14 +39,14 @@ pub struct EngineInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||||
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
|
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||||
.model_path(model_dir.ggml_q8_0_v2_file())
|
.model_path(model_dir.ggml_q8_0_v2_file())
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
.use_gpu(device.ggml_use_gpu())
|
.use_gpu(device.ggml_use_gpu())
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
|
Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_dir(model: &str) -> ModelDir {
|
fn get_model_dir(model: &str) -> ModelDir {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue