feat: support ModelScope for model registry downloading (#477)
* feat: update cache info file after each file got downloaded * refactor: extract Downloader for model downloading logic * refactor: extract HuggingFaceRegistry * refactor: extract serde_json to workspace dependency * feat: add ModelScopeRegistry * refactor: extract registry to its sub dir. * feat: add scripts to mirror hf model to modelscoperelease-0.2
parent
f75a50de02
commit
d42942c379
|
|
@ -39,6 +39,17 @@ dependencies = [
|
||||||
"version_check",
|
"version_check",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ahash"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"once_cell",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aho-corasick"
|
name = "aho-corasick"
|
||||||
version = "0.7.20"
|
version = "0.7.20"
|
||||||
|
|
@ -57,6 +68,12 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "allocator-api2"
|
||||||
|
version = "0.2.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "android-tzdata"
|
name = "android-tzdata"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
|
@ -342,6 +359,24 @@ dependencies = [
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cached"
|
||||||
|
version = "0.46.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8cead8ece0da6b744b2ad8ef9c58a4cdc7ef2921e60a6ddfb9eaaa86839b5fc5"
|
||||||
|
dependencies = [
|
||||||
|
"ahash 0.8.3",
|
||||||
|
"async-trait",
|
||||||
|
"cached_proc_macro",
|
||||||
|
"cached_proc_macro_types",
|
||||||
|
"futures",
|
||||||
|
"hashbrown 0.14.0",
|
||||||
|
"instant",
|
||||||
|
"once_cell",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cached-path"
|
name = "cached-path"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
|
|
@ -364,6 +399,24 @@ dependencies = [
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cached_proc_macro"
|
||||||
|
version = "0.18.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7da8245dd5f576a41c3b76247b54c15b0e43139ceeb4f732033e15be7c005176"
|
||||||
|
dependencies = [
|
||||||
|
"darling 0.14.4",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cached_proc_macro_types"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.79"
|
version = "1.0.79"
|
||||||
|
|
@ -1177,7 +1230,7 @@ version = "0.12.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ahash",
|
"ahash 0.7.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1185,6 +1238,10 @@ name = "hashbrown"
|
||||||
version = "0.14.0"
|
version = "0.14.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
||||||
|
dependencies = [
|
||||||
|
"ahash 0.8.3",
|
||||||
|
"allocator-api2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
|
|
@ -2702,9 +2759,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.105"
|
version = "1.0.107"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
|
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
|
|
@ -3011,14 +3068,18 @@ name = "tabby-download"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
|
"cached",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"indicatif 0.17.3",
|
"indicatif 0.17.3",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"serdeconv",
|
"serdeconv",
|
||||||
"tabby-common",
|
"tabby-common",
|
||||||
"tokio-retry",
|
"tokio-retry",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"urlencoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3875,6 +3936,12 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "urlencoding"
|
||||||
|
version = "2.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utf8-ranges"
|
name = "utf8-ranges"
|
||||||
version = "1.0.5"
|
version = "1.0.5"
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ homepage = "https://github.com/TabbyML/tabby"
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
serdeconv = "0.4.1"
|
serdeconv = "0.4.1"
|
||||||
tokio = "1.28"
|
tokio = "1.28"
|
||||||
tokio-util = "0.7"
|
tokio-util = "0.7"
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ edition = "2021"
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
reqwest = { workspace = true, features = ["json"] }
|
reqwest = { workspace = true, features = ["json"] }
|
||||||
serde = { workspace = true, features = ["derive"] }
|
serde = { workspace = true, features = ["derive"] }
|
||||||
serde_json = "1.0.105"
|
serde_json = { workspace = true }
|
||||||
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
|
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,7 @@ serde = { workspace = true }
|
||||||
serdeconv = { workspace = true }
|
serdeconv = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tokio-retry = "0.3.0"
|
tokio-retry = "0.3.0"
|
||||||
|
urlencoding = "2.1.3"
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use std::{collections::HashMap, fs, path::Path};
|
use std::{collections::HashMap, fs, path::Path};
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
|
|
@ -33,15 +33,6 @@ impl CacheInfo {
|
||||||
self.etags.get(path).map(|x| x.as_str())
|
self.etags.get(path).map(|x| x.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> {
|
|
||||||
let key = res
|
|
||||||
.headers()
|
|
||||||
.get("etag")
|
|
||||||
.ok_or(anyhow!("etag key missing"))?
|
|
||||||
.to_str()?;
|
|
||||||
Ok(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
mod cache_info;
|
mod cache_info;
|
||||||
|
mod registry;
|
||||||
|
|
||||||
use std::{cmp, fs, io::Write, path::Path};
|
use std::{cmp, fs, io::Write, path::Path};
|
||||||
|
|
||||||
|
|
@ -6,111 +7,114 @@ use anyhow::{anyhow, Result};
|
||||||
use cache_info::CacheInfo;
|
use cache_info::CacheInfo;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
|
use registry::{create_registry, Registry};
|
||||||
use tabby_common::path::ModelDir;
|
use tabby_common::path::ModelDir;
|
||||||
use tokio_retry::{
|
use tokio_retry::{
|
||||||
strategy::{jitter, ExponentialBackoff},
|
strategy::{jitter, ExponentialBackoff},
|
||||||
Retry,
|
Retry,
|
||||||
};
|
};
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
impl CacheInfo {
|
pub struct Downloader {
|
||||||
async fn download(
|
model_id: String,
|
||||||
&mut self,
|
prefer_local_file: bool,
|
||||||
model_id: &str,
|
registry: Box<dyn Registry>,
|
||||||
path: &str,
|
}
|
||||||
prefer_local_file: bool,
|
|
||||||
is_optional: bool,
|
|
||||||
) -> Result<()> {
|
|
||||||
// Create url.
|
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
|
||||||
|
|
||||||
// Get cache key.
|
impl Downloader {
|
||||||
let local_cache_key = self.local_cache_key(path);
|
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
model_id: model_id.to_owned(),
|
||||||
|
prefer_local_file,
|
||||||
|
registry: create_registry(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create destination path.
|
pub async fn download_ctranslate2_files(&self) -> Result<()> {
|
||||||
let filepath = ModelDir::new(model_id).path_string(path);
|
let files = vec![
|
||||||
|
("tabby.json", true),
|
||||||
|
("tokenizer.json", true),
|
||||||
|
("ctranslate2/vocabulary.txt", false),
|
||||||
|
("ctranslate2/shared_vocabulary.txt", false),
|
||||||
|
("ctranslate2/vocabulary.json", false),
|
||||||
|
("ctranslate2/shared_vocabulary.json", false),
|
||||||
|
("ctranslate2/config.json", true),
|
||||||
|
("ctranslate2/model.bin", true),
|
||||||
|
];
|
||||||
|
|
||||||
// Cache hit.
|
self.download_files(&files).await
|
||||||
let local_file_ready = if prefer_local_file {
|
}
|
||||||
if let Some(local_cache_key) = local_cache_key {
|
|
||||||
if local_cache_key == "404" {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
fs::metadata(&filepath).is_ok()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
if !local_file_ready {
|
pub async fn download_ggml_files(&self) -> Result<()> {
|
||||||
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
let files = vec![
|
||||||
let etag = Retry::spawn(strategy, || {
|
("tabby.json", true),
|
||||||
download_file(&url, &filepath, local_cache_key, is_optional)
|
("tokenizer.json", true),
|
||||||
})
|
("ggml/q8_0.gguf", true),
|
||||||
|
];
|
||||||
|
self.download_files(&files).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
|
||||||
|
// Local path, no need for downloading.
|
||||||
|
if fs::metadata(&self.model_id).is_ok() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut cache_info = CacheInfo::from(&self.model_id).await;
|
||||||
|
for (path, required) in files {
|
||||||
|
download_model_file(
|
||||||
|
self.registry.as_ref(),
|
||||||
|
&mut cache_info,
|
||||||
|
&self.model_id,
|
||||||
|
path,
|
||||||
|
self.prefer_local_file,
|
||||||
|
*required,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
self.set_local_cache_key(path, &etag).await;
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn download_model(
|
async fn download_model_file(
|
||||||
|
registry: &dyn Registry,
|
||||||
|
cache_info: &mut CacheInfo,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
download_ctranslate2_files: bool,
|
path: &str,
|
||||||
download_ggml_files: bool,
|
|
||||||
prefer_local_file: bool,
|
prefer_local_file: bool,
|
||||||
|
required: bool,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if fs::metadata(model_id).is_ok() {
|
// Create url.
|
||||||
// Local path, no need for downloading.
|
let url = registry.build_url(model_id, path);
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("Start downloading model `{}`", model_id);
|
// Get cache key.
|
||||||
|
let local_cache_key = cache_info.local_cache_key(path);
|
||||||
|
|
||||||
let mut cache_info = CacheInfo::from(model_id).await;
|
// Create destination path.
|
||||||
|
let filepath = ModelDir::new(model_id).path_string(path);
|
||||||
|
|
||||||
let mut optional_files = vec![];
|
// Cache hit.
|
||||||
if download_ctranslate2_files {
|
let local_file_ready = if prefer_local_file {
|
||||||
optional_files.push("ctranslate2/vocabulary.txt");
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
optional_files.push("ctranslate2/shared_vocabulary.txt");
|
if local_cache_key == "404" {
|
||||||
optional_files.push("ctranslate2/vocabulary.json");
|
true
|
||||||
optional_files.push("ctranslate2/shared_vocabulary.json");
|
} else {
|
||||||
}
|
fs::metadata(&filepath).is_ok()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
if download_ggml_files {
|
if !local_file_ready {
|
||||||
optional_files.push("ggml/q8_0.gguf");
|
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
|
||||||
}
|
let etag = Retry::spawn(strategy, || {
|
||||||
|
download_file(registry, &url, &filepath, local_cache_key, !required)
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
for path in optional_files {
|
cache_info.set_local_cache_key(path, &etag).await;
|
||||||
cache_info
|
|
||||||
.download(
|
|
||||||
model_id,
|
|
||||||
path,
|
|
||||||
prefer_local_file,
|
|
||||||
/* is_optional */ true,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut required_files = vec!["tabby.json", "tokenizer.json"];
|
|
||||||
|
|
||||||
if download_ctranslate2_files {
|
|
||||||
required_files.push("ctranslate2/config.json");
|
|
||||||
required_files.push("ctranslate2/model.bin");
|
|
||||||
}
|
|
||||||
|
|
||||||
for path in required_files {
|
|
||||||
cache_info
|
|
||||||
.download(
|
|
||||||
model_id,
|
|
||||||
path,
|
|
||||||
prefer_local_file,
|
|
||||||
/* required= */ false,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cache_info.save(model_id)?;
|
cache_info.save(model_id)?;
|
||||||
|
|
@ -118,6 +122,7 @@ pub async fn download_model(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(
|
async fn download_file(
|
||||||
|
registry: &dyn Registry,
|
||||||
url: &str,
|
url: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
local_cache_key: Option<&str>,
|
local_cache_key: Option<&str>,
|
||||||
|
|
@ -137,7 +142,7 @@ async fn download_file(
|
||||||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string();
|
let remote_cache_key = registry.build_cache_key(url).await?;
|
||||||
if let Some(local_cache_key) = local_cache_key {
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
if local_cache_key == remote_cache_key {
|
if local_cache_key == remote_cache_key {
|
||||||
return Ok(remote_cache_key);
|
return Ok(remote_cache_key);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::Registry;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct HuggingFaceRegistry {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Registry for HuggingFaceRegistry {
|
||||||
|
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||||
|
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||||
|
let res = reqwest::get(url).await?;
|
||||||
|
let cache_key = res
|
||||||
|
.headers()
|
||||||
|
.get("etag")
|
||||||
|
.ok_or(anyhow!("etag key missing"))?
|
||||||
|
.to_str()?;
|
||||||
|
Ok(cache_key.to_owned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
mod huggingface;
|
||||||
|
mod modelscope;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use huggingface::HuggingFaceRegistry;
|
||||||
|
|
||||||
|
use self::modelscope::ModelScopeRegistry;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Registry {
|
||||||
|
fn build_url(&self, model_id: &str, path: &str) -> String;
|
||||||
|
async fn build_cache_key(&self, url: &str) -> Result<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_registry() -> Box<dyn Registry> {
|
||||||
|
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
|
||||||
|
if registry == "huggingface" {
|
||||||
|
Box::<HuggingFaceRegistry>::default()
|
||||||
|
} else if registry == "modelscope" {
|
||||||
|
Box::<ModelScopeRegistry>::default()
|
||||||
|
} else {
|
||||||
|
panic!("Unsupported registry {}", registry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use cached::proc_macro::cached;
|
||||||
|
use reqwest::Url;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::Registry;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ModelScopeRegistry {}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Registry for ModelScopeRegistry {
|
||||||
|
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||||
|
format!(
|
||||||
|
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
|
||||||
|
model_id,
|
||||||
|
urlencoding::encode(path)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||||
|
let url = Url::parse(url)?;
|
||||||
|
let model_id = url
|
||||||
|
.path()
|
||||||
|
.strip_prefix("/api/v1/models/")
|
||||||
|
.ok_or(anyhow!("Invalid url"))?
|
||||||
|
.strip_suffix("/repo")
|
||||||
|
.ok_or(anyhow!("Invalid url"))?;
|
||||||
|
|
||||||
|
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
|
||||||
|
let path = query
|
||||||
|
.get("FilePath")
|
||||||
|
.ok_or(anyhow!("Failed to extract FilePath"))?;
|
||||||
|
|
||||||
|
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
|
||||||
|
for x in revision_map.data.files {
|
||||||
|
if x.path == *path {
|
||||||
|
return Ok(x.sha256);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow!("Failed to find {} in revisions", path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cached(size = 1, result = true)]
|
||||||
|
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
|
||||||
|
let url = format!(
|
||||||
|
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
|
||||||
|
model_id
|
||||||
|
);
|
||||||
|
let resp = reqwest::get(url)
|
||||||
|
.await?
|
||||||
|
.json::<ModelScopeRevision>()
|
||||||
|
.await?;
|
||||||
|
Ok(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone)]
|
||||||
|
#[serde(rename_all = "PascalCase")]
|
||||||
|
struct ModelScopeRevision {
|
||||||
|
data: ModelScopeRevisionData,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone)]
|
||||||
|
#[serde(rename_all = "PascalCase")]
|
||||||
|
struct ModelScopeRevisionData {
|
||||||
|
files: Vec<ModelScopeRevisionFile>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Clone)]
|
||||||
|
#[serde(rename_all = "PascalCase")]
|
||||||
|
struct ModelScopeRevisionFile {
|
||||||
|
path: String,
|
||||||
|
sha256: String,
|
||||||
|
}
|
||||||
|
|
@ -17,7 +17,7 @@ utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
||||||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serdeconv = { workspace = true }
|
serdeconv = { workspace = true }
|
||||||
serde_json = "1.0"
|
serde_json = { workspace = true }
|
||||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||||
clap = { version = "4.3.0", features = ["derive"] }
|
clap = { version = "4.3.0", features = ["derive"] }
|
||||||
lazy_static = { workspace = true }
|
lazy_static = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
|
use tabby_download::Downloader;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
use crate::fatal;
|
use crate::fatal;
|
||||||
|
|
@ -15,19 +16,18 @@ pub struct DownloadArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &DownloadArgs) {
|
pub async fn main(args: &DownloadArgs) {
|
||||||
tabby_download::download_model(
|
let downloader = Downloader::new(&args.model, args.prefer_local_file);
|
||||||
&args.model,
|
|
||||||
/* download_ctranslate2_files= */ true,
|
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
||||||
/* download_ggml_files= */ true,
|
|
||||||
args.prefer_local_file,
|
downloader
|
||||||
)
|
.download_ctranslate2_files()
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(handler);
|
||||||
fatal!(
|
downloader
|
||||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
.download_ggml_files()
|
||||||
err,
|
.await
|
||||||
args.model
|
.unwrap_or_else(handler);
|
||||||
)
|
|
||||||
});
|
|
||||||
info!("model '{}' is ready", args.model);
|
info!("model '{}' is ready", args.model);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ use axum::{routing, Router, Server};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use tabby_common::{config::Config, usage};
|
use tabby_common::{config::Config, usage};
|
||||||
|
use tabby_download::Downloader;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
@ -136,25 +137,16 @@ fn should_download_ggml_files(device: &Device) -> bool {
|
||||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
valid_args(args);
|
valid_args(args);
|
||||||
|
|
||||||
|
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
|
||||||
if args.device != Device::ExperimentalHttp {
|
if args.device != Device::ExperimentalHttp {
|
||||||
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
|
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
|
||||||
let download_ggml_files = should_download_ggml_files(&args.device);
|
let download_result = if should_download_ggml_files(&args.device) {
|
||||||
|
downloader.download_ggml_files().await
|
||||||
|
} else {
|
||||||
|
downloader.download_ctranslate2_files().await
|
||||||
|
};
|
||||||
|
|
||||||
// Ensure model exists.
|
download_result.unwrap_or_else(handler);
|
||||||
tabby_download::download_model(
|
|
||||||
&args.model,
|
|
||||||
download_ctranslate2_files,
|
|
||||||
download_ggml_files,
|
|
||||||
/* prefer_local_file= */ true,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap_or_else(|err| {
|
|
||||||
fatal!(
|
|
||||||
"Failed to fetch model due to '{}', is '{}' a valid model id?",
|
|
||||||
err,
|
|
||||||
args.model
|
|
||||||
)
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
hf_model
|
||||||
|
ms_model
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# copy-to-modelscope
|
||||||
|
|
||||||
|
Scripts to copy huggingface model to modelscope
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
MODEL_ID=$1
|
||||||
|
ACCESS_TOKEN=$2
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 <model_id> <access_token>"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -z "${MODEL_ID}" ]; then
|
||||||
|
usage
|
||||||
|
fi
|
||||||
|
|
||||||
|
git clone https://huggingface.co/$MODEL_ID hf_model
|
||||||
|
git clone https://oauth2:${ACCESS_TOKEN}@www.modelscope.cn/$MODEL_ID.git ms_model
|
||||||
|
|
||||||
|
echo "Sync directory"
|
||||||
|
rsync -a --exclude '.git' hf_model/ ms_model/
|
||||||
|
|
||||||
|
echo "Create README.md"
|
||||||
|
cat <<EOF >ms_model/README.md
|
||||||
|
---
|
||||||
|
license: other
|
||||||
|
tasks:
|
||||||
|
- text-generation
|
||||||
|
---
|
||||||
|
|
||||||
|
# ${MODEL_ID}
|
||||||
|
|
||||||
|
This is an mirror of [${MODEL_ID}](https://huggingface.co/${MODEL_ID}).
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Create configuration.json"
|
||||||
|
cat <<EOF >ms_model/configuration.json
|
||||||
|
{
|
||||||
|
"framework": "pytorch",
|
||||||
|
"task": "text-generation",
|
||||||
|
"pipeline": {
|
||||||
|
"type": "text-generation-pipeline"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
set -x
|
||||||
|
cd ms_model
|
||||||
|
git add .
|
||||||
|
git commit -m "sync with upstream"
|
||||||
|
git push origin
|
||||||
|
|
||||||
|
echo "Success!"
|
||||||
|
rm -rf hf_model
|
||||||
|
rm -rf ms_model
|
||||||
Loading…
Reference in New Issue