From 3573d4378e081c5392edd58c20580be105424351 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 3 Sep 2023 09:59:07 +0800 Subject: [PATCH] feat: llama.cpp for metal support [TAB-146] (#391) * feat: init commit adding llama-cpp-bindings * add llama.cpp submodule * add LlamaEngine to hold llama context / llama model * add cxxbridge * add basic greedy sampling * move files * make compile success * connect TextGeneration with LlamaEngine * experimental support llama.cpp * add metal device * add Accelerate * fix namespace for llama-cpp-bindings * fix lint * move stepping logic to rust * add stop words package * use stop-words in ctranslate2-bindings * use raw string for regex * use Arc for sharing tokenizers * refactor: remove useless stop_words_encoding_offset * switch to tokenizers 0.13.4-rc.3 * fix lints in cpp * simplify implementation of greedy decoding * feat: split metal feature for llama backend * add ci * update ci * build tabby bin in ci build --- .github/workflows/ci.yml | 4 +- .gitmodules | 3 + Cargo.lock | 96 +++++++++++++----- Cargo.toml | 4 + crates/ctranslate2-bindings/Cargo.toml | 7 +- crates/ctranslate2-bindings/src/lib.rs | 100 ++++-------------- crates/llama-cpp-bindings/Cargo.toml | 17 ++++ crates/llama-cpp-bindings/build.rs | 23 +++++ crates/llama-cpp-bindings/include/engine.h | 17 ++++ crates/llama-cpp-bindings/llama.cpp | 1 + crates/llama-cpp-bindings/src/engine.cc | 112 +++++++++++++++++++++ crates/llama-cpp-bindings/src/lib.rs | 73 ++++++++++++++ crates/stop-words/Cargo.toml | 11 ++ crates/stop-words/src/lib.rs | 80 +++++++++++++++ crates/tabby-common/src/path.rs | 4 + crates/tabby-inference/src/lib.rs | 2 +- crates/tabby/Cargo.toml | 3 +- crates/tabby/src/serve/completions.rs | 71 ++++++++++--- crates/tabby/src/serve/mod.rs | 4 + 19 files changed, 500 insertions(+), 132 deletions(-) create mode 100644 crates/llama-cpp-bindings/Cargo.toml create mode 100644 crates/llama-cpp-bindings/build.rs create mode 100644 crates/llama-cpp-bindings/include/engine.h create mode 160000 crates/llama-cpp-bindings/llama.cpp create mode 100644 crates/llama-cpp-bindings/src/engine.cc create mode 100644 crates/llama-cpp-bindings/src/lib.rs create mode 100644 crates/stop-words/Cargo.toml create mode 100644 crates/stop-words/src/lib.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5fabaa9..ef92e5f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -43,8 +43,10 @@ jobs: include: - os: macos-11 target: aarch64-apple-darwin + flags: "--features metal" - os: ubuntu-latest target: x86_64-unknown-linux-gnu + flags: "" env: SCCACHE_GHA_ENABLED: true @@ -81,7 +83,7 @@ jobs: ~/.cargo/git - run: bash ./ci/prepare_build_environment.sh - name: Bulid release binary - run: cargo build --no-default-features --release --target ${{ matrix.target }} + run: cargo build --bin tabby --no-default-features ${{ matrix.flags }} --release --target ${{ matrix.target }} - name: Rename release binary run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }} diff --git a/.gitmodules b/.gitmodules index 17f7250..f71ffb6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "crates/ctranslate2-bindings/CTranslate2"] path = crates/ctranslate2-bindings/CTranslate2 url = https://github.com/OpenNMT/CTranslate2.git +[submodule "crates/llama-cpp-bindings/llama.cpp"] + path = crates/llama-cpp-bindings/llama.cpp + url = https://github.com/ggerganov/llama.cpp diff --git a/Cargo.lock b/Cargo.lock index 105aee6..c1e821c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -617,10 +617,9 @@ dependencies = [ "cmake", "cxx", "cxx-build", - "dashmap", "derive_builder", - "regex", "rust-cxx-cmake-bridge", + "stop-words", "tabby-inference", "tokenizers", "tokio", @@ -743,12 +742,12 @@ dependencies = [ [[package]] name = "dashmap" -version = "5.4.0" +version = "5.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown", + "hashbrown 0.14.0", "lock_api", "once_cell", "parking_lot_core", @@ -1181,6 +1180,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.1" @@ -1352,7 +1357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] @@ -1547,11 +1552,26 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +[[package]] +name = "llama-cpp-bindings" +version = "0.1.0" +dependencies = [ + "async-trait", + "cmake", + "cxx", + "cxx-build", + "derive_builder", + "stop-words", + "tabby-inference", + "tokenizers", + "tokio", +] + [[package]] name = "lock_api" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" dependencies = [ "autocfg", "scopeguard", @@ -1586,7 +1606,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" dependencies = [ - "hashbrown", + "hashbrown 0.12.3", ] [[package]] @@ -1617,7 +1637,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" dependencies = [ - "regex-automata", + "regex-automata 0.1.10", ] [[package]] @@ -1626,7 +1646,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" dependencies = [ - "regex-automata", + "regex-automata 0.1.10", ] [[package]] @@ -1647,9 +1667,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "5486aed0026218e61b8a01d5fbd5a0a134649abb71a0e53b7bc088529dced86e" [[package]] name = "memmap2" @@ -1887,9 +1907,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oneshot" @@ -2073,15 +2093,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.2.16", + "redox_syscall 0.3.5", "smallvec", - "windows-sys 0.45.0", + "windows-targets 0.48.0", ] [[package]] @@ -2388,13 +2408,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick 1.0.1", "memchr", - "regex-syntax 0.7.2", + "regex-automata 0.3.8", + "regex-syntax 0.7.5", ] [[package]] @@ -2406,6 +2427,17 @@ dependencies = [ "regex-syntax 0.6.29", ] +[[package]] +name = "regex-automata" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +dependencies = [ + "aho-corasick 1.0.1", + "memchr", + "regex-syntax 0.7.5", +] + [[package]] name = "regex-syntax" version = "0.6.29" @@ -2414,9 +2446,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" @@ -2807,6 +2839,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stop-words" +version = "0.1.0" +dependencies = [ + "dashmap", + "regex", + "tokenizers", +] + [[package]] name = "strfmt" version = "0.2.4" @@ -2907,6 +2948,7 @@ dependencies = [ "ctranslate2-bindings", "hyper", "lazy_static", + "llama-cpp-bindings", "mime_guess", "nvml-wrapper", "opentelemetry", @@ -3217,9 +3259,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5" +checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f" dependencies = [ "aho-corasick 0.7.20", "cached-path", @@ -3240,7 +3282,7 @@ dependencies = [ "rayon", "rayon-cond", "regex", - "regex-syntax 0.6.29", + "regex-syntax 0.7.5", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index f9d681a..be7706a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ members = [ "crates/tabby-inference", "crates/ctranslate2-bindings", "crates/rust-cxx-cmake-bridge", + "crates/llama-cpp-bindings", + "crates/stop-words", ] [workspace.package] @@ -28,3 +30,5 @@ serde-jsonlines = "0.4.0" tantivy = "0.19.2" async-trait = "0.1.72" reqwest = { version = "0.11.18" } +derive_builder = "0.12.0" +tokenizers = "0.13.4-rc3" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index d6484a5..6753902 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -5,14 +5,13 @@ edition = "2021" [dependencies] cxx = "1.0" -dashmap = "5.4.0" -derive_builder = "0.12.0" -regex = "1.8.4" -tokenizers = "0.13.3" +derive_builder = { workspace = true } +tokenizers = { workspace = true } tokio = { workspace = true, features = ["rt"] } tokio-util = { workspace = true } tabby-inference = { path = "../tabby-inference" } async-trait = { workspace = true } +stop-words = { path = "../stop-words" } [build-dependencies] cxx-build = "1.0" diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 87be81b..ab49996 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -1,7 +1,8 @@ +use std::sync::Arc; + use async_trait::async_trait; -use dashmap::DashMap; use derive_builder::Builder; -use regex::Regex; +use stop_words::{StopWords, StopWordsCondition}; use tabby_inference::{TextGeneration, TextGenerationOptions}; use tokenizers::tokenizer::Tokenizer; use tokio_util::sync::CancellationToken; @@ -63,31 +64,26 @@ pub struct CTranslate2EngineOptions { num_replicas_per_device: usize, compute_type: String, - - stop_words_encoding_offset: Option, } pub struct InferenceContext { - stop_re: Option, + stop_condition: StopWordsCondition, cancel: CancellationToken, - reversed_output_text: String, } impl InferenceContext { - fn new(stop_re: Option, cancel: CancellationToken) -> Self { + fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self { InferenceContext { - stop_re, + stop_condition, cancel, - reversed_output_text: "".to_owned(), } } } pub struct CTranslate2Engine { engine: cxx::SharedPtr, - tokenizer: Tokenizer, - stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, - stop_words_encoding_offset: Option, + stop_words: StopWords, + tokenizer: Arc, } impl CTranslate2Engine { @@ -103,9 +99,8 @@ impl CTranslate2Engine { return Self { engine, - stop_regex_cache: DashMap::new(), - tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(), - stop_words_encoding_offset: options.stop_words_encoding_offset, + stop_words: StopWords::default(), + tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), }; } } @@ -120,25 +115,10 @@ impl TextGeneration for CTranslate2Engine { let cancel_for_inference = cancel.clone(); let _guard = cancel.drop_guard(); - let stop_re: Option = if options.stop_words.is_empty() { - None - } else { - let mut re = self.stop_regex_cache.get(options.stop_words); - if re.is_none() { - self.stop_regex_cache.insert( - options.stop_words, - create_stop_regex( - &self.tokenizer, - options.stop_words, - self.stop_words_encoding_offset, - ), - ); - re = self.stop_regex_cache.get(options.stop_words); - } - re.map(|x| x.value().clone()) - }; - - let context = InferenceContext::new(stop_re, cancel_for_inference); + let stop_condition = self + .stop_words + .create_condition(self.tokenizer.clone(), options.stop_words); + let context = InferenceContext::new(stop_condition, cancel_for_inference); let output_ids = tokio::task::spawn_blocking(move || { let context = Box::new(context); engine.inference( @@ -151,63 +131,19 @@ impl TextGeneration for CTranslate2Engine { }) .await .expect("Inference failed"); - self.tokenizer.decode(output_ids, true).unwrap() + self.tokenizer.decode(&output_ids, true).unwrap() } } fn inference_callback( context: &mut InferenceContext, _step: usize, - _token_id: u32, - token: String, + token_id: u32, + _token: String, ) -> bool { if context.cancel.is_cancelled() { true - } else if let Some(re) = &context.stop_re { - let mut new_token = reverse(&token); - new_token.push_str(&context.reversed_output_text); - context.reversed_output_text = new_token; - re.find(&context.reversed_output_text).is_some() } else { - false + context.stop_condition.next_token(token_id) } } - -fn reverse(s: &String) -> String { - // Special treatment for byte fallback token. - // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs - if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') { - // Keep byte fallback tokens like <0x0A> as is, do not reverse it. - // This won't really affect stop words regex logic, but brings more readability when - // debugging decoding steps. - s.to_owned() - } else { - s.chars().rev().collect() - } -} - -fn create_stop_regex( - tokenizer: &Tokenizer, - stop_words: &[&str], - stop_words_encoding_offset: Option, -) -> Regex { - let encodings = tokenizer - .encode_batch(stop_words.to_owned(), false) - .unwrap(); - let stop_tokens: Vec = encodings - .iter() - .map(|x| { - x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..] - .iter() - .rev() - .map(reverse) - .collect::>() - .join("") - }) - .collect(); - - // (?m) enables multi-line matching mode. - // \A means absolute begins of string. - let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|"); - Regex::new(®ex_string).unwrap() -} diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml new file mode 100644 index 0000000..395829e --- /dev/null +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "llama-cpp-bindings" +version = "0.1.0" +edition = "2021" + +[build-dependencies] +cxx-build = "1.0" +cmake = "0.1" + +[dependencies] +cxx = "1.0" +async-trait = { workspace = true } +tokio = { workspace = true, features = ["rt"] } +tabby-inference = { path = "../tabby-inference" } +derive_builder = { workspace = true } +tokenizers = { workspace = true } +stop-words = { version = "0.1.0", path = "../stop-words" } diff --git a/crates/llama-cpp-bindings/build.rs b/crates/llama-cpp-bindings/build.rs new file mode 100644 index 0000000..c4ac850 --- /dev/null +++ b/crates/llama-cpp-bindings/build.rs @@ -0,0 +1,23 @@ +use cmake::Config; + +fn main() { + let dst = Config::new("llama.cpp").define("LLAMA_METAL", "ON").build(); + + println!("cargo:rerun-if-changed=cc/*.h"); + println!("cargo:rerun-if-changed=cc/*.cc"); + + println!("cargo:rustc-link-search=native={}/build", dst.display()); + println!("cargo:rustc-link-lib=llama"); + println!("cargo:rustc-link-lib=ggml_static"); + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Accelerate"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + + cxx_build::bridge("src/lib.rs") + .file("src/engine.cc") + .flag_if_supported("-Iinclude") + .flag_if_supported("-Illama.cpp") + .flag_if_supported("-std=c++14") + .compile("cxxbridge"); +} diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h new file mode 100644 index 0000000..45a4f9f --- /dev/null +++ b/crates/llama-cpp-bindings/include/engine.h @@ -0,0 +1,17 @@ +#pragma once + +#include "rust/cxx.h" +#include + +namespace llama { + +class TextInferenceEngine { + public: + virtual ~TextInferenceEngine(); + + virtual uint32_t start(const rust::Str prompt) const = 0; + virtual uint32_t step(uint32_t next_token_id) const = 0; +}; + +std::shared_ptr create_engine(rust::Str model_path); +} // namespace diff --git a/crates/llama-cpp-bindings/llama.cpp b/crates/llama-cpp-bindings/llama.cpp new file mode 160000 index 0000000..bce1fef --- /dev/null +++ b/crates/llama-cpp-bindings/llama.cpp @@ -0,0 +1 @@ +Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc new file mode 100644 index 0000000..b1b0dfd --- /dev/null +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -0,0 +1,112 @@ +#include "engine.h" + +#include +#include + +#include +#include + +namespace llama { +TextInferenceEngine::~TextInferenceEngine() {} + +namespace { +template +using owned = std::unique_ptr>; + +std::vector tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { + // upper limit for the number of tokens + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +class TextInferenceEngineImpl : public TextInferenceEngine { + public: + TextInferenceEngineImpl(owned model, owned ctx) : + model_(std::move(model)), + ctx_(std::move(ctx)) { + } + + uint32_t start(const rust::Str prompt) const override { + auto* ctx = ctx_.get(); + std::vector tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true); + eval(tokens_list, /* reset = */ true); + return sample(); + } + + uint32_t step(uint32_t next_token_id) const override { + eval({ static_cast(next_token_id) }, /* reset = */ false); + return sample(); + } + + private: + uint32_t sample() const { + auto* ctx = ctx_.get(); + + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Greedy sampling (always select the highest logit). + return std::distance(logits, std::max_element(logits, logits + n_vocab)); + } + + bool eval(const std::vector& tokens_list, bool reset) const { + auto* ctx = ctx_.get(); + if (llama_eval( + ctx, + tokens_list.data(), + tokens_list.size(), + reset ? 0 : llama_get_kv_cache_token_count(ctx), + /* n_threads = */ 1)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + + return true; + } + + owned model_; + owned ctx_; +}; + +struct BackendInitializer { + BackendInitializer() { + llama_backend_init(false); + } + + ~BackendInitializer() { + llama_backend_free(); + } +}; +} // namespace + +std::shared_ptr create_engine(rust::Str model_path) { + static BackendInitializer initializer; + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_gpu_layers = 4; + + llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params); + + if (!model) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return nullptr; + } + + llama_context* ctx = llama_new_context_with_model(model, ctx_params); + + return std::make_shared( + owned(model, llama_free_model), + owned(ctx, llama_free) + ); +} + +} // namespace tabby diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs new file mode 100644 index 0000000..d953f2e --- /dev/null +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -0,0 +1,73 @@ +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use derive_builder::Builder; +use ffi::create_engine; +use stop_words::StopWords; +use tabby_inference::{TextGeneration, TextGenerationOptions}; +use tokenizers::tokenizer::Tokenizer; + +#[cxx::bridge(namespace = "llama")] +mod ffi { + unsafe extern "C++" { + include!("llama-cpp-bindings/include/engine.h"); + + type TextInferenceEngine; + + fn create_engine(model_path: &str) -> SharedPtr; + + fn start(&self, prompt: &str) -> u32; + fn step(&self, next_token_id: u32) -> u32; + } +} + +unsafe impl Send for ffi::TextInferenceEngine {} +unsafe impl Sync for ffi::TextInferenceEngine {} + +#[derive(Builder, Debug)] +pub struct LlamaEngineOptions { + model_path: String, + tokenizer_path: String, +} + +pub struct LlamaEngine { + engine: Mutex>, + tokenizer: Arc, + stop_words: StopWords, +} + +impl LlamaEngine { + pub fn create(options: LlamaEngineOptions) -> Self { + LlamaEngine { + engine: Mutex::new(create_engine(&options.model_path)), + tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()), + stop_words: StopWords::default(), + } + } +} + +#[async_trait] +impl TextGeneration for LlamaEngine { + async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { + let engine = self.engine.lock().unwrap(); + let mut next_token_id = engine.start(prompt); + let mut n_remains = options.max_decoding_length - 1; + let mut output_ids = vec![next_token_id]; + + let mut stop_condition = self + .stop_words + .create_condition(self.tokenizer.clone(), options.stop_words); + + // FIXME(meng): supports cancellation. + while n_remains > 0 { + next_token_id = engine.step(next_token_id); + if stop_condition.next_token(next_token_id) { + break; + } + output_ids.push(next_token_id); + n_remains -= 1; + } + + self.tokenizer.decode(&output_ids, true).unwrap() + } +} diff --git a/crates/stop-words/Cargo.toml b/crates/stop-words/Cargo.toml new file mode 100644 index 0000000..fea97ae --- /dev/null +++ b/crates/stop-words/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "stop-words" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dashmap = "5.5.3" +regex = "1.9.5" +tokenizers.workspace = true diff --git a/crates/stop-words/src/lib.rs b/crates/stop-words/src/lib.rs new file mode 100644 index 0000000..b18a362 --- /dev/null +++ b/crates/stop-words/src/lib.rs @@ -0,0 +1,80 @@ +use std::sync::Arc; + +use dashmap::DashMap; +use regex::Regex; +use tokenizers::tokenizer::Tokenizer; + +pub struct StopWords { + stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>, +} + +fn reverse(s: &&str) -> String { + s.chars().rev().collect() +} + +impl Default for StopWords { + fn default() -> Self { + Self { + stop_regex_cache: DashMap::new(), + } + } +} + +impl StopWords { + pub fn create_condition( + &self, + tokenizer: Arc, + stop_words: &'static Vec<&'static str>, + ) -> StopWordsCondition { + let re = if stop_words.is_empty() { + None + } else { + let mut re = self.stop_regex_cache.get(stop_words); + if re.is_none() { + self.stop_regex_cache + .insert(stop_words, create_stop_regex(stop_words)); + re = self.stop_regex_cache.get(stop_words); + } + re.map(|x| x.value().clone()) + }; + + StopWordsCondition::new(tokenizer, re) + } +} + +fn create_stop_regex(stop_words: &[&str]) -> Regex { + let tokens: Vec = stop_words.iter().map(reverse).collect(); + + // (?m) enables multi-line matching mode. + // \A means absolute begins of string. + let regex_string = r"(?m)\A".to_owned() + &tokens.join("|"); + Regex::new(®ex_string).unwrap() +} + +pub struct StopWordsCondition { + tokenizer: Arc, + stop_re: Option, + reversed_output_text: String, +} + +impl StopWordsCondition { + pub fn new(tokenizer: Arc, stop_re: Option) -> Self { + Self { + tokenizer, + stop_re, + reversed_output_text: String::new(), + } + } + + pub fn next_token(&mut self, token_id: u32) -> bool { + if let Some(re) = &self.stop_re { + let token = self.tokenizer.decode(&[token_id], false).unwrap(); + let mut new_token = reverse(&token.as_str()); + new_token.push_str(&self.reversed_output_text); + self.reversed_output_text = new_token; + re.find(&self.reversed_output_text).is_some() + } else { + false + } + } +} diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index 0a5d094..0c06328 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -85,4 +85,8 @@ impl ModelDir { pub fn ctranslate2_dir(&self) -> String { self.path_string("ctranslate2") } + + pub fn ggml_model_file(&self) -> String { + self.path_string("ggml/default.gguf") + } } diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 8804a58..4d1befa 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -13,6 +13,6 @@ pub struct TextGenerationOptions { } #[async_trait] -pub trait TextGeneration { +pub trait TextGeneration: Sync + Send { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 290c304..802f252 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -35,7 +35,7 @@ tantivy = { workspace = true } anyhow = { workspace = true } sysinfo = "0.29.8" nvml-wrapper = "0.9.0" - +llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true } [dependencies.uuid] version = "1.3.3" @@ -49,6 +49,7 @@ features = [ default = ["scheduler"] link_shared = ["ctranslate2-bindings/link_shared"] scheduler = ["tabby-scheduler"] +metal = ["llama-cpp-bindings"] [build-dependencies] vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 62014d0..8307cab 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -121,7 +121,7 @@ pub async fn completion( } pub struct CompletionState { - engine: CTranslate2Engine, + engine: Box, prompt_builder: prompt::PromptBuilder, } @@ -129,21 +129,8 @@ impl CompletionState { pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { let model_dir = get_model_dir(&args.model); let metadata = read_metadata(&model_dir); + let engine = create_engine(args, &model_dir, &metadata); - let device = format!("{}", args.device); - let compute_type = format!("{}", args.compute_type); - let options = CTranslate2EngineOptionsBuilder::default() - .model_path(model_dir.ctranslate2_dir()) - .tokenizer_path(model_dir.tokenizer_file()) - .device(device) - .model_type(metadata.auto_model) - .device_indices(args.device_indices.clone()) - .num_replicas_per_device(args.num_replicas_per_device) - .compute_type(compute_type) - .stop_words_encoding_offset(metadata.stop_words_encoding_offset) - .build() - .unwrap(); - let engine = CTranslate2Engine::create(options); Self { engine, prompt_builder: prompt::PromptBuilder::new( @@ -154,6 +141,59 @@ impl CompletionState { } } +#[cfg(not(feature = "metal"))] +fn create_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + create_ctranslate2_engine(args, model_dir, metadata) +} + +#[cfg(feature = "metal")] +fn create_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + if args.device != super::Device::Metal { + create_ctranslate2_engine(args, model_dir, metadata) + } else { + create_llama_engine(model_dir) + } +} + +fn create_ctranslate2_engine( + args: &crate::serve::ServeArgs, + model_dir: &ModelDir, + metadata: &Metadata, +) -> Box { + let device = format!("{}", args.device); + let compute_type = format!("{}", args.compute_type); + let options = CTranslate2EngineOptionsBuilder::default() + .model_path(model_dir.ctranslate2_dir()) + .tokenizer_path(model_dir.tokenizer_file()) + .device(device) + .model_type(metadata.auto_model.clone()) + .device_indices(args.device_indices.clone()) + .num_replicas_per_device(args.num_replicas_per_device) + .compute_type(compute_type) + .build() + .unwrap(); + Box::new(CTranslate2Engine::create(options)) +} + +#[cfg(feature = "metal")] +fn create_llama_engine(model_dir: &ModelDir) -> Box { + let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() + .model_path(model_dir.ggml_model_file()) + .tokenizer_path(model_dir.tokenizer_file()) + .build() + .unwrap(); + + Box::new(llama_cpp_bindings::LlamaEngine::create(options)) +} + fn get_model_dir(model: &str) -> ModelDir { if Path::new(model).exists() { ModelDir::from(model) @@ -166,7 +206,6 @@ fn get_model_dir(model: &str) -> ModelDir { struct Metadata { auto_model: String, prompt_template: Option, - stop_words_encoding_offset: Option, } fn read_metadata(model_dir: &ModelDir) -> Metadata { diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index a7337b2..a8b9638 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -54,6 +54,10 @@ pub enum Device { #[strum(serialize = "cuda")] Cuda, + + #[cfg(feature = "metal")] + #[strum(serialize = "metal")] + Metal, } #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]