feat: support ModelScope for model registry downloading (#477)

* feat: update cache info file after each file got downloaded

* refactor: extract Downloader for model downloading logic

* refactor: extract HuggingFaceRegistry

* refactor: extract serde_json to workspace dependency

* feat: add ModelScopeRegistry

* refactor: extract registry to its sub dir.

* feat: add scripts to mirror hf model to modelscope
release-0.2
Meng Zhang 2023-09-26 11:52:11 -07:00 committed by GitHub
parent f75a50de02
commit d42942c379
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 375 additions and 128 deletions

73
Cargo.lock generated
View File

@ -39,6 +39,17 @@ dependencies = [
"version_check",
]
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.20"
@ -57,6 +68,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
@ -342,6 +359,24 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "cached"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cead8ece0da6b744b2ad8ef9c58a4cdc7ef2921e60a6ddfb9eaaa86839b5fc5"
dependencies = [
"ahash 0.8.3",
"async-trait",
"cached_proc_macro",
"cached_proc_macro_types",
"futures",
"hashbrown 0.14.0",
"instant",
"once_cell",
"thiserror",
"tokio",
]
[[package]]
name = "cached-path"
version = "0.6.1"
@ -364,6 +399,24 @@ dependencies = [
"zip",
]
[[package]]
name = "cached_proc_macro"
version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8245dd5f576a41c3b76247b54c15b0e43139ceeb4f732033e15be7c005176"
dependencies = [
"darling 0.14.4",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "cached_proc_macro_types"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a4f925191b4367301851c6d99b09890311d74b0d43f274c0b34c86d308a3663"
[[package]]
name = "cc"
version = "1.0.79"
@ -1177,7 +1230,7 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash",
"ahash 0.7.6",
]
[[package]]
@ -1185,6 +1238,10 @@ name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash 0.8.3",
"allocator-api2",
]
[[package]]
name = "heck"
@ -2702,9 +2759,9 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.105"
version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
dependencies = [
"itoa",
"ryu",
@ -3011,14 +3068,18 @@ name = "tabby-download"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"cached",
"futures-util",
"indicatif 0.17.3",
"reqwest",
"serde",
"serde_json",
"serdeconv",
"tabby-common",
"tokio-retry",
"tracing",
"urlencoding",
]
[[package]]
@ -3875,6 +3936,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf8-ranges"
version = "1.0.5"

View File

@ -22,6 +22,7 @@ homepage = "https://github.com/TabbyML/tabby"
[workspace.dependencies]
lazy_static = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
serdeconv = "0.4.1"
tokio = "1.28"
tokio-util = "0.7"

View File

@ -7,7 +7,7 @@ edition = "2021"
async-trait.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.105"
serde_json = { workspace = true }
tabby-inference = { version = "0.1.0", path = "../tabby-inference" }
[dev-dependencies]

View File

@ -13,3 +13,7 @@ serde = { workspace = true }
serdeconv = { workspace = true }
tracing = { workspace = true }
tokio-retry = "0.3.0"
urlencoding = "2.1.3"
serde_json = { workspace = true }
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
async-trait = { workspace = true }

View File

@ -1,6 +1,6 @@
use std::{collections::HashMap, fs, path::Path};
use anyhow::{anyhow, Result};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tabby_common::path::ModelDir;
@ -33,15 +33,6 @@ impl CacheInfo {
self.etags.get(path).map(|x| x.as_str())
}
pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> {
let key = res
.headers()
.get("etag")
.ok_or(anyhow!("etag key missing"))?
.to_str()?;
Ok(key)
}
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
self.etags.insert(path.to_string(), cache_key.to_string());
}

View File

