feat: llama.cpp for metal support [TAB-146] (#391)
* feat: init commit adding llama-cpp-bindings * add llama.cpp submodule * add LlamaEngine to hold llama context / llama model * add cxxbridge * add basic greedy sampling * move files * make compile success * connect TextGeneration with LlamaEngine * experimental support llama.cpp * add metal device * add Accelerate * fix namespace for llama-cpp-bindings * fix lint * move stepping logic to rust * add stop words package * use stop-words in ctranslate2-bindings * use raw string for regex * use Arc<Tokenizer> for sharing tokenizers * refactor: remove useless stop_words_encoding_offset * switch to tokenizers 0.13.4-rc.3 * fix lints in cpp * simplify implementation of greedy decoding * feat: split metal feature for llama backend * add ci * update ci * build tabby bin in ci buildrelease-v0.1
parent
1c1cf44639
commit
3573d4378e
|
|
@ -43,8 +43,10 @@ jobs:
|
|||
include:
|
||||
- os: macos-11
|
||||
target: aarch64-apple-darwin
|
||||
flags: "--features metal"
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
flags: ""
|
||||
|
||||
env:
|
||||
SCCACHE_GHA_ENABLED: true
|
||||
|
|
@ -81,7 +83,7 @@ jobs:
|
|||
~/.cargo/git
|
||||
- run: bash ./ci/prepare_build_environment.sh
|
||||
- name: Bulid release binary
|
||||
run: cargo build --no-default-features --release --target ${{ matrix.target }}
|
||||
run: cargo build --bin tabby --no-default-features ${{ matrix.flags }} --release --target ${{ matrix.target }}
|
||||
|
||||
- name: Rename release binary
|
||||
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
[submodule "crates/ctranslate2-bindings/CTranslate2"]
|
||||
path = crates/ctranslate2-bindings/CTranslate2
|
||||
url = https://github.com/OpenNMT/CTranslate2.git
|
||||
[submodule "crates/llama-cpp-bindings/llama.cpp"]
|
||||
path = crates/llama-cpp-bindings/llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp
|
||||
|
|
|
|||
|
|
@ -617,10 +617,9 @@ dependencies = [
|
|||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"dashmap",
|
||||
"derive_builder",
|
||||
"regex",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"stop-words",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
@ -743,12 +742,12 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.4.0"
|
||||
version = "5.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
|
||||
checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"hashbrown",
|
||||
"hashbrown 0.14.0",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
|
|
@ -1181,6 +1180,12 @@ dependencies = [
|
|||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
|
|
@ -1352,7 +1357,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"hashbrown",
|
||||
"hashbrown 0.12.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
|
|
@ -1547,11 +1552,26 @@ version = "0.3.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-bindings"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"derive_builder",
|
||||
"stop-words",
|
||||
"tabby-inference",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.9"
|
||||
version = "0.4.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df"
|
||||
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
|
|
@ -1586,7 +1606,7 @@ version = "0.7.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
"hashbrown 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1617,7 +1637,7 @@ version = "0.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1626,7 +1646,7 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1647,9 +1667,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
version = "2.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
checksum = "5486aed0026218e61b8a01d5fbd5a0a134649abb71a0e53b7bc088529dced86e"
|
||||
|
||||
[[package]]
|
||||
name = "memmap2"
|
||||
|
|
@ -1887,9 +1907,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.17.1"
|
||||
version = "1.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
|
||||
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
|
||||
|
||||
[[package]]
|
||||
name = "oneshot"
|
||||
|
|
@ -2073,15 +2093,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.7"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521"
|
||||
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall 0.2.16",
|
||||
"redox_syscall 0.3.5",
|
||||
"smallvec",
|
||||
"windows-sys 0.45.0",
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2388,13 +2408,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.8.4"
|
||||
version = "1.9.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f"
|
||||
checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.1",
|
||||
"memchr",
|
||||
"regex-syntax 0.7.2",
|
||||
"regex-automata 0.3.8",
|
||||
"regex-syntax 0.7.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2406,6 +2427,17 @@ dependencies = [
|
|||
"regex-syntax 0.6.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795"
|
||||
dependencies = [
|
||||
"aho-corasick 1.0.1",
|
||||
"memchr",
|
||||
"regex-syntax 0.7.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
|
|
@ -2414,9 +2446,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
|||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.2"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78"
|
||||
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
|
|
@ -2807,6 +2839,15 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "stop-words"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"dashmap",
|
||||
"regex",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strfmt"
|
||||
version = "0.2.4"
|
||||
|
|
@ -2907,6 +2948,7 @@ dependencies = [
|
|||
"ctranslate2-bindings",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"llama-cpp-bindings",
|
||||
"mime_guess",
|
||||
"nvml-wrapper",
|
||||
"opentelemetry",
|
||||
|
|
@ -3217,9 +3259,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
|||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.3"
|
||||
version = "0.13.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5"
|
||||
checksum = "aea68938177975ab09da68552b720eac941779ff386baceaf77e0f5f9cea645f"
|
||||
dependencies = [
|
||||
"aho-corasick 0.7.20",
|
||||
"cached-path",
|
||||
|
|
@ -3240,7 +3282,7 @@ dependencies = [
|
|||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.6.29",
|
||||
"regex-syntax 0.7.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ members = [
|
|||
"crates/tabby-inference",
|
||||
"crates/ctranslate2-bindings",
|
||||
"crates/rust-cxx-cmake-bridge",
|
||||
"crates/llama-cpp-bindings",
|
||||
"crates/stop-words",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
|
@ -28,3 +30,5 @@ serde-jsonlines = "0.4.0"
|
|||
tantivy = "0.19.2"
|
||||
async-trait = "0.1.72"
|
||||
reqwest = { version = "0.11.18" }
|
||||
derive_builder = "0.12.0"
|
||||
tokenizers = "0.13.4-rc3"
|
||||
|
|
|
|||
|
|
@ -5,14 +5,13 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
dashmap = "5.4.0"
|
||||
derive_builder = "0.12.0"
|
||||
regex = "1.8.4"
|
||||
tokenizers = "0.13.3"
|
||||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tokio-util = { workspace = true }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
async-trait = { workspace = true }
|
||||
stop-words = { path = "../stop-words" }
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use derive_builder::Builder;
|
||||
use regex::Regex;
|
||||
use stop_words::{StopWords, StopWordsCondition};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
|
@ -63,31 +64,26 @@ pub struct CTranslate2EngineOptions {
|
|||
num_replicas_per_device: usize,
|
||||
|
||||
compute_type: String,
|
||||
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
}
|
||||
|
||||
pub struct InferenceContext {
|
||||
stop_re: Option<Regex>,
|
||||
stop_condition: StopWordsCondition,
|
||||
cancel: CancellationToken,
|
||||
reversed_output_text: String,
|
||||
}
|
||||
|
||||
impl InferenceContext {
|
||||
fn new(stop_re: Option<Regex>, cancel: CancellationToken) -> Self {
|
||||
fn new(stop_condition: StopWordsCondition, cancel: CancellationToken) -> Self {
|
||||
InferenceContext {
|
||||
stop_re,
|
||||
stop_condition,
|
||||
cancel,
|
||||
reversed_output_text: "".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CTranslate2Engine {
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
stop_words: StopWords,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
}
|
||||
|
||||
impl CTranslate2Engine {
|
||||
|
|
@ -103,9 +99,8 @@ impl CTranslate2Engine {
|
|||
|
||||
return Self {
|
||||
engine,
|
||||
stop_regex_cache: DashMap::new(),
|
||||
tokenizer: Tokenizer::from_file(&options.tokenizer_path).unwrap(),
|
||||
stop_words_encoding_offset: options.stop_words_encoding_offset,
|
||||
stop_words: StopWords::default(),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -120,25 +115,10 @@ impl TextGeneration for CTranslate2Engine {
|
|||
let cancel_for_inference = cancel.clone();
|
||||
let _guard = cancel.drop_guard();
|
||||
|
||||
let stop_re: Option<Regex> = if options.stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(options.stop_words);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache.insert(
|
||||
options.stop_words,
|
||||
create_stop_regex(
|
||||
&self.tokenizer,
|
||||
options.stop_words,
|
||||
self.stop_words_encoding_offset,
|
||||
),
|
||||
);
|
||||
re = self.stop_regex_cache.get(options.stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
};
|
||||
|
||||
let context = InferenceContext::new(stop_re, cancel_for_inference);
|
||||
let stop_condition = self
|
||||
.stop_words
|
||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||
let context = InferenceContext::new(stop_condition, cancel_for_inference);
|
||||
let output_ids = tokio::task::spawn_blocking(move || {
|
||||
let context = Box::new(context);
|
||||
engine.inference(
|
||||
|
|
@ -151,63 +131,19 @@ impl TextGeneration for CTranslate2Engine {
|
|||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_callback(
|
||||
context: &mut InferenceContext,
|
||||
_step: usize,
|
||||
_token_id: u32,
|
||||
token: String,
|
||||
token_id: u32,
|
||||
_token: String,
|
||||
) -> bool {
|
||||
if context.cancel.is_cancelled() {
|
||||
true
|
||||
} else if let Some(re) = &context.stop_re {
|
||||
let mut new_token = reverse(&token);
|
||||
new_token.push_str(&context.reversed_output_text);
|
||||
context.reversed_output_text = new_token;
|
||||
re.find(&context.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
context.stop_condition.next_token(token_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse(s: &String) -> String {
|
||||
// Special treatment for byte fallback token.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/byte_fallback.rs
|
||||
if s.len() == 6 && s.starts_with("<0x") && s.ends_with('>') {
|
||||
// Keep byte fallback tokens like <0x0A> as is, do not reverse it.
|
||||
// This won't really affect stop words regex logic, but brings more readability when
|
||||
// debugging decoding steps.
|
||||
s.to_owned()
|
||||
} else {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(
|
||||
tokenizer: &Tokenizer,
|
||||
stop_words: &[&str],
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
) -> Regex {
|
||||
let encodings = tokenizer
|
||||
.encode_batch(stop_words.to_owned(), false)
|
||||
.unwrap();
|
||||
let stop_tokens: Vec<String> = encodings
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_tokens()[stop_words_encoding_offset.unwrap_or(0)..]
|
||||
.iter()
|
||||
.rev()
|
||||
.map(reverse)
|
||||
.collect::<Vec<String>>()
|
||||
.join("")
|
||||
})
|
||||
.collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &stop_tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "llama-cpp-bindings"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
cmake = "0.1"
|
||||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
async-trait = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
derive_builder = { workspace = true }
|
||||
tokenizers = { workspace = true }
|
||||
stop-words = { version = "0.1.0", path = "../stop-words" }
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
use cmake::Config;
|
||||
|
||||
fn main() {
|
||||
let dst = Config::new("llama.cpp").define("LLAMA_METAL", "ON").build();
|
||||
|
||||
println!("cargo:rerun-if-changed=cc/*.h");
|
||||
println!("cargo:rerun-if-changed=cc/*.cc");
|
||||
|
||||
println!("cargo:rustc-link-search=native={}/build", dst.display());
|
||||
println!("cargo:rustc-link-lib=llama");
|
||||
println!("cargo:rustc-link-lib=ggml_static");
|
||||
println!("cargo:rustc-link-lib=framework=Foundation");
|
||||
println!("cargo:rustc-link-lib=framework=Accelerate");
|
||||
println!("cargo:rustc-link-lib=framework=Metal");
|
||||
println!("cargo:rustc-link-lib=framework=MetalKit");
|
||||
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.file("src/engine.cc")
|
||||
.flag_if_supported("-Iinclude")
|
||||
.flag_if_supported("-Illama.cpp")
|
||||
.flag_if_supported("-std=c++14")
|
||||
.compile("cxxbridge");
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include "rust/cxx.h"
|
||||
#include <memory>
|
||||
|
||||
namespace llama {
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
|
||||
virtual uint32_t start(const rust::Str prompt) const = 0;
|
||||
virtual uint32_t step(uint32_t next_token_id) const = 0;
|
||||
};
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
} // namespace
|
||||
|
|
@ -0,0 +1 @@
|
|||
Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
#include "engine.h"
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include <ggml.h>
|
||||
#include <llama.h>
|
||||
|
||||
namespace llama {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
namespace {
|
||||
template<class T>
|
||||
using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||
|
||||
std::vector<llama_token> tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
|
||||
model_(std::move(model)),
|
||||
ctx_(std::move(ctx)) {
|
||||
}
|
||||
|
||||
uint32_t start(const rust::Str prompt) const override {
|
||||
auto* ctx = ctx_.get();
|
||||
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
|
||||
eval(tokens_list, /* reset = */ true);
|
||||
return sample();
|
||||
}
|
||||
|
||||
uint32_t step(uint32_t next_token_id) const override {
|
||||
eval({ static_cast<llama_token>(next_token_id) }, /* reset = */ false);
|
||||
return sample();
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t sample() const {
|
||||
auto* ctx = ctx_.get();
|
||||
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
// Greedy sampling (always select the highest logit).
|
||||
return std::distance(logits, std::max_element(logits, logits + n_vocab));
|
||||
}
|
||||
|
||||
bool eval(const std::vector<llama_token>& tokens_list, bool reset) const {
|
||||
auto* ctx = ctx_.get();
|
||||
if (llama_eval(
|
||||
ctx,
|
||||
tokens_list.data(),
|
||||
tokens_list.size(),
|
||||
reset ? 0 : llama_get_kv_cache_token_count(ctx),
|
||||
/* n_threads = */ 1)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
owned<llama_model> model_;
|
||||
owned<llama_context> ctx_;
|
||||
};
|
||||
|
||||
struct BackendInitializer {
|
||||
BackendInitializer() {
|
||||
llama_backend_init(false);
|
||||
}
|
||||
|
||||
~BackendInitializer() {
|
||||
llama_backend_free();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||
static BackendInitializer initializer;
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_gpu_layers = 4;
|
||||
|
||||
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);
|
||||
|
||||
if (!model) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llama_context* ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
return std::make_shared<TextInferenceEngineImpl>(
|
||||
owned<llama_model>(model, llama_free_model),
|
||||
owned<llama_context>(ctx, llama_free)
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace tabby
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use ffi::create_engine;
|
||||
use stop_words::StopWords;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions};
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
#[cxx::bridge(namespace = "llama")]
|
||||
mod ffi {
|
||||
unsafe extern "C++" {
|
||||
include!("llama-cpp-bindings/include/engine.h");
|
||||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(model_path: &str) -> SharedPtr<TextInferenceEngine>;
|
||||
|
||||
fn start(&self, prompt: &str) -> u32;
|
||||
fn step(&self, next_token_id: u32) -> u32;
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct LlamaEngineOptions {
|
||||
model_path: String,
|
||||
tokenizer_path: String,
|
||||
}
|
||||
|
||||
pub struct LlamaEngine {
|
||||
engine: Mutex<cxx::SharedPtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_words: StopWords,
|
||||
}
|
||||
|
||||
impl LlamaEngine {
|
||||
pub fn create(options: LlamaEngineOptions) -> Self {
|
||||
LlamaEngine {
|
||||
engine: Mutex::new(create_engine(&options.model_path)),
|
||||
tokenizer: Arc::new(Tokenizer::from_file(&options.tokenizer_path).unwrap()),
|
||||
stop_words: StopWords::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TextGeneration for LlamaEngine {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
|
||||
let engine = self.engine.lock().unwrap();
|
||||
let mut next_token_id = engine.start(prompt);
|
||||
let mut n_remains = options.max_decoding_length - 1;
|
||||
let mut output_ids = vec![next_token_id];
|
||||
|
||||
let mut stop_condition = self
|
||||
.stop_words
|
||||
.create_condition(self.tokenizer.clone(), options.stop_words);
|
||||
|
||||
// FIXME(meng): supports cancellation.
|
||||
while n_remains > 0 {
|
||||
next_token_id = engine.step(next_token_id);
|
||||
if stop_condition.next_token(next_token_id) {
|
||||
break;
|
||||
}
|
||||
output_ids.push(next_token_id);
|
||||
n_remains -= 1;
|
||||
}
|
||||
|
||||
self.tokenizer.decode(&output_ids, true).unwrap()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "stop-words"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
dashmap = "5.5.3"
|
||||
regex = "1.9.5"
|
||||
tokenizers.workspace = true
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use regex::Regex;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
|
||||
pub struct StopWords {
|
||||
stop_regex_cache: DashMap<&'static Vec<&'static str>, Regex>,
|
||||
}
|
||||
|
||||
fn reverse(s: &&str) -> String {
|
||||
s.chars().rev().collect()
|
||||
}
|
||||
|
||||
impl Default for StopWords {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
stop_regex_cache: DashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StopWords {
|
||||
pub fn create_condition(
|
||||
&self,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_words: &'static Vec<&'static str>,
|
||||
) -> StopWordsCondition {
|
||||
let re = if stop_words.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let mut re = self.stop_regex_cache.get(stop_words);
|
||||
if re.is_none() {
|
||||
self.stop_regex_cache
|
||||
.insert(stop_words, create_stop_regex(stop_words));
|
||||
re = self.stop_regex_cache.get(stop_words);
|
||||
}
|
||||
re.map(|x| x.value().clone())
|
||||
};
|
||||
|
||||
StopWordsCondition::new(tokenizer, re)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_stop_regex(stop_words: &[&str]) -> Regex {
|
||||
let tokens: Vec<String> = stop_words.iter().map(reverse).collect();
|
||||
|
||||
// (?m) enables multi-line matching mode.
|
||||
// \A means absolute begins of string.
|
||||
let regex_string = r"(?m)\A".to_owned() + &tokens.join("|");
|
||||
Regex::new(®ex_string).unwrap()
|
||||
}
|
||||
|
||||
pub struct StopWordsCondition {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
stop_re: Option<Regex>,
|
||||
reversed_output_text: String,
|
||||
}
|
||||
|
||||
impl StopWordsCondition {
|
||||
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
stop_re,
|
||||
reversed_output_text: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_token(&mut self, token_id: u32) -> bool {
|
||||
if let Some(re) = &self.stop_re {
|
||||
let token = self.tokenizer.decode(&[token_id], false).unwrap();
|
||||
let mut new_token = reverse(&token.as_str());
|
||||
new_token.push_str(&self.reversed_output_text);
|
||||
self.reversed_output_text = new_token;
|
||||
re.find(&self.reversed_output_text).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -85,4 +85,8 @@ impl ModelDir {
|
|||
pub fn ctranslate2_dir(&self) -> String {
|
||||
self.path_string("ctranslate2")
|
||||
}
|
||||
|
||||
pub fn ggml_model_file(&self) -> String {
|
||||
self.path_string("ggml/default.gguf")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,6 @@ pub struct TextGenerationOptions {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait TextGeneration {
|
||||
pub trait TextGeneration: Sync + Send {
|
||||
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ tantivy = { workspace = true }
|
|||
anyhow = { workspace = true }
|
||||
sysinfo = "0.29.8"
|
||||
nvml-wrapper = "0.9.0"
|
||||
|
||||
llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true }
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
@ -49,6 +49,7 @@ features = [
|
|||
default = ["scheduler"]
|
||||
link_shared = ["ctranslate2-bindings/link_shared"]
|
||||
scheduler = ["tabby-scheduler"]
|
||||
metal = ["llama-cpp-bindings"]
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ pub async fn completion(
|
|||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
engine: CTranslate2Engine,
|
||||
engine: Box<dyn TextGeneration>,
|
||||
prompt_builder: prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
|
|
@ -129,21 +129,8 @@ impl CompletionState {
|
|||
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
|
||||
let model_dir = get_model_dir(&args.model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_engine(args, &model_dir, &metadata);
|
||||
|
||||
let device = format!("{}", args.device);
|
||||
let compute_type = format!("{}", args.compute_type);
|
||||
let options = CTranslate2EngineOptionsBuilder::default()
|
||||
.model_path(model_dir.ctranslate2_dir())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.device(device)
|
||||
.model_type(metadata.auto_model)
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.compute_type(compute_type)
|
||||
.stop_words_encoding_offset(metadata.stop_words_encoding_offset)
|
||||
.build()
|
||||
.unwrap();
|
||||
let engine = CTranslate2Engine::create(options);
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: prompt::PromptBuilder::new(
|
||||
|
|
@ -154,6 +141,59 @@ impl CompletionState {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "metal"))]
|
||||
fn create_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
model_dir: &ModelDir,
|
||||
metadata: &Metadata,
|
||||
) -> Box<dyn TextGeneration> {
|
||||
create_ctranslate2_engine(args, model_dir, metadata)
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn create_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
model_dir: &ModelDir,
|
||||
metadata: &Metadata,
|
||||
) -> Box<dyn TextGeneration> {
|
||||
if args.device != super::Device::Metal {
|
||||
create_ctranslate2_engine(args, model_dir, metadata)
|
||||
} else {
|
||||
create_llama_engine(model_dir)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_ctranslate2_engine(
|
||||
args: &crate::serve::ServeArgs,
|
||||
model_dir: &ModelDir,
|
||||
metadata: &Metadata,
|
||||
) -> Box<dyn TextGeneration> {
|
||||
let device = format!("{}", args.device);
|
||||
let compute_type = format!("{}", args.compute_type);
|
||||
let options = CTranslate2EngineOptionsBuilder::default()
|
||||
.model_path(model_dir.ctranslate2_dir())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.device(device)
|
||||
.model_type(metadata.auto_model.clone())
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.compute_type(compute_type)
|
||||
.build()
|
||||
.unwrap();
|
||||
Box::new(CTranslate2Engine::create(options))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
|
||||
.model_path(model_dir.ggml_model_file())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Box::new(llama_cpp_bindings::LlamaEngine::create(options))
|
||||
}
|
||||
|
||||
fn get_model_dir(model: &str) -> ModelDir {
|
||||
if Path::new(model).exists() {
|
||||
ModelDir::from(model)
|
||||
|
|
@ -166,7 +206,6 @@ fn get_model_dir(model: &str) -> ModelDir {
|
|||
struct Metadata {
|
||||
auto_model: String,
|
||||
prompt_template: Option<String>,
|
||||
stop_words_encoding_offset: Option<usize>,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ pub enum Device {
|
|||
|
||||
#[strum(serialize = "cuda")]
|
||||
Cuda,
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[strum(serialize = "metal")]
|
||||
Metal,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||
|
|
|
|||
Loading…
Reference in New Issue