Add download_model.py. Adjust ctranslate2 backend model structure (#153)
* adjust * update * updateadd-tracing
parent
734957d1de
commit
0000312460
|
|
@ -8,7 +8,15 @@ on:
|
||||||
branches: ["main" ]
|
branches: ["main" ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
pre-commit:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v3
|
||||||
|
- uses: pre-commit/action@v3.0.0
|
||||||
|
|
||||||
release-docker:
|
release-docker:
|
||||||
|
needs: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
@ -68,6 +76,7 @@ jobs:
|
||||||
cache-to: ${{ steps.cache.outputs.cache-to }}
|
cache-to: ${{ steps.cache.outputs.cache-to }}
|
||||||
|
|
||||||
release-binary:
|
release-binary:
|
||||||
|
needs: pre-commit
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
name: pre-commit
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
push:
|
|
||||||
branches: [main]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
pre-commit:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
- uses: actions/setup-python@v3
|
|
||||||
- uses: pre-commit/action@v3.0.0
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Arguments:
|
||||||
|
repo_id: str = field(
|
||||||
|
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
|
||||||
|
)
|
||||||
|
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
|
||||||
|
output_dir: str = field(metadata={"help": "Output directory"})
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = HfArgumentParser(Arguments)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
print(f"Loading {args.repo_id}, this will take a while...")
|
||||||
|
snapshot_download(
|
||||||
|
local_dir=args.output_dir,
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
|
||||||
|
)
|
||||||
|
print(f"Loaded {args.repo_id} !")
|
||||||
|
|
@ -88,6 +88,7 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||||
.model_path(
|
.model_path(
|
||||||
Path::new(&args.model)
|
Path::new(&args.model)
|
||||||
|
.join("ctranslate2")
|
||||||
.join(device.clone())
|
.join(device.clone())
|
||||||
.display()
|
.display()
|
||||||
.to_string(),
|
.to_string(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue