feat: improve download command - support local cache checking behavior (#178)
* move download.rs * add metadata * support prefer local args * fix format * replace errorchain with anyhowsupport-coreml
parent
5aa2370e19
commit
e8dbd36663
|
|
@ -2,15 +2,6 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
|
||||
dependencies = [
|
||||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
|
|
@ -110,6 +101,12 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.71"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
|
|
@ -176,21 +173,6 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.67"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"miniz_oxide 0.6.2",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
|
|
@ -706,16 +688,6 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "error-chain"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esaxx-rs"
|
||||
version = "0.1.8"
|
||||
|
|
@ -753,7 +725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide 0.7.1",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -878,12 +850,6 @@ dependencies = [
|
|||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
|
|
@ -1308,15 +1274,6 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
|
||||
dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.7.1"
|
||||
|
|
@ -1418,15 +1375,6 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.30.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.17.1"
|
||||
|
|
@ -1873,12 +1821,6 @@ dependencies = [
|
|||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.37.19"
|
||||
|
|
@ -2175,11 +2117,11 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
|||
name = "tabby"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"clap",
|
||||
"ctranslate2-bindings",
|
||||
"env_logger",
|
||||
"error-chain",
|
||||
"futures-util",
|
||||
"hyper",
|
||||
"indicatif 0.17.3",
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ lazy_static = { workspace = true }
|
|||
rust-embed = "6.6.1"
|
||||
mime_guess = "2.0.4"
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
reqwest = { version = "0.11.18", features = ["stream"] }
|
||||
error-chain = "0.12.4"
|
||||
reqwest = { version = "0.11.18", features = ["stream", "json"] }
|
||||
indicatif = "0.17.3"
|
||||
futures-util = "0.3.28"
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
anyhow = "1.0.71"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -1,92 +0,0 @@
|
|||
use std::cmp;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use clap::Args;
|
||||
use error_chain::error_chain;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
/// model id to fetch.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
error_chain! {
|
||||
foreign_links {
|
||||
Io(std::io::Error);
|
||||
HttpRequest(reqwest::Error);
|
||||
TemplateError(indicatif::style::TemplateError);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn main(args: &DownloadArgs) -> Result<()> {
|
||||
download_model(&args.model).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_model(model_id: &str) -> Result<()> {
|
||||
download_metadata(model_id).await?;
|
||||
download_model_file(model_id, "tokenizer.json").await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/config.json")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/vocabulary.txt")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/shared_vocabulary.txt")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/model.bin")).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_metadata(model_id: &str) -> Result<()> {
|
||||
let url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||
let filepath = ModelDir::new(model_id).metadata_file();
|
||||
download_file(&url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
||||
// Create url.
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
|
||||
|
||||
// Create destination path.
|
||||
let filepath = ModelDir::new(model_id).path_string(fname);
|
||||
download_file(&url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_file(url: &str, path: &str) -> Result<()> {
|
||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||
|
||||
// Reqwest setup
|
||||
let res = reqwest::get(url)
|
||||
.await
|
||||
.or(Err(format!("Failed to GET from '{}'", url)))?;
|
||||
|
||||
let total_size = res
|
||||
.content_length()
|
||||
.ok_or(format!("Failed to get content length from '{}'", url))?;
|
||||
|
||||
// Indicatif setup
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", path));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut stream = res.bytes_stream();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
let chunk = item.or(Err(format!("Error while downloading file")))?;
|
||||
file.write_all(&chunk)
|
||||
.or(Err(format!("Error while writing to file")))?;
|
||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", path));
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct HFTransformersInfo {
|
||||
auto_model: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct HFMetadata {
|
||||
transformers_info: HFTransformersInfo,
|
||||
}
|
||||
|
||||
impl HFMetadata {
|
||||
async fn from(model_id: &str) -> Result<HFMetadata> {
|
||||
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||
let metadata = reqwest::get(api_url).await?.json::<HFMetadata>().await?;
|
||||
Ok(metadata)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
auto_model: String,
|
||||
etags: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Metadata {
|
||||
pub async fn from(model_id: &str) -> Result<Metadata> {
|
||||
if let Some(metadata) = Self::from_local(model_id) {
|
||||
Ok(metadata)
|
||||
} else {
|
||||
let hf_metadata = HFMetadata::from(model_id).await?;
|
||||
let metadata = Metadata {
|
||||
auto_model: hf_metadata.transformers_info.auto_model,
|
||||
etags: HashMap::new(),
|
||||
};
|
||||
Ok(metadata)
|
||||
}
|
||||
}
|
||||
|
||||
fn from_local(model_id: &str) -> Option<Metadata> {
|
||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
||||
if fs::metadata(&metadata_file).is_ok() {
|
||||
let metadata = serdeconv::from_json_file(metadata_file);
|
||||
metadata.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_etag(&self, url: &str) -> bool {
|
||||
self.etags.get(url).is_some()
|
||||
}
|
||||
|
||||
pub async fn match_etag(&self, url: &str, path: &str) -> Result<bool> {
|
||||
let etag = self
|
||||
.etags
|
||||
.get(url)
|
||||
.ok_or(anyhow!("Path doesn't exist: {}", path))?;
|
||||
let etag_from_header = reqwest::get(url)
|
||||
.await?
|
||||
.headers()
|
||||
.get("etag")
|
||||
.ok_or(anyhow!("URL doesn't have etag header: '{}'", url))?
|
||||
.to_str()?
|
||||
.to_owned();
|
||||
|
||||
Ok(etag == &etag_from_header)
|
||||
}
|
||||
|
||||
pub async fn update_etag(&mut self, url: &str, path: &str) {
|
||||
self.etags.insert(url.to_owned(), path.to_owned());
|
||||
}
|
||||
|
||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||
let metadata_file = ModelDir::new(model_id).metadata_file();
|
||||
let metadata_file_path = Path::new(&metadata_file);
|
||||
serdeconv::to_json_file(self, metadata_file_path)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hf() {
|
||||
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap();
|
||||
assert_eq!(
|
||||
hf_metadata.transformers_info.auto_model,
|
||||
"AutoModelForCausalLM"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
mod metadata;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::cmp;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
use clap::Args;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
/// model id to fetch.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
/// If true, skip checking for remote model file.
|
||||
#[clap(long, default_value_t = true)]
|
||||
prefer_local_file: bool,
|
||||
}
|
||||
|
||||
pub async fn main(args: &DownloadArgs) -> Result<()> {
|
||||
download_model(&args.model, args.prefer_local_file).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl metadata::Metadata {
|
||||
async fn download(
|
||||
&mut self,
|
||||
model_id: &str,
|
||||
path: &str,
|
||||
prefer_local_file: bool,
|
||||
) -> Result<()> {
|
||||
// Create url.
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||
|
||||
// Create destination path.
|
||||
let filepath = ModelDir::new(model_id).path_string(path);
|
||||
|
||||
// Cache hit.
|
||||
let mut cache_hit = false;
|
||||
if fs::metadata(&filepath).is_ok() && self.has_etag(&url) {
|
||||
if prefer_local_file || self.match_etag(&url, path).await? {
|
||||
cache_hit = true
|
||||
}
|
||||
}
|
||||
|
||||
if !cache_hit {
|
||||
let etag = download_file(&url, &filepath).await?;
|
||||
self.update_etag(&url, &etag).await
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
|
||||
let mut metadata = metadata::Metadata::from(model_id).await?;
|
||||
|
||||
metadata
|
||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
||||
.await?;
|
||||
metadata
|
||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
||||
.await?;
|
||||
metadata
|
||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
||||
.await?;
|
||||
metadata
|
||||
.download(
|
||||
model_id,
|
||||
"ctranslate2/shared_vocabulary.txt",
|
||||
prefer_local_file,
|
||||
)
|
||||
.await?;
|
||||
metadata
|
||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
||||
.await?;
|
||||
metadata.save(model_id)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_file(url: &str, path: &str) -> Result<String> {
|
||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||
|
||||
// Reqwest setup
|
||||
let res = reqwest::get(url)
|
||||
.await
|
||||
.or(Err(anyhow!("Failed to GET from '{}'", url)))?;
|
||||
|
||||
let etag = res
|
||||
.headers()
|
||||
.get("etag")
|
||||
.ok_or(anyhow!("Failed to get etag from '{}", url))?
|
||||
.to_str()?
|
||||
.to_string();
|
||||
|
||||
let total_size = res
|
||||
.content_length()
|
||||
.ok_or(anyhow!("Failed to get content length from '{}'", url))?;
|
||||
|
||||
// Indicatif setup
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", path));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?;
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut stream = res.bytes_stream();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
let chunk = item.or(Err(anyhow!("Error while downloading file")))?;
|
||||
file.write_all(&chunk)
|
||||
.or(Err(anyhow!("Error while writing to file")))?;
|
||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", path));
|
||||
return Ok(etag);
|
||||
}
|
||||
Loading…
Reference in New Issue