Add download_model.py. Adjust ctranslate2 backend model structure (#153)

* adjust

* update

* update
add-tracing
Meng Zhang 2023-05-27 14:51:12 -07:00 committed by GitHub
parent 734957d1de
commit 0000312460
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 14 deletions

View File

@ -8,7 +8,15 @@ on:
branches: ["main" ]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.0
release-docker:
needs: pre-commit
runs-on: ubuntu-latest
permissions:
contents: read
@ -68,6 +76,7 @@ jobs:
cache-to: ${{ steps.cache.outputs.cache-to }}
release-binary:
needs: pre-commit
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

@ -1,14 +0,0 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.0

View File

@ -0,0 +1,31 @@
#!/usr/bin/env python3
from dataclasses import dataclass, field
from huggingface_hub import snapshot_download
from transformers import HfArgumentParser
@dataclass
class Arguments:
repo_id: str = field(
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
)
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
output_dir: str = field(metadata={"help": "Output directory"})
def parse_args():
parser = HfArgumentParser(Arguments)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print(f"Loading {args.repo_id}, this will take a while...")
snapshot_download(
local_dir=args.output_dir,
repo_id=args.repo_id,
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
)
print(f"Loaded {args.repo_id} !")

View File

@ -88,6 +88,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(
Path::new(&args.model)
.join("ctranslate2")
.join(device.clone())
.display()
.to_string(),