feat: support ModelScope for model registry downloading (#477)
* feat: update cache info file after each file got downloaded * refactor: extract Downloader for model downloading logic * refactor: extract HuggingFaceRegistry * refactor: extract serde_json to workspace dependency * feat: add ModelScopeRegistry * refactor: extract registry to its sub dir. * feat: add scripts to mirror hf model to modelscoperelease-0.2
parent
f75a50de02
commit
d42942c379
|
|
@ -39,6 +39,17 @@ dependencies = [
|
|||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
|
|
@ -57,6 +68,12 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
|
|
@ -342,6 +359,24 @@ dependencies = [
|
|||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cead8ece0da6b744b2ad8ef9c58a4cdc7ef2921e60a6ddfb9eaaa86839b5fc5"
|
||||
dependencies = [
|
||||
"ahash 0.8.3",
|
||||
"async-trait",
|
||||
"cached_proc_macro",
|
||||
"cached_proc_macro_types",
|
||||
"futures",
|
||||
"hashbrown 0.14.0",
|
||||
"instant",
|
||||
"once_cell",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached-path"
|
||||
version = "0.6.1"
|
||||
|
|
@ -364,6 +399,24 @@ dependencies = [
|
|||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached_proc_macro"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8245dd5f576a41c3b76247b54c15b0e43139ceeb4f732033e15be7c005176"
|
||||
dependencies = [
|
||||
"darling 0.14.4",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cached_proc_macro_types"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.79"
|
||||
|
|
@ -1177,7 +1230,7 @@ version = "0.12.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.7.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1185,6 +1238,10 @@ name = "hashbrown"
|
|||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
||||
dependencies = [
|
||||
"ahash 0.8.3",
|
||||
"allocator-api2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
|
|
@ -2702,9 +2759,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.105"
|
||||
version = "1.0.107"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
|
||||
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
|
|
@ -3011,14 +3068,18 @@ name = "tabby-download"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"cached",
|
||||
"futures-util",
|
||||
"indicatif 0.17.3",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serdeconv",
|
||||
"tabby-common",
|
||||
"tokio-retry",
|
||||
"tracing",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3875,6 +3936,12 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urlencoding"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf8-ranges"
|
||||
version = "1.0.5"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ homepage = "https://github.com/TabbyML/tabby"
|
|||
[workspace.dependencies]
|
||||
lazy_static = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serdeconv = "0.4.1"
|
||||
tokio = "1.28"
|
||||
tokio-util = "0.7"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ edition = "2021"
|
|||
async-trait.workspace = true
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = "1.0.105"
|
||||
serde_json = { workspace = true }
|
||||
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
|||
|
|
@ -13,3 +13,7 @@ serde = { workspace = true }
|
|||
serdeconv = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tokio-retry = "0.3.0"
|
||||
urlencoding = "2.1.3"
|
||||
serde_json = { workspace = true }
|
||||
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
|
||||
async-trait = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
|
|
@ -33,15 +33,6 @@ impl CacheInfo {
|
|||
self.etags.get(path).map(|x| x.as_str())
|
||||
}
|
||||
|
||||
pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> {
|
||||
let key = res
|
||||
.headers()
|
||||
.get("etag")
|
||||
.ok_or(anyhow!("etag key missing"))?
|
||||
.to_str()?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
mod cache_info;
|
||||
mod registry;
|
||||
|
||||
use std::{cmp, fs, io::Write, path::Path};
|
||||
|
||||
|
|
@ -6,111 +7,114 @@ use anyhow::{anyhow, Result};
|
|||
use cache_info::CacheInfo;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use registry::{create_registry, Registry};
|
||||
use tabby_common::path::ModelDir;
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
impl CacheInfo {
|
||||
async fn download(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
path: &str,
|
||||
prefer_local_file: bool,
|
||||
is_optional: bool,
|
||||
) -> Result<()> {
|
||||
// Create url.
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||
pub struct Downloader {
|
||||
model_id: String,
|
||||
prefer_local_file: bool,
|
||||
registry: Box<dyn Registry>,
|
||||
}
|
||||
|
||||
// Get cache key.
|
||||
let local_cache_key = self.local_cache_key(path);
|
||||
impl Downloader {
|
||||
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
|
||||
Self {
|
||||
model_id: model_id.to_owned(),
|
||||
prefer_local_file,
|
||||
registry: create_registry(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create destination path.
|
||||
let filepath = ModelDir::new(model_id).path_string(path);
|
||||
pub async fn download_ctranslate2_files(&self) -> Result<()> {
|
||||
let files = vec![
|
||||
("tabby.json", true),
|
||||
("tokenizer.json", true),
|
||||
("ctranslate2/vocabulary.txt", false),
|
||||
("ctranslate2/shared_vocabulary.txt", false),
|
||||
("ctranslate2/vocabulary.json", false),
|
||||
("ctranslate2/shared_vocabulary.json", false),
|
||||
("ctranslate2/config.json", true),
|
||||
("ctranslate2/model.bin", true),
|
||||
];
|
||||
|
||||
// Cache hit.
|
||||
let local_file_ready = if prefer_local_file {
|
||||
if let Some(local_cache_key) = local_cache_key {
|
||||
if local_cache_key == "404" {
|
||||
true
|
||||
} else {
|
||||
fs::metadata(&filepath).is_ok()
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
self.download_files(&files).await
|
||||
}
|
||||
|
||||
if !local_file_ready {
|
||||
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
let etag = Retry::spawn(strategy, || {
|
||||
download_file(&url, &filepath, local_cache_key, is_optional)
|
||||
})
|
||||
pub async fn download_ggml_files(&self) -> Result<()> {
|
||||
let files = vec![
|
||||
("tabby.json", true),
|
||||
("tokenizer.json", true),
|
||||
("ggml/q8_0.gguf", true),
|
||||
];
|
||||
self.download_files(&files).await
|
||||
}
|
||||
|
||||
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
|
||||
// Local path, no need for downloading.
|
||||
if fs::metadata(&self.model_id).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut cache_info = CacheInfo::from(&self.model_id).await;
|
||||
for (path, required) in files {
|
||||
download_model_file(
|
||||
self.registry.as_ref(),
|
||||
&mut cache_info,
|
||||
&self.model_id,
|
||||
path,
|
||||
self.prefer_local_file,
|
||||
*required,
|
||||
)
|
||||
.await?;
|
||||
self.set_local_cache_key(path, &etag).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_model(
|
||||
async fn download_model_file(
|
||||
registry: &dyn Registry,
|
||||
cache_info: &mut CacheInfo,
|
||||
model_id: &str,
|
||||
download_ctranslate2_files: bool,
|
||||
download_ggml_files: bool,
|
||||
path: &str,
|
||||
prefer_local_file: bool,
|
||||
required: bool,
|
||||
) -> Result<()> {
|
||||
if fs::metadata(model_id).is_ok() {
|
||||
// Local path, no need for downloading.
|
||||
return Ok(());
|
||||
}
|
||||
// Create url.
|
||||
let url = registry.build_url(model_id, path);
|
||||
|
||||
info!("Start downloading model `{}`", model_id);
|
||||
// Get cache key.
|
||||
let local_cache_key = cache_info.local_cache_key(path);
|
||||
|
||||
let mut cache_info = CacheInfo::from(model_id).await;
|
||||
// Create destination path.
|
||||
let filepath = ModelDir::new(model_id).path_string(path);
|
||||
|
||||
let mut optional_files = vec![];
|
||||
if download_ctranslate2_files {
|
||||
optional_files.push("ctranslate2/vocabulary.txt");
|
||||
optional_files.push("ctranslate2/shared_vocabulary.txt");
|
||||
optional_files.push("ctranslate2/vocabulary.json");
|
||||
optional_files.push("ctranslate2/shared_vocabulary.json");
|
||||
}
|
||||
// Cache hit.
|
||||
let local_file_ready = if prefer_local_file {
|
||||
if let Some(local_cache_key) = local_cache_key {
|
||||
if local_cache_key == "404" {
|
||||
true
|
||||
} else {
|
||||
fs::metadata(&filepath).is_ok()
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if download_ggml_files {
|
||||
optional_files.push("ggml/q8_0.gguf");
|
||||
}
|
||||
if !local_file_ready {
|
||||
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
|
||||
let etag = Retry::spawn(strategy, || {
|
||||
download_file(registry, &url, &filepath, local_cache_key, !required)
|
||||
})
|
||||
.await?;
|
||||
|
||||
for path in optional_files {
|
||||
cache_info
|
||||
.download(
|
||||
model_id,
|
||||
path,
|
||||
prefer_local_file,
|
||||
/* is_optional */ true,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut required_files = vec!["tabby.json", "tokenizer.json"];
|
||||
|
||||
if download_ctranslate2_files {
|
||||
required_files.push("ctranslate2/config.json");
|
||||
required_files.push("ctranslate2/model.bin");
|
||||
}
|
||||
|
||||
for path in required_files {
|
||||
cache_info
|
||||
.download(
|
||||
model_id,
|
||||
path,
|
||||
prefer_local_file,
|
||||
/* required= */ false,
|
||||
)
|
||||
.await?;
|
||||
cache_info.set_local_cache_key(path, &etag).await;
|
||||
}
|
||||
|
||||
cache_info.save(model_id)?;
|
||||
|
|
@ -118,6 +122,7 @@ pub async fn download_model(
|
|||
}
|
||||
|
||||
async fn download_file(
|
||||
registry: &dyn Registry,
|
||||
url: &str,
|
||||
path: &str,
|
||||
local_cache_key: Option<&str>,
|
||||
|
|
@ -137,7 +142,7 @@ async fn download_file(
|
|||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||
}
|
||||
|
||||
let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string();
|
||||
let remote_cache_key = registry.build_cache_key(url).await?;
|
||||
if let Some(local_cache_key) = local_cache_key {
|
||||
if local_cache_key == remote_cache_key {
|
||||
return Ok(remote_cache_key);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::Registry;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct HuggingFaceRegistry {}
|
||||
|
||||
#[async_trait]
|
||||
impl Registry for HuggingFaceRegistry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
|
||||
}
|
||||
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||
let res = reqwest::get(url).await?;
|
||||
let cache_key = res
|
||||
.headers()
|
||||
.get("etag")
|
||||
.ok_or(anyhow!("etag key missing"))?
|
||||
.to_str()?;
|
||||
Ok(cache_key.to_owned())
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
mod huggingface;
|
||||
mod modelscope;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use huggingface::HuggingFaceRegistry;
|
||||
|
||||
use self::modelscope::ModelScopeRegistry;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Registry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String;
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String>;
|
||||
}
|
||||
|
||||
pub fn create_registry() -> Box<dyn Registry> {
|
||||
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
|
||||
if registry == "huggingface" {
|
||||
Box::<HuggingFaceRegistry>::default()
|
||||
} else if registry == "modelscope" {
|
||||
Box::<ModelScopeRegistry>::default()
|
||||
} else {
|
||||
panic!("Unsupported registry {}", registry);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use cached::proc_macro::cached;
|
||||
use reqwest::Url;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::Registry;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ModelScopeRegistry {}
|
||||
|
||||
#[async_trait]
|
||||
impl Registry for ModelScopeRegistry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||
format!(
|
||||
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
|
||||
model_id,
|
||||
urlencoding::encode(path)
|
||||
)
|
||||
}
|
||||
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||
let url = Url::parse(url)?;
|
||||
let model_id = url
|
||||
.path()
|
||||
.strip_prefix("/api/v1/models/")
|
||||
.ok_or(anyhow!("Invalid url"))?
|
||||
.strip_suffix("/repo")
|
||||
.ok_or(anyhow!("Invalid url"))?;
|
||||
|
||||
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
|
||||
let path = query
|
||||
.get("FilePath")
|
||||
.ok_or(anyhow!("Failed to extract FilePath"))?;
|
||||
|
||||
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
|
||||
for x in revision_map.data.files {
|
||||
if x.path == *path {
|
||||
return Ok(x.sha256);
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("Failed to find {} in revisions", path))
|
||||
}
|
||||
}
|
||||
|
||||
#[cached(size = 1, result = true)]
|
||||
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
|
||||
let url = format!(
|
||||
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
|
||||
model_id
|
||||
);
|
||||
let resp = reqwest::get(url)
|
||||
.await?
|
||||
.json::<ModelScopeRevision>()
|
||||
.await?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevision {
|
||||
data: ModelScopeRevisionData,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevisionData {
|
||||
files: Vec<ModelScopeRevisionFile>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevisionFile {
|
||||
path: String,
|
||||
sha256: String,
|
||||
}
|
||||
|
|
@ -17,7 +17,7 @@ utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
|||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||
serde = { workspace = true }
|
||||
serdeconv = { workspace = true }
|
||||
serde_json = "1.0"
|
||||
serde_json = { workspace = true }
|
||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||
clap = { version = "4.3.0", features = ["derive"] }
|
||||
lazy_static = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use clap::Args;
|
||||
use tabby_download::Downloader;
|
||||
use tracing::info;
|
||||
|
||||
use crate::fatal;
|
||||
|
|
@ -15,19 +16,18 @@ pub struct DownloadArgs {
|
|||
}
|
||||
|
||||
pub async fn main(args: &DownloadArgs) {
|
||||
tabby_download::download_model(
|
||||
&args.model,
|
||||
/* download_ctranslate2_files= */ true,
|
||||
/* download_ggml_files= */ true,
|
||||
args.prefer_local_file,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
fatal!(
|
||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||
err,
|
||||
args.model
|
||||
)
|
||||
});
|
||||
let downloader = Downloader::new(&args.model, args.prefer_local_file);
|
||||
|
||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
||||
|
||||
downloader
|
||||
.download_ctranslate2_files()
|
||||
.await
|
||||
.unwrap_or_else(handler);
|
||||
downloader
|
||||
.download_ggml_files()
|
||||
.await
|
||||
.unwrap_or_else(handler);
|
||||
|
||||
info!("model '{}' is ready", args.model);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ use axum::{routing, Router, Server};
|
|||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use clap::Args;
|
||||
use tabby_common::{config::Config, usage};
|
||||
use tabby_download::Downloader;
|
||||
use tokio::time::sleep;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::{info, warn};
|
||||
|
|
@ -136,25 +137,16 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
|||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
valid_args(args);
|
||||
|
||||
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
|
||||
let download_ggml_files = should_download_ggml_files(&args.device);
|
||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
||||
let download_result = if should_download_ggml_files(&args.device) {
|
||||
downloader.download_ggml_files().await
|
||||
} else {
|
||||
downloader.download_ctranslate2_files().await
|
||||
};
|
||||
|
||||
// Ensure model exists.
|
||||
tabby_download::download_model(
|
||||
&args.model,
|
||||
download_ctranslate2_files,
|
||||
download_ggml_files,
|
||||
/* prefer_local_file= */ true,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
fatal!(
|
||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
||||
err,
|
||||
args.model
|
||||
)
|
||||
});
|
||||
download_result.unwrap_or_else(handler);
|
||||
} else {
|
||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
hf_model
|
||||
ms_model
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# copy-to-modelscope
|
||||
|
||||
Scripts to copy huggingface model to modelscope
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
MODEL_ID=$1
|
||||
ACCESS_TOKEN=$2
|
||||
|
||||
usage() {
|
||||
echo "Usage: $0 <model_id> <access_token>"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [ -z "${MODEL_ID}" ]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
git clone https://huggingface.co/$MODEL_ID hf_model
|
||||
git clone https://oauth2:${ACCESS_TOKEN}@www.modelscope.cn/$MODEL_ID.git ms_model
|
||||
|
||||
echo "Sync directory"
|
||||
rsync -a --exclude '.git' hf_model/ ms_model/
|
||||
|
||||
echo "Create README.md"
|
||||
cat <<EOF >ms_model/README.md
|
||||
---
|
||||
license: other
|
||||
tasks:
|
||||
- text-generation
|
||||
---
|
||||
|
||||
# ${MODEL_ID}
|
||||
|
||||
This is an mirror of [${MODEL_ID}](https://huggingface.co/${MODEL_ID}).
|
||||
EOF
|
||||
|
||||
echo "Create configuration.json"
|
||||
cat <<EOF >ms_model/configuration.json
|
||||
{
|
||||
"framework": "pytorch",
|
||||
"task": "text-generation",
|
||||
"pipeline": {
|
||||
"type": "text-generation-pipeline"
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
set -x
|
||||
cd ms_model
|
||||
git add .
|
||||
git commit -m "sync with upstream"
|
||||
git push origin
|
||||
|
||||
echo "Success!"
|
||||
rm -rf hf_model
|
||||
rm -rf ms_model
|
||||
Loading…
Reference in New Issue