feat: llama.cpp for metal support [TAB-146] (#391)

* feat: init commit adding llama-cpp-bindings

* add llama.cpp submodule

* add LlamaEngine to hold llama context / llama model

* add cxxbridge

* add basic greedy sampling

* move files

* make compile success

* connect TextGeneration with LlamaEngine

* experimental support llama.cpp

* add metal device

* add Accelerate

* fix namespace for llama-cpp-bindings

* fix lint

* move stepping logic to rust

* add stop words package

* use stop-words in ctranslate2-bindings

* use raw string for regex

* use Arc<Tokenizer> for sharing tokenizers

* refactor: remove useless stop_words_encoding_offset

* switch to tokenizers 0.13.4-rc.3

* fix lints in cpp

* simplify implementation of greedy decoding

* feat: split metal feature for llama backend

* add ci

* update ci

* build tabby bin in ci build
release-v0.1
Meng Zhang 2023-09-03 09:59:07 +08:00 committed by GitHub
parent 1c1cf44639
commit 3573d4378e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 500 additions and 132 deletions

View File

@ -43,8 +43,10 @@ jobs:
include: include:
- os: macos-11 - os: macos-11
target: aarch64-apple-darwin target: aarch64-apple-darwin
flags: "--features metal"
- os: ubuntu-latest - os: ubuntu-latest
target: x86_64-unknown-linux-gnu target: x86_64-unknown-linux-gnu
flags: ""
env: env:
SCCACHE_GHA_ENABLED: true SCCACHE_GHA_ENABLED: true
@ -81,7 +83,7 @@ jobs:
~/.cargo/git ~/.cargo/git
- run: bash ./ci/prepare_build_environment.sh - run: bash ./ci/prepare_build_environment.sh
- name: Bulid release binary - name: Bulid release binary
run: cargo build --no-default-features --release --target ${{ matrix.target }} run: cargo build --bin tabby --no-default-features ${{ matrix.flags }} --release --target ${{ matrix.target }}
- name: Rename release binary - name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }} run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}

3
.gitmodules vendored
View File

@ -1,3 +1,6 @@
[submodule "crates/ctranslate2-bindings/CTranslate2"] [submodule "crates/ctranslate2-bindings/CTranslate2"]
path = crates/ctranslate2-bindings/CTranslate2 path = crates/ctranslate2-bindings/CTranslate2
url = https://github.com/OpenNMT/CTranslate2.git url = https://github.com/OpenNMT/CTranslate2.git
[submodule "crates/llama-cpp-bindings/llama.cpp"]
path = crates/llama-cpp-bindings/llama.cpp
url = https://github.com/ggerganov/llama.cpp

96
Cargo.lock generated
View File

@ -617,10 +617,9 @@ dependencies = [
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"dashmap",
"derive_builder", "derive_builder",
"regex",
"rust-cxx-cmake-bridge", "rust-cxx-cmake-bridge",
"stop-words",
"tabby-inference", "tabby-inference",
"tokenizers", "tokenizers",
"tokio", "tokio",
@ -743,12 +742,12 @@ dependencies = [
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "5.4.0" version = "5.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"hashbrown", "hashbrown 0.14.0",
"lock_api", "lock_api",
"once_cell", "once_cell",
"parking_lot_core", "parking_lot_core",
@ -1181,6 +1180,12 @@ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.4.1" version = "0.4.1"
@ -1352,7 +1357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"hashbrown", "hashbrown 0.12.3",
"serde", "serde",
] ]
@ -1547,11 +1552,26 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
[[package]]
name = "llama-cpp-bindings"
version = "0.1.0"
dependencies = [
"async-trait",
"cmake",
"cxx",
"cxx-build",
"derive_builder",
"stop-words",
"tabby-inference",
"tokenizers",
"tokio",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.9" version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"scopeguard", "scopeguard",
@ -1586,7 +1606,7 @@ version = "0.7.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a"
dependencies = [ dependencies = [
"hashbrown", "hashbrown 0.12.3",
] ]
[[package]] [[package]]
@ -1617,7 +1637,7 @@ version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1"
dependencies = [ dependencies = [
"regex-automata", "regex-automata 0.1.10",
] ]
[[package]] [[package]]
@ -1626,7 +1646,7 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [ dependencies = [
"regex-automata", "regex-automata 0.1.10",
] ]
[[package]] [[package]]
@ -1647,9 +1667,9 @@ dependencies = [
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" checksum = "5486aed0026218e61b8a01d5fbd5a0a134649abb71a0e53b7bc088529dced86e"
[[package]] [[package]]
name = "memmap2" name = "memmap2"
@ -1887,9 +1907,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.17.1" version = "1.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]] [[package]]
name = "oneshot" name = "oneshot"
@ -2073,15 +2093,15 @@ dependencies = [
[[package]] [[package]]
name = "parking_lot_core" name = "parking_lot_core"
version = "0.9.7" version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall 0.2.16", "redox_syscall 0.3.5",
"smallvec", "smallvec",
"windows-sys 0.45.0", "windows-targets 0.48.0",
] ]
[[package]] [[package]]
@ -2388,13 +2408,14 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.8.4" version = "1.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47"
dependencies = [ dependencies = [
"aho-corasick 1.0.1", "aho-corasick 1.0.1",
"memchr", "memchr",
"regex-syntax 0.7.2", "regex-automata 0.3.8",
"regex-syntax 0.7.5",
] ]
[[package]] [[package]]
@ -2406,6 +2427,17 @@ dependencies = [
"regex-syntax 0.6.29", "regex-syntax 0.6.29",
] ]
[[package]]
name = "regex-automata"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795"
dependencies = [
"aho-corasick 1.0.1",
"memchr",
"regex-syntax 0.7.5",
]
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.29" version = "0.6.29"
@ -2414,9 +2446,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.7.2" version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
@ -2807,6 +2839,15 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stop-words"
version = "0.1.0"
dependencies = [
"dashmap",
"regex",
"tokenizers",
]
[[package]] [[package]]
name = "strfmt" name = "strfmt"
version = "0.2.4" version = "0.2.4"
@ -2907,6 +2948,7 @@ dependencies = [
"ctranslate2-bindings", "ctranslate2-bindings",
"hyper", "hyper",
"lazy_static", "lazy_static",
"llama-cpp-bindings",
"mime_guess", "mime_guess",
"nvml-wrapper", "nvml-wrapper",
"opentelemetry", "opentelemetry",
@ -3217,9 +3259,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.13.3" version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5" checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f"
dependencies = [ dependencies = [
"aho-corasick 0.7.20", "aho-corasick 0.7.20",
"cached-path", "cached-path",
@ -3240,7 +3282,7 @@ dependencies = [
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
"regex-syntax 0.6.29", "regex-syntax 0.7.5",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -7,6 +7,8 @@ members = [
"crates/tabby-inference", "crates/tabby-inference",
"crates/ctranslate2-bindings", "crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge", "crates/rust-cxx-cmake-bridge",
"crates/llama-cpp-bindings",
"crates/stop-words",
] ]
[workspace.package] [workspace.package]
@ -28,3 +30,5 @@ serde-jsonlines = "0.4.0"
tantivy = "0.19.2" tantivy = "0.19.2"
async-trait = "0.1.72" async-trait = "0.1.72"
reqwest = { version = "0.11.18" } reqwest = { version = "0.11.18" }
derive_builder = "0.12.0"
tokenizers = "0.13.4-rc3"

View File

@ -5,14 +5,13 @@ edition = "2021"
[dependencies] [dependencies]
cxx = "1.0" cxx = "1.0"
dashmap = "5.4.0" derive_builder = { workspace = true }
derive_builder = "0.12.0" tokenizers = { workspace = true }
regex = "1.8.4"
tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] } tokio = { workspace = true, features = ["rt"] }
tokio-util = { workspace = true } tokio-util = { workspace = true }
tabby-inference = { path = "../tabby-inference" } tabby-inference = { path = "../tabby-inference" }
async-trait = { workspace = true } async-trait = { workspace = true }
stop-words = { path = "../stop-words" }
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"

View File

@ -1,7 +1,8 @@
use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashMap;
use derive_builder::Builder; use derive_builder::Builder;
use regex::Regex; use stop_words::{StopWords, StopWordsCondition};
use tabby_inference::{TextGeneration, TextGenerationOptions}; use tabby_inference::{TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@ -63,31 +64,26 @@ pub struct CTranslate2EngineOptions {
num_replicas_per_device: usize, num_replicas_per_device: usize,
compute_type: String, compute_type: String,
stop_words_encoding_offset: Option<usize>,
} }
pub struct InferenceContext { pub struct InferenceContext {
stop_re: Option<Regex>, stop_condition: StopWordsCondition,
cancel: CancellationToken, cancel: CancellationToken,
reversed_output_text: String,
} }
impl InferenceContext { impl InferenceContext {
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self { fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self {
InferenceContext { InferenceContext {
stop_re, stop_condition,
cancel, cancel,
reversed_output_text: "".to_owned(),
} }
} }
} }
pub struct CTranslate2Engine { pub struct CTranslate2Engine {
engine: cxx::SharedPtr<ffi::TextInferenceEngine>, engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer, stop_words: StopWords,
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, tokenizer: Arc<Tokenizer>,
stop_words_encoding_offset: Option<usize>,
} }
impl CTranslate2Engine { impl CTranslate2Engine {
@ -103,9 +99,8 @@ impl CTranslate2Engine {
return Self { return Self {
engine, engine,
stop_regex_cache: DashMap::new(), stop_words: StopWords::default(),
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
stop_words_encoding_offset: options.stop_words_encoding_offset,
}; };
} }
} }
@ -120,25 +115,10 @@ impl TextGeneration for CTranslate2Engine {
let cancel_for_inference = cancel.clone(); let cancel_for_inference = cancel.clone();
let _guard = cancel.drop_guard(); let _guard = cancel.drop_guard();
let stop_re: Option<Regex> = if options.stop_words.is_empty() { let stop_condition = self
None .stop_words
} else { .create_condition(self.tokenizer.clone(), options.stop_words);
let mut re = self.stop_regex_cache.get(options.stop_words); let context = InferenceContext::new(stop_condition, cancel_for_inference);
if re.is_none() {
self.stop_regex_cache.insert(
options.stop_words,
create_stop_regex(
&self.tokenizer,
options.stop_words,
self.stop_words_encoding_offset,
),
);
re = self.stop_regex_cache.get(options.stop_words);
}
re.map(|x| x.value().clone())
};
let context = InferenceContext::new(stop_re, cancel_for_inference);
let output_ids = tokio::task::spawn_blocking(move || { let output_ids = tokio::task::spawn_blocking(move || {
let context = Box::new(context); let context = Box::new(context);
engine.inference( engine.inference(
@ -151,63 +131,19 @@ impl TextGeneration for CTranslate2Engine {
}) })
.await .await
.expect("Inference failed"); .expect("Inference failed");
self.tokenizer.decode(output_ids, true).unwrap() self.tokenizer.decode(&output_ids, true).unwrap()
} }
} }
fn inference_callback( fn inference_callback(
context: &mut InferenceContext, context: &mut InferenceContext,
_step: usize, _step: usize,
_token_id: u32, token_id: u32,
token: String, _token: String,
) -> bool { ) -> bool {
if context.cancel.is_cancelled() { if context.cancel.is_cancelled() {
true true
} else if let Some(re) = &context.stop_re {
let mut new_token = reverse(&token);
new_token.push_str(&context.reversed_output_text);
context.reversed_output_text = new_token;
re.find(&context.reversed_output_text).is_some()
} else { } else {
false context.stop_condition.next_token(token_id)
} }
} }
fn reverse(s: &String) -> String {
// Special treatment for byte fallback token.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
// Keep byte fallback tokens like <0x0A> as is, do not reverse it.
// This won't really affect stop words regex logic, but brings more readability when
// debugging decoding steps.
s.to_owned()
} else {
s.chars().rev().collect()
}
}
fn create_stop_regex(
tokenizer: &Tokenizer,
stop_words: &[&str],
stop_words_encoding_offset: Option<usize>,
) -> Regex {
let encodings = tokenizer
.encode_batch(stop_words.to_owned(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings
.iter()
.map(|x| {
x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..]
.iter()
.rev()
.map(reverse)
.collect::<Vec<String>>()
.join("")
})
.collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
Regex::new(&regex_string).unwrap()
}

View File

@ -0,0 +1,17 @@
[package]
name = "llama-cpp-bindings"
version = "0.1.0"
edition = "2021"
[build-dependencies]
cxx-build = "1.0"
cmake = "0.1"
[dependencies]
cxx = "1.0"
async-trait = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
tabby-inference = { path = "../tabby-inference" }
derive_builder = { workspace = true }
tokenizers = { workspace = true }
stop-words = { version = "0.1.0", path = "../stop-words" }

View File

@ -0,0 +1,23 @@
use cmake::Config;
fn main() {
let dst = Config::new("llama.cpp").define("LLAMA_METAL", "ON").build();
println!("cargo:rerun-if-changed=cc/*.h");
println!("cargo:rerun-if-changed=cc/*.cc");
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
cxx_build::bridge("src/lib.rs")
.file("src/engine.cc")
.flag_if_supported("-Iinclude")
.flag_if_supported("-Illama.cpp")
.flag_if_supported("-std=c++14")
.compile("cxxbridge");
}

View File

@ -0,0 +1,17 @@
#pragma once
#include "rust/cxx.h"
#include <memory>
namespace llama {
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual uint32_t start(const rust::Str prompt) const = 0;
virtual uint32_t step(uint32_t next_token_id) const = 0;
};
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
} // namespace

@ -0,0 +1 @@
Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f

View File

@ -0,0 +1,112 @@
#include "engine.h"
#include <functional>
#include <vector>
#include <ggml.h>
#include <llama.h>
namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}
namespace {
template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
}
uint32_t start(const rust::Str prompt) const override {
auto* ctx = ctx_.get();
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
eval(tokens_list, /* reset = */ true);
return sample();
}
uint32_t step(uint32_t next_token_id) const override {
eval({ static_cast<llama_token>(next_token_id) }, /* reset = */ false);
return sample();
}
private:
uint32_t sample() const {
auto* ctx = ctx_.get();
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
// Greedy sampling (always select the highest logit).
return std::distance(logits, std::max_element(logits, logits + n_vocab));
}
bool eval(const std::vector<llama_token>& tokens_list, bool reset) const {
auto* ctx = ctx_.get();
if (llama_eval(
ctx,
tokens_list.data(),
tokens_list.size(),
reset ? 0 : llama_get_kv_cache_token_count(ctx),
/* n_threads = */ 1)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
return true;
}
owned<llama_model> model_;
owned<llama_context> ctx_;
};
struct BackendInitializer {
BackendInitializer() {
llama_backend_init(false);
}
~BackendInitializer() {
llama_backend_free();
}
};
} // namespace
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
static BackendInitializer initializer;
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_gpu_layers = 4;
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);
if (!model) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return nullptr;
}
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
return std::make_shared<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free)
);
}
} // namespace tabby

View File

@ -0,0 +1,73 @@
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use derive_builder::Builder;
use ffi::create_engine;
use stop_words::StopWords;
use tabby_inference::{TextGeneration, TextGenerationOptions};
use tokenizers::tokenizer::Tokenizer;
#[cxx::bridge(namespace = "llama")]
mod ffi {
unsafe extern "C++" {
include!("llama-cpp-bindings/include/engine.h");
type TextInferenceEngine;
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
fn start(&self, prompt: &str) -> u32;
fn step(&self, next_token_id: u32) -> u32;
}
}
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}
#[derive(Builder, Debug)]
pub struct LlamaEngineOptions {
model_path: String,
tokenizer_path: String,
}
pub struct LlamaEngine {
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
tokenizer: Arc<Tokenizer>,
stop_words: StopWords,
}
impl LlamaEngine {
pub fn create(options: LlamaEngineOptions) -> Self {
LlamaEngine {
engine: Mutex::new(create_engine(&options.model_path)),
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
stop_words: StopWords::default(),
}
}
}
#[async_trait]
impl TextGeneration for LlamaEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let engine = self.engine.lock().unwrap();
let mut next_token_id = engine.start(prompt);
let mut n_remains = options.max_decoding_length - 1;
let mut output_ids = vec![next_token_id];
let mut stop_condition = self
.stop_words
.create_condition(self.tokenizer.clone(), options.stop_words);
// FIXME(meng): supports cancellation.
while n_remains > 0 {
next_token_id = engine.step(next_token_id);
if stop_condition.next_token(next_token_id) {
break;
}
output_ids.push(next_token_id);
n_remains -= 1;
}
self.tokenizer.decode(&output_ids, true).unwrap()
}
}

View File

@ -0,0 +1,11 @@
[package]
name = "stop-words"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
dashmap = "5.5.3"
regex = "1.9.5"
tokenizers.workspace = true

View File

