Add download_model.py. Adjust ctranslate2 backend model structure (#153)
* adjust * update * updateadd-tracing
parent
734957d1de
commit
0000312460
|
|
@ -8,7 +8,15 @@ on:
|
|||
branches: ["main" ]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
release-docker:
|
||||
needs: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
|
|
@ -68,6 +76,7 @@ jobs:
|
|||
cache-to: ${{ steps.cache.outputs.cache-to }}
|
||||
|
||||
release-binary:
|
||||
needs: pre-commit
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class Arguments:
|
||||
repo_id: str = field(
|
||||
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
|
||||
)
|
||||
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
|
||||
output_dir: str = field(metadata={"help": "Output directory"})
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = HfArgumentParser(Arguments)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print(f"Loading {args.repo_id}, this will take a while...")
|
||||
snapshot_download(
|
||||
local_dir=args.output_dir,
|
||||
repo_id=args.repo_id,
|
||||
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
|
||||
)
|
||||
print(f"Loaded {args.repo_id} !")
|
||||
|
|
@ -88,6 +88,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
|||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
.model_path(
|
||||
Path::new(&args.model)
|
||||
.join("ctranslate2")
|
||||
.join(device.clone())
|
||||
.display()
|
||||
.to_string(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue