feat: turn on metal device by default on macosx / aarch64 devices (#398)

release-0.2
Meng Zhang 2023-09-05 13:03:49 +08:00 committed by GitHub
parent d85cd81139
commit a207520571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 9 deletions

View File

@ -70,10 +70,8 @@ jobs:
include:
- os: macos-11
target: aarch64-apple-darwin
flags: "--features metal"
- os: ubuntu-latest
target: x86_64-unknown-linux-gnu
flags: ""
env:
SCCACHE_GHA_ENABLED: true
@ -112,7 +110,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh
- name: Bulid release binary
run: cargo build --no-default-features ${{ matrix.flags }} --release --target ${{ matrix.target }}
run: cargo build --no-default-features --release --target ${{ matrix.target }}
- name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}

View File

@ -35,7 +35,9 @@ tantivy = { workspace = true }
anyhow = { workspace = true }
sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true }
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
[dependencies.uuid]
version = "1.3.3"
@ -49,7 +51,6 @@ features = [
default = ["scheduler"]
link_shared = ["ctranslate2-bindings/link_shared"]
scheduler = ["tabby-scheduler"]
metal = ["llama-cpp-bindings"]
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

View File

@ -141,7 +141,7 @@ impl CompletionState {
}
}
#[cfg(not(feature = "metal"))]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn create_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
@ -150,7 +150,7 @@ fn create_engine(
create_ctranslate2_engine(args, model_dir, metadata)
}
#[cfg(feature = "metal")]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
@ -183,7 +183,7 @@ fn create_ctranslate2_engine(
Box::new(CTranslate2Engine::create(options))
}
#[cfg(feature = "metal")]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_model_file())

View File

@ -55,7 +55,7 @@ pub enum Device {
#[strum(serialize = "cuda")]
Cuda,
#[cfg(feature = "metal")]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,
}