feat: switch cuda backend to llama.cpp (#656)

* feat: switch cuda backend to llama.cpp

* fix

* fix
release-notes-05
Meng Zhang 2023-10-27 13:41:22 -07:00 committed by GitHub
parent 308681efb0
commit 23bd542cec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 46 additions and 109 deletions

View File

@ -111,7 +111,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh - run: bash ./ci/prepare_build_environment.sh
- name: Bulid release binary - name: Bulid release binary
run: cargo build --no-default-features --release --target ${{ matrix.target }} --package tabby run: cargo build --release --target ${{ matrix.target }} --package tabby
- name: Rename release binary - name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }} run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}

View File

@ -9,6 +9,7 @@
* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638 * Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637 * add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637
* Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/656
# v0.4.0 # v0.4.0

1
Cargo.lock generated
View File

@ -3153,7 +3153,6 @@ dependencies = [
"axum-streams", "axum-streams",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap", "clap",
"ctranslate2-bindings",
"futures", "futures",
"http-api-bindings", "http-api-bindings",
"hyper", "hyper",

View File

@ -1,8 +1,12 @@
FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2 as source ARG UBUNTU_VERSION=22.04
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder # This needs to generally match the container host's environment.
ARG CUDA_VERSION=11.7.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the CUDA runtime image
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
ENV CTRANSLATE2_ROOT=/opt/ctranslate2 FROM ${BASE_CUDA_DEV_CONTAINER} as build
COPY --from=source $CTRANSLATE2_ROOT $CTRANSLATE2_ROOT
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \ RUN apt-get update && \
@ -30,10 +34,10 @@ RUN mkdir -p target
RUN --mount=type=cache,target=/usr/local/cargo/registry \ RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/root/workspace/target \ --mount=type=cache,target=/root/workspace/target \
cargo build --features link_shared --release && \ cargo build --features cuda --release && \
cp target/release/tabby /opt/tabby/bin/ cp target/release/tabby /opt/tabby/bin/
FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2 FROM ${BASE_CUDA_RUN_CONTAINER} as runtime
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
@ -51,7 +55,7 @@ RUN git config --system --add safe.directory "*"
RUN ln -s /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 \ RUN ln -s /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 \
/usr/lib/x86_64-linux-gnu/libnvidia-ml.so /usr/lib/x86_64-linux-gnu/libnvidia-ml.so
COPY --from=builder /opt/tabby /opt/tabby COPY --from=build /opt/tabby /opt/tabby
ENV TABBY_ROOT=/data ENV TABBY_ROOT=/data

View File

@ -3,6 +3,9 @@ name = "llama-cpp-bindings"
version = "0.5.0-dev" version = "0.5.0-dev"
edition = "2021" edition = "2021"
[features]
cuda = []
[build-dependencies] [build-dependencies]
cxx-build = "1.0" cxx-build = "1.0"
cmake = "0.1" cmake = "0.1"

View File

@ -1,25 +1,25 @@
use cmake::Config; use cmake::Config;
fn main() { fn main() {
let mut config = Config::new("llama.cpp");
if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
}
let dst = config.build();
println!("cargo:rerun-if-changed=cc/*.h"); println!("cargo:rerun-if-changed=cc/*.h");
println!("cargo:rerun-if-changed=cc/*.cc"); println!("cargo:rerun-if-changed=cc/*.cc");
println!("cargo:rustc-link-search=native={}/build", dst.display()); let mut config = Config::new("llama.cpp");
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");
if cfg!(target_os = "macos") { if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
println!("cargo:rustc-link-lib=framework=Foundation"); println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate"); println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal"); println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit"); println!("cargo:rustc-link-lib=framework=MetalKit");
} }
if cfg!(feature = "cuda") {
config.define("LLAMA_CUBLAS", "ON");
}
let dst = config.build();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.file("src/engine.cc") .file("src/engine.cc")

View File

