feat: add rocm support (#913)
* Added build configurations for Intel and AMD hardware * Improved rocm build * Added options for OneAPI and ROCm * Build llama using icx * [autofix.ci] apply automated fixes * Fixed rocm image * Build ROCm * Tried to adjust compile flags for SYCL * Removed references to oneAPI * Provide info about the used device for ROCm * Added ROCm documentation * Addressed review comments * Refactored to expose generic accelerator information * Pull request cleanup * cleanup * cleanup * Delete .github/workflows/docker-cuda.yml * Delete .github/workflows/docker-rocm.yml * Delete crates/tabby-common/src/api/accelerator.rs * update * cleanup * update * update * update * update --------- Co-authored-by: Cromefire_ <cromefire+git@pm.me> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>add-prompt-lookup
parent
2b131ad1d2
commit
9c905e4849
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
|||
|
||||
[features]
|
||||
cuda = []
|
||||
rocm = []
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::path::Path;
|
||||
use std::{env, path::Path};
|
||||
|
||||
use cmake::Config;
|
||||
|
||||
|
|
@ -32,6 +32,41 @@ fn main() {
|
|||
println!("cargo:rustc-link-lib=cublas");
|
||||
println!("cargo:rustc-link-lib=cublasLt");
|
||||
}
|
||||
if cfg!(feature = "rocm") {
|
||||
let amd_gpu_targets: Vec<&str> = vec![
|
||||
"gfx803",
|
||||
"gfx900",
|
||||
"gfx906:xnack-",
|
||||
"gfx908:xnack-",
|
||||
"gfx90a:xnack+",
|
||||
"gfx90a:xnack-",
|
||||
"gfx940",
|
||||
"gfx941",
|
||||
"gfx942",
|
||||
"gfx1010",
|
||||
"gfx1012",
|
||||
"gfx1030",
|
||||
"gfx1100",
|
||||
"gfx1101",
|
||||
"gfx1102",
|
||||
];
|
||||
|
||||
let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string());
|
||||
config.define("LLAMA_HIPBLAS", "ON");
|
||||
config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root));
|
||||
config.define(
|
||||
"CMAKE_CXX_COMPILER",
|
||||
format!("{}/llvm/bin/clang++", rocm_root),
|
||||
);
|
||||
config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";"));
|
||||
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
|
||||
println!("cargo:rustc-link-search=native={}/hip/lib", rocm_root);
|
||||
println!("cargo:rustc-link-search=native={}/rocblas/lib", rocm_root);
|
||||
println!("cargo:rustc-link-search=native={}/hipblas/lib", rocm_root);
|
||||
println!("cargo:rustc-link-lib=amdhip64");
|
||||
println!("cargo:rustc-link-lib=rocblas");
|
||||
println!("cargo:rustc-link-lib=hipblas");
|
||||
}
|
||||
|
||||
let dst = config.build();
|
||||
println!("cargo:rustc-link-search=native={}/build", dst.display());
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ edition = "2021"
|
|||
default = ["ee"]
|
||||
ee = ["dep:tabby-webserver"]
|
||||
cuda = ["llama-cpp-bindings/cuda"]
|
||||
rocm = ["llama-cpp-bindings/rocm"]
|
||||
experimental-http = ["dep:http-api-bindings"]
|
||||
|
||||
[dependencies]
|
||||
|
|
|
|||
|
|
@ -69,6 +69,10 @@ pub enum Device {
|
|||
#[strum(serialize = "cuda")]
|
||||
Cuda,
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
#[strum(serialize = "rocm")]
|
||||
Rocm,
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
#[strum(serialize = "metal")]
|
||||
Metal,
|
||||
|
|
@ -89,7 +93,16 @@ impl Device {
|
|||
*self == Device::Cuda
|
||||
}
|
||||
|
||||
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
||||
#[cfg(feature = "rocm")]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Rocm
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
all(target_os = "macos", target_arch = "aarch64"),
|
||||
feature = "cuda",
|
||||
feature = "rocm",
|
||||
)))]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue