feat: improve download command - support local cache checking behavior (#178)

* move download.rs

* add metadata

* support prefer local args

* fix format

* replace errorchain with anyhow
support-coreml
Meng Zhang 2023-05-31 23:42:04 -07:00 committed by GitHub
parent 5aa2370e19
commit e8dbd36663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 160 deletions

74
Cargo.lock generated
View File

@ -2,15 +2,6 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
@ -110,6 +101,12 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anyhow"
version = "1.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
[[package]]
name = "async-trait"
version = "0.1.68"
@ -176,21 +173,6 @@ dependencies = [
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide 0.6.2",
"object",
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.13.1"
@ -706,16 +688,6 @@ dependencies = [
"libc",
]
[[package]]
name = "error-chain"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc"
dependencies = [
"backtrace",
"version_check",
]
[[package]]
name = "esaxx-rs"
version = "0.1.8"
@ -753,7 +725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
dependencies = [
"crc32fast",
"miniz_oxide 0.7.1",
"miniz_oxide",
]
[[package]]
@ -878,12 +850,6 @@ dependencies = [
"wasi 0.11.0+wasi-snapshot-preview1",
]
[[package]]
name = "gimli"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]]
name = "glob"
version = "0.3.1"
@ -1308,15 +1274,6 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
dependencies = [
"adler",
]
[[package]]
name = "miniz_oxide"
version = "0.7.1"
@ -1418,15 +1375,6 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.30.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.17.1"
@ -1873,12 +1821,6 @@ dependencies = [
"walkdir",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustix"
version = "0.37.19"
@ -2175,11 +2117,11 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
name = "tabby"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"clap",
"ctranslate2-bindings",
"env_logger",
"error-chain",
"futures-util",
"hyper",
"indicatif 0.17.3",

View File

@ -23,11 +23,11 @@ lazy_static = { workspace = true }
rust-embed = "6.6.1"
mime_guess = "2.0.4"
strum = { version = "0.24", features = ["derive"] }
reqwest = { version = "0.11.18", features = ["stream"] }
error-chain = "0.12.4"
reqwest = { version = "0.11.18", features = ["stream", "json"] }
indicatif = "0.17.3"
futures-util = "0.3.28"
tabby-common = { path = "../tabby-common" }
anyhow = "1.0.71"
[dependencies.uuid]
version = "1.3.3"

View File

@ -1,92 +0,0 @@
use std::cmp;
use std::fs;
use std::io::Write;
use std::path::Path;
use clap::Args;
use error_chain::error_chain;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::path::ModelDir;
#[derive(Args)]
pub struct DownloadArgs {
/// model id to fetch.
#[clap(long)]
model: String,
}
error_chain! {
foreign_links {
Io(std::io::Error);
HttpRequest(reqwest::Error);
TemplateError(indicatif::style::TemplateError);
}
}
pub async fn main(args: &DownloadArgs) -> Result<()> {
download_model(&args.model).await.unwrap();
Ok(())
}
async fn download_model(model_id: &str) -> Result<()> {
download_metadata(model_id).await?;
download_model_file(model_id, "tokenizer.json").await?;
download_model_file(model_id, &format!("ctranslate2/config.json")).await?;
download_model_file(model_id, &format!("ctranslate2/vocabulary.txt")).await?;
download_model_file(model_id, &format!("ctranslate2/shared_vocabulary.txt")).await?;
download_model_file(model_id, &format!("ctranslate2/model.bin")).await?;
Ok(())
}
async fn download_metadata(model_id: &str) -> Result<()> {
let url = format!("https://huggingface.co/api/models/{}", model_id);
let filepath = ModelDir::new(model_id).metadata_file();
download_file(&url, &filepath).await
}
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(fname);
download_file(&url, &filepath).await
}
async fn download_file(url: &str, path: &str) -> Result<()> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
let res = reqwest::get(url)
.await
.or(Err(format!("Failed to GET from '{}'", url)))?;
let total_size = res
.content_length()
.ok_or(format!("Failed to get content length from '{}'", url))?;
// Indicatif setup
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path));
// download chunks
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item.or(Err(format!("Error while downloading file")))?;
file.write_all(&chunk)
.or(Err(format!("Error while writing to file")))?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", path));
return Ok(());
}

View File