@ -0,0 +1,80 @@
use std::sync::Arc;
use dashmap::DashMap;
use regex::Regex;
use tokenizers::tokenizer::Tokenizer;
pub struct StopWords {
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
}
fn reverse(s: &&str) -> String {
s.chars().rev().collect()
}
impl Default for StopWords {
fn default() -> Self {
Self {
stop_regex_cache: DashMap::new(),
}
}
}
impl StopWords {
pub fn create_condition(
&self,
tokenizer: Arc<Tokenizer>,
stop_words: &'static Vec<&'static str>,
) -> StopWordsCondition {
let re = if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
}
re.map(|x| x.value().clone())
};
StopWordsCondition::new(tokenizer, re)
}
}
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(reverse).collect();
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
Regex::new(&regex_string).unwrap()
}
pub struct StopWordsCondition {
tokenizer: Arc<Tokenizer>,
stop_re: Option<Regex>,
reversed_output_text: String,
}
impl StopWordsCondition {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>) -> Self {
Self {
tokenizer,
stop_re,
reversed_output_text: String::new(),
}
}
pub fn next_token(&mut self, token_id: u32) -> bool {
if let Some(re) = &self.stop_re {
let token = self.tokenizer.decode(&[token_id], false).unwrap();
let mut new_token = reverse(&token.as_str());
new_token.push_str(&self.reversed_output_text);
self.reversed_output_text = new_token;
re.find(&self.reversed_output_text).is_some()
} else {
false
}
}
}

View File

@ -85,4 +85,8 @@ impl ModelDir {
pub fn ctranslate2_dir(&self) -> String { pub fn ctranslate2_dir(&self) -> String {
self.path_string("ctranslate2") self.path_string("ctranslate2")
} }
pub fn ggml_model_file(&self) -> String {
self.path_string("ggml/default.gguf")
}
} }

View File

@ -13,6 +13,6 @@ pub struct TextGenerationOptions {
} }
#[async_trait] #[async_trait]
pub trait TextGeneration { pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
} }

View File

@ -35,7 +35,7 @@ tantivy = { workspace = true }
anyhow = { workspace = true } anyhow = { workspace = true }
sysinfo = "0.29.8" sysinfo = "0.29.8"
nvml-wrapper = "0.9.0" nvml-wrapper = "0.9.0"
llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true }
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"
@ -49,6 +49,7 @@ features = [
default = ["scheduler"] default = ["scheduler"]
link_shared = ["ctranslate2-bindings/link_shared"] link_shared = ["ctranslate2-bindings/link_shared"]
scheduler = ["tabby-scheduler"] scheduler = ["tabby-scheduler"]
metal = ["llama-cpp-bindings"]
[build-dependencies] [build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

View File

@ -121,7 +121,7 @@ pub async fn completion(
} }
pub struct CompletionState { pub struct CompletionState {
engine: CTranslate2Engine, engine: Box<dyn TextGeneration>,
prompt_builder: prompt::PromptBuilder, prompt_builder: prompt::PromptBuilder,
} }
@ -129,21 +129,8 @@ impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let model_dir = get_model_dir(&args.model); let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir); let metadata = read_metadata(&model_dir);
let engine = create_engine(args, &model_dir, &metadata);
let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model)
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.stop_words_encoding_offset(metadata.stop_words_encoding_offset)
.build()
.unwrap();
let engine = CTranslate2Engine::create(options);
Self { Self {
engine, engine,
prompt_builder: prompt::PromptBuilder::new( prompt_builder: prompt::PromptBuilder::new(
@ -154,6 +141,59 @@ impl CompletionState {
} }
} }
#[cfg(not(feature = "metal"))]
fn create_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ctranslate2_engine(args, model_dir, metadata)
}
#[cfg(feature = "metal")]
fn create_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device != super::Device::Metal {
create_ctranslate2_engine(args, model_dir, metadata)
} else {
create_llama_engine(model_dir)
}
}
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
let device = format!("{}", args.device);
let compute_type = format!("{}", args.compute_type);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.compute_type(compute_type)
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}
#[cfg(feature = "metal")]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_model_file())
.tokenizer_path(model_dir.tokenizer_file())
.build()
.unwrap();
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
}
fn get_model_dir(model: &str) -> ModelDir { fn get_model_dir(model: &str) -> ModelDir {
if Path::new(model).exists() { if Path::new(model).exists() {
ModelDir::from(model) ModelDir::from(model)
@ -166,7 +206,6 @@ fn get_model_dir(model: &str) -> ModelDir {
struct Metadata { struct Metadata {
auto_model: String, auto_model: String,
prompt_template: Option<String>, prompt_template: Option<String>,
stop_words_encoding_offset: Option<usize>,
} }
fn read_metadata(model_dir: &ModelDir) -> Metadata { fn read_metadata(model_dir: &ModelDir) -> Metadata {

View File

@ -54,6 +54,10 @@ pub enum Device {
#[strum(serialize = "cuda")] #[strum(serialize = "cuda")]
Cuda, Cuda,
#[cfg(feature = "metal")]
#[strum(serialize = "metal")]
Metal,
} }
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]