@ -1,4 +1,5 @@
mod cache_info;
mod registry;
use std::{cmp, fs, io::Write, path::Path};
@ -6,111 +7,114 @@ use anyhow::{anyhow, Result};
use cache_info::CacheInfo;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use registry::{create_registry, Registry};
use tabby_common::path::ModelDir;
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::info;
impl CacheInfo {
async fn download(
&mut self,
model_id: &str,
path: &str,
prefer_local_file: bool,
is_optional: bool,
) -> Result<()> {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
pub struct Downloader {
model_id: String,
prefer_local_file: bool,
registry: Box<dyn Registry>,
}
// Get cache key.
let local_cache_key = self.local_cache_key(path);
impl Downloader {
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
Self {
model_id: model_id.to_owned(),
prefer_local_file,
registry: create_registry(),
}
}
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(path);
pub async fn download_ctranslate2_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ctranslate2/vocabulary.txt", false),
("ctranslate2/shared_vocabulary.txt", false),
("ctranslate2/vocabulary.json", false),
("ctranslate2/shared_vocabulary.json", false),
("ctranslate2/config.json", true),
("ctranslate2/model.bin", true),
];
// Cache hit.
let local_file_ready = if prefer_local_file {
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == "404" {
true
} else {
fs::metadata(&filepath).is_ok()
}
} else {
false
}
} else {
false
};
self.download_files(&files).await
}
if !local_file_ready {
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
let etag = Retry::spawn(strategy, || {
download_file(&url, &filepath, local_cache_key, is_optional)
})
pub async fn download_ggml_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ggml/q8_0.gguf", true),
];
self.download_files(&files).await
}
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
// Local path, no need for downloading.
if fs::metadata(&self.model_id).is_ok() {
return Ok(());
}
let mut cache_info = CacheInfo::from(&self.model_id).await;
for (path, required) in files {
download_model_file(
self.registry.as_ref(),
&mut cache_info,
&self.model_id,
path,
self.prefer_local_file,
*required,
)
.await?;
self.set_local_cache_key(path, &etag).await;
}
Ok(())
}
}
pub async fn download_model(
async fn download_model_file(
registry: &dyn Registry,
cache_info: &mut CacheInfo,
model_id: &str,
download_ctranslate2_files: bool,
download_ggml_files: bool,
path: &str,
prefer_local_file: bool,
required: bool,
) -> Result<()> {
if fs::metadata(model_id).is_ok() {
// Local path, no need for downloading.
return Ok(());
}
// Create url.
let url = registry.build_url(model_id, path);
info!("Start downloading model `{}`", model_id);
// Get cache key.
let local_cache_key = cache_info.local_cache_key(path);
let mut cache_info = CacheInfo::from(model_id).await;
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(path);
let mut optional_files = vec![];
if download_ctranslate2_files {
optional_files.push("ctranslate2/vocabulary.txt");
optional_files.push("ctranslate2/shared_vocabulary.txt");
optional_files.push("ctranslate2/vocabulary.json");
optional_files.push("ctranslate2/shared_vocabulary.json");
}
// Cache hit.
let local_file_ready = if prefer_local_file {
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == "404" {
true
} else {
fs::metadata(&filepath).is_ok()
}
} else {
false
}
} else {
false
};
if download_ggml_files {
optional_files.push("ggml/q8_0.gguf");
}
if !local_file_ready {
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let etag = Retry::spawn(strategy, || {
download_file(registry, &url, &filepath, local_cache_key, !required)
})
.await?;
for path in optional_files {
cache_info
.download(
model_id,
path,
prefer_local_file,
/* is_optional */ true,
)
.await?;
}
let mut required_files = vec!["tabby.json", "tokenizer.json"];
if download_ctranslate2_files {
required_files.push("ctranslate2/config.json");
required_files.push("ctranslate2/model.bin");
}
for path in required_files {
cache_info
.download(
model_id,
path,
prefer_local_file,
/* required= */ false,
)
.await?;
cache_info.set_local_cache_key(path, &etag).await;
}
cache_info.save(model_id)?;
@ -118,6 +122,7 @@ pub async fn download_model(
}
async fn download_file(
registry: &dyn Registry,
url: &str,
path: &str,
local_cache_key: Option<&str>,
@ -137,7 +142,7 @@ async fn download_file(
return Err(anyhow!(format!("Invalid url: {}", url)));
}
let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string();
let remote_cache_key = registry.build_cache_key(url).await?;
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == remote_cache_key {
return Ok(remote_cache_key);

View File

@ -0,0 +1,24 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use crate::Registry;
#[derive(Default)]
pub struct HuggingFaceRegistry {}
#[async_trait]
impl Registry for HuggingFaceRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
}
async fn build_cache_key(&self, url: &str) -> Result<String> {
let res = reqwest::get(url).await?;
let cache_key = res
.headers()
.get("etag")
.ok_or(anyhow!("etag key missing"))?
.to_str()?;
Ok(cache_key.to_owned())
}
}

View File

@ -0,0 +1,25 @@
mod huggingface;
mod modelscope;
use anyhow::Result;
use async_trait::async_trait;
use huggingface::HuggingFaceRegistry;
use self::modelscope::ModelScopeRegistry;
#[async_trait]
pub trait Registry {
fn build_url(&self, model_id: &str, path: &str) -> String;
async fn build_cache_key(&self, url: &str) -> Result<String>;
}
pub fn create_registry() -> Box<dyn Registry> {
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
if registry == "huggingface" {
Box::<HuggingFaceRegistry>::default()
} else if registry == "modelscope" {
Box::<ModelScopeRegistry>::default()
} else {
panic!("Unsupported registry {}", registry);
}
}

View File

@ -0,0 +1,79 @@
use std::collections::HashMap;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cached::proc_macro::cached;
use reqwest::Url;
use serde::Deserialize;
use crate::Registry;
#[derive(Default)]
pub struct ModelScopeRegistry {}
#[async_trait]
impl Registry for ModelScopeRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!(
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
model_id,
urlencoding::encode(path)
)
}
async fn build_cache_key(&self, url: &str) -> Result<String> {
let url = Url::parse(url)?;
let model_id = url
.path()
.strip_prefix("/api/v1/models/")
.ok_or(anyhow!("Invalid url"))?
.strip_suffix("/repo")
.ok_or(anyhow!("Invalid url"))?;
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
let path = query
.get("FilePath")
.ok_or(anyhow!("Failed to extract FilePath"))?;
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
for x in revision_map.data.files {
if x.path == *path {
return Ok(x.sha256);
}
}
Err(anyhow!("Failed to find {} in revisions", path))
}
}
#[cached(size = 1, result = true)]
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
let url = format!(
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
model_id
);
let resp = reqwest::get(url)
.await?
.json::<ModelScopeRevision>()
.await?;
Ok(resp)
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevision {
data: ModelScopeRevisionData,
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevisionData {
files: Vec<ModelScopeRevisionFile>,
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevisionFile {
path: String,
sha256: String,
}

View File

@ -17,7 +17,7 @@ utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
serde = { workspace = true }
serdeconv = { workspace = true }
serde_json = "1.0"
serde_json = { workspace = true }
tower-http = { version = "0.4.0", features = ["cors"] }
clap = { version = "4.3.0", features = ["derive"] }
lazy_static = { workspace = true }

View File

@ -1,4 +1,5 @@
use clap::Args;
use tabby_download::Downloader;
use tracing::info;
use crate::fatal;
@ -15,19 +16,18 @@ pub struct DownloadArgs {
}
pub async fn main(args: &DownloadArgs) {
tabby_download::download_model(
&args.model,
/* download_ctranslate2_files= */ true,
/* download_ggml_files= */ true,
args.prefer_local_file,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
let downloader = Downloader::new(&args.model, args.prefer_local_file);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
downloader
.download_ctranslate2_files()
.await
.unwrap_or_else(handler);
downloader
.download_ggml_files()
.await
.unwrap_or_else(handler);
info!("model '{}' is ready", args.model);
}

View File

@ -12,6 +12,7 @@ use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args;
use tabby_common::{config::Config, usage};
use tabby_download::Downloader;
use tokio::time::sleep;
use tower_http::cors::CorsLayer;
use tracing::{info, warn};
@ -136,25 +137,16 @@ fn should_download_ggml_files(device: &Device) -> bool {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
let downloader = Downloader::new(&args.model, /* prefer_local_file= */ true);
if args.device != Device::ExperimentalHttp {
let download_ctranslate2_files = !should_download_ggml_files(&args.device);
let download_ggml_files = should_download_ggml_files(&args.device);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err,);
let download_result = if should_download_ggml_files(&args.device) {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};
// Ensure model exists.
tabby_download::download_model(
&args.model,
download_ctranslate2_files,
download_ggml_files,
/* prefer_local_file= */ true,
)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
download_result.unwrap_or_else(handler);
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
}

View File

@ -0,0 +1,2 @@
hf_model
ms_model

View File

@ -0,0 +1,3 @@
# copy-to-modelscope
Scripts to copy huggingface model to modelscope

View File

@ -0,0 +1,54 @@
#!/bin/bash
set -e
MODEL_ID=$1
ACCESS_TOKEN=$2
usage() {
echo "Usage: $0 <model_id> <access_token>"
exit 1
}
if [ -z "${MODEL_ID}" ]; then
usage
fi
git clone https://huggingface.co/$MODEL_ID hf_model
git clone https://oauth2:${ACCESS_TOKEN}@www.modelscope.cn/$MODEL_ID.git ms_model
echo "Sync directory"
rsync -a --exclude '.git' hf_model/ ms_model/
echo "Create README.md"
cat <<EOF >ms_model/README.md
---
license: other
tasks:
- text-generation
---
# ${MODEL_ID}
This is an mirror of [${MODEL_ID}](https://huggingface.co/${MODEL_ID}).
EOF
echo "Create configuration.json"
cat <<EOF >ms_model/configuration.json
{
"framework": "pytorch",
"task": "text-generation",
"pipeline": {
"type": "text-generation-pipeline"
}
}
EOF
set -x
cd ms_model
git add .
git commit -m "sync with upstream"
git push origin
echo "Success!"
rm -rf hf_model
rm -rf ms_model