From 0000312460f7ef22194708d413c4d0703db02d9e Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 27 May 2023 14:51:12 -0700 Subject: [PATCH] Add download_model.py. Adjust ctranslate2 backend model structure (#153) * adjust * update * update --- .github/workflows/docker.rust.yml | 9 ++++++++ .github/workflows/pre-commit.yml | 14 ------------ crates/tabby/python/download_model.py | 31 +++++++++++++++++++++++++++ crates/tabby/src/serve/mod.rs | 1 + 4 files changed, 41 insertions(+), 14 deletions(-) delete mode 100644 .github/workflows/pre-commit.yml create mode 100755 crates/tabby/python/download_model.py diff --git a/.github/workflows/docker.rust.yml b/.github/workflows/docker.rust.yml index f09364a..8556058 100644 --- a/.github/workflows/docker.rust.yml +++ b/.github/workflows/docker.rust.yml @@ -8,7 +8,15 @@ on: branches: ["main" ] jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.0 + release-docker: + needs: pre-commit runs-on: ubuntu-latest permissions: contents: read @@ -68,6 +76,7 @@ jobs: cache-to: ${{ steps.cache.outputs.cache-to }} release-binary: + needs: pre-commit runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index c2f7e71..0000000 --- a/.github/workflows/pre-commit.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: pre-commit - -on: - pull_request: - push: - branches: [main] - -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.0 diff --git a/crates/tabby/python/download_model.py b/crates/tabby/python/download_model.py new file mode 100755 index 0000000..f8ca744 --- /dev/null +++ b/crates/tabby/python/download_model.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +from dataclasses import dataclass, field + +from huggingface_hub import snapshot_download +from transformers import HfArgumentParser + + +@dataclass +class Arguments: + repo_id: str = field( + metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"} + ) + device: str = field(metadata={"help": "Device type for inference: cpu / cuda"}) + output_dir: str = field(metadata={"help": "Output directory"}) + + +def parse_args(): + parser = HfArgumentParser(Arguments) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + print(f"Loading {args.repo_id}, this will take a while...") + snapshot_download( + local_dir=args.output_dir, + repo_id=args.repo_id, + allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"], + ) + print(f"Loaded {args.repo_id} !") diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index a8630cb..a3f61dc 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -88,6 +88,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> { let options = TextInferenceEngineCreateOptionsBuilder::default() .model_path( Path::new(&args.model) + .join("ctranslate2") .join(device.clone()) .display() .to_string(),