@ -3,6 +3,9 @@ name = "tabby"
version = "0.5.0-dev" version = "0.5.0-dev"
edition = "2021" edition = "2021"
[features]
cuda = ["llama-cpp-bindings/cuda"]
[dependencies] [dependencies]
tabby-common = { path = "../tabby-common" } tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler" } tabby-scheduler = { path = "../tabby-scheduler" }
@ -43,7 +46,6 @@ textdistance = "1.0.2"
regex.workspace = true regex.workspace = true
thiserror.workspace = true thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" } llama-cpp-bindings = { path = "../llama-cpp-bindings" }
ctranslate2-bindings = { path = "../ctranslate2-bindings", optional = true }
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"
@ -53,10 +55,6 @@ features = [
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs "macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
] ]
[features]
link_shared = ["ctranslate2-bindings/link_shared"]
link_cuda_static = ["ctranslate2-bindings"]
[build-dependencies] [build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

View File

@ -13,7 +13,7 @@ pub fn create_engine(
if args.device != super::Device::ExperimentalHttp { if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model); let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir); let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata); let engine = create_ggml_engine(&args.device, &model_dir);
( (
engine, engine,
EngineInfo { EngineInfo {
@ -38,48 +38,6 @@ pub struct EngineInfo {
pub chat_template: Option<String>, pub chat_template: Option<String>,
} }
#[cfg(not(any(feature = "link_shared", feature = "link_cuda_static")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
_metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ggml_engine(&args.device, model_dir)
}
#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device.use_ggml_backend() {
create_ggml_engine(&args.device, model_dir)
} else {
create_ctranslate2_engine(args, model_dir, metadata)
}
}
#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
let device = format!("{}", args.device);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> { fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default() let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_v2_file()) .model_path(model_dir.ggml_q8_0_v2_file())

View File

@ -74,7 +74,7 @@ pub enum Device {
#[strum(serialize = "cpu")] #[strum(serialize = "cpu")]
Cpu, Cpu,
#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))] #[cfg(feature = "cuda")]
Cuda, Cuda,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
@ -86,22 +86,17 @@ pub enum Device {
} }
impl Device { impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Metal || *self == Device::Cpu
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Cpu
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn ggml_use_gpu(&self) -> bool { fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal *self == Device::Metal
} }
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] #[cfg(feature = "cuda")]
fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
fn ggml_use_gpu(&self) -> bool { fn ggml_use_gpu(&self) -> bool {
false false
} }
@ -124,26 +119,19 @@ pub struct ServeArgs {
#[clap(long, default_value_t=Device::Cpu)] #[clap(long, default_value_t=Device::Cpu)]
device: Device, device: Device,
/// GPU indices to run models, only applicable for CUDA. /// DEPRECATED: Do not use.
#[clap(long, default_values_t=[0])] #[deprecated(since = "0.5.0")]
#[clap(long, hide(true))]
device_indices: Vec<i32>, device_indices: Vec<i32>,
/// DEPRECATED: Do not use.
#[clap(long, hide(true))]
num_replicas_per_device: Option<usize>,
/// DEPRECATED: Do not use.
#[clap(long, hide(true))]
compute_type: Option<String>,
} }
pub async fn main(config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args); valid_args(args);
if args.device != Device::ExperimentalHttp { if args.device != Device::ExperimentalHttp {
download_model(&args.model, &args.device).await; download_model(&args.model).await;
if let Some(chat_model) = &args.chat_model { if let Some(chat_model) = &args.chat_model {
download_model(chat_model, &args.device).await; download_model(chat_model).await;
} }
} else { } else {
warn!("HTTP device is unstable and does not comply with semver expectations.") warn!("HTTP device is unstable and does not comply with semver expectations.")
@ -261,17 +249,8 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
} }
fn valid_args(args: &ServeArgs) { fn valid_args(args: &ServeArgs) {
if args.num_replicas_per_device.is_some() { if !args.device_indices.is_empty() {
warn!("--num-replicas-per-device is deprecated and will be removed in future release."); warn!("--device-indices is deprecated and will be removed in future release.");
}
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{
fatal!("CPU device only supports device indices = [0]");
}
if args.compute_type.is_some() {
warn!("--compute-type is deprecated and will be removed in future release.");
} }
} }
@ -285,15 +264,10 @@ fn start_heartbeat(args: &ServeArgs) {
}); });
} }
async fn download_model(model: &str, device: &Device) { async fn download_model(model: &str) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true); let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,); let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = if device.use_ggml_backend() { let download_result = downloader.download_ggml_files().await;
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};
download_result.unwrap_or_else(handler); download_result.unwrap_or_else(handler);
} }