@ -0,0 +1,101 @@
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use tabby_common::path::ModelDir;
#[derive(Deserialize)]
struct HFTransformersInfo {
auto_model: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct HFMetadata {
transformers_info: HFTransformersInfo,
}
impl HFMetadata {
async fn from(model_id: &str) -> Result<HFMetadata> {
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
let metadata = reqwest::get(api_url).await?.json::<HFMetadata>().await?;
Ok(metadata)
}
}
#[derive(Serialize, Deserialize)]
pub struct Metadata {
auto_model: String,
etags: HashMap<String, String>,
}
impl Metadata {
pub async fn from(model_id: &str) -> Result<Metadata> {
if let Some(metadata) = Self::from_local(model_id) {
Ok(metadata)
} else {
let hf_metadata = HFMetadata::from(model_id).await?;
let metadata = Metadata {
auto_model: hf_metadata.transformers_info.auto_model,
etags: HashMap::new(),
};
Ok(metadata)
}
}
fn from_local(model_id: &str) -> Option<Metadata> {
let metadata_file = ModelDir::new(model_id).metadata_file();
if fs::metadata(&metadata_file).is_ok() {
let metadata = serdeconv::from_json_file(metadata_file);
metadata.ok()
} else {
None
}
}
pub fn has_etag(&self, url: &str) -> bool {
self.etags.get(url).is_some()
}
pub async fn match_etag(&self, url: &str, path: &str) -> Result<bool> {
let etag = self
.etags
.get(url)
.ok_or(anyhow!("Path doesn't exist: {}", path))?;
let etag_from_header = reqwest::get(url)
.await?
.headers()
.get("etag")
.ok_or(anyhow!("URL doesn't have etag header: '{}'", url))?
.to_str()?
.to_owned();
Ok(etag == &etag_from_header)
}
pub async fn update_etag(&mut self, url: &str, path: &str) {
self.etags.insert(url.to_owned(), path.to_owned());
}
pub fn save(&self, model_id: &str) -> Result<()> {
let metadata_file = ModelDir::new(model_id).metadata_file();
let metadata_file_path = Path::new(&metadata_file);
serdeconv::to_json_file(self, metadata_file_path)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hf() {
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap();
assert_eq!(
hf_metadata.transformers_info.auto_model,
"AutoModelForCausalLM"
);
}
}

View File

@ -0,0 +1,128 @@
mod metadata;
use anyhow::{anyhow, Result};
use std::cmp;
use std::fs;
use std::io::Write;
use std::path::Path;
use clap::Args;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::path::ModelDir;
#[derive(Args)]
pub struct DownloadArgs {
/// model id to fetch.
#[clap(long)]
model: String,
/// If true, skip checking for remote model file.
#[clap(long, default_value_t = true)]
prefer_local_file: bool,
}
pub async fn main(args: &DownloadArgs) -> Result<()> {
download_model(&args.model, args.prefer_local_file).await?;
Ok(())
}
impl metadata::Metadata {
async fn download(
&mut self,
model_id: &str,
path: &str,
prefer_local_file: bool,
) -> Result<()> {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(path);
// Cache hit.
let mut cache_hit = false;
if fs::metadata(&filepath).is_ok() && self.has_etag(&url) {
if prefer_local_file || self.match_etag(&url, path).await? {
cache_hit = true
}
}
if !cache_hit {
let etag = download_file(&url, &filepath).await?;
self.update_etag(&url, &etag).await
}
Ok(())
}
}
async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
let mut metadata = metadata::Metadata::from(model_id).await?;
metadata
.download(model_id, "tokenizer.json", prefer_local_file)
.await?;
metadata
.download(model_id, "ctranslate2/config.json", prefer_local_file)
.await?;
metadata
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
.await?;
metadata
.download(
model_id,
"ctranslate2/shared_vocabulary.txt",
prefer_local_file,
)
.await?;
metadata
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
.await?;
metadata.save(model_id)?;
Ok(())
}
async fn download_file(url: &str, path: &str) -> Result<String> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
let res = reqwest::get(url)
.await
.or(Err(anyhow!("Failed to GET from '{}'", url)))?;
let etag = res
.headers()
.get("etag")
.ok_or(anyhow!("Failed to get etag from '{}", url))?
.to_str()?
.to_string();
let total_size = res
.content_length()
.ok_or(anyhow!("Failed to get content length from '{}'", url))?;
// Indicatif setup
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path));
// download chunks
let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?;
let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item.or(Err(anyhow!("Error while downloading file")))?;
file.write_all(&chunk)
.or(Err(anyhow!("Error while writing to file")))?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", path));
return Ok(etag);
}