diff --git a/crates/llama-cpp-bindings/Cargo.toml b/crates/llama-cpp-bindings/Cargo.toml index 4e8af38..a054d36 100644 --- a/crates/llama-cpp-bindings/Cargo.toml +++ b/crates/llama-cpp-bindings/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [features] cuda = [] +rocm = [] [build-dependencies] cxx-build = "1.0" diff --git a/crates/llama-cpp-bindings/build.rs b/crates/llama-cpp-bindings/build.rs index c621947..f9b825a 100644 --- a/crates/llama-cpp-bindings/build.rs +++ b/crates/llama-cpp-bindings/build.rs @@ -1,4 +1,4 @@ -use std::path::Path; +use std::{env, path::Path}; use cmake::Config; @@ -32,6 +32,41 @@ fn main() { println!("cargo:rustc-link-lib=cublas"); println!("cargo:rustc-link-lib=cublasLt"); } + if cfg!(feature = "rocm") { + let amd_gpu_targets: Vec<&str> = vec![ + "gfx803", + "gfx900", + "gfx906:xnack-", + "gfx908:xnack-", + "gfx90a:xnack+", + "gfx90a:xnack-", + "gfx940", + "gfx941", + "gfx942", + "gfx1010", + "gfx1012", + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + ]; + + let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string()); + config.define("LLAMA_HIPBLAS", "ON"); + config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root)); + config.define( + "CMAKE_CXX_COMPILER", + format!("{}/llvm/bin/clang++", rocm_root), + ); + config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";")); + println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); + println!("cargo:rustc-link-search=native={}/hip/lib", rocm_root); + println!("cargo:rustc-link-search=native={}/rocblas/lib", rocm_root); + println!("cargo:rustc-link-search=native={}/hipblas/lib", rocm_root); + println!("cargo:rustc-link-lib=amdhip64"); + println!("cargo:rustc-link-lib=rocblas"); + println!("cargo:rustc-link-lib=hipblas"); + } let dst = config.build(); println!("cargo:rustc-link-search=native={}/build", dst.display()); diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 577a118..9854933 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" default = ["ee"] ee = ["dep:tabby-webserver"] cuda = ["llama-cpp-bindings/cuda"] +rocm = ["llama-cpp-bindings/rocm"] experimental-http = ["dep:http-api-bindings"] [dependencies] diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index fe7c4a4..e13f138 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -69,6 +69,10 @@ pub enum Device { #[strum(serialize = "cuda")] Cuda, + #[cfg(feature = "rocm")] + #[strum(serialize = "rocm")] + Rocm, + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[strum(serialize = "metal")] Metal, @@ -89,7 +93,16 @@ impl Device { *self == Device::Cuda } - #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] + #[cfg(feature = "rocm")] + pub fn ggml_use_gpu(&self) -> bool { + *self == Device::Rocm + } + + #[cfg(not(any( + all(target_os = "macos", target_arch = "aarch64"), + feature = "cuda", + feature = "rocm", + )))] pub fn ggml_use_gpu(&self) -> bool { false }