diff --git a/Cargo.lock b/Cargo.lock index f88644c..4018898 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -110,6 +101,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + [[package]] name = "async-trait" version = "0.1.68" @@ -176,21 +173,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "backtrace" -version = "0.3.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide 0.6.2", - "object", - "rustc-demangle", -] - [[package]] name = "base64" version = "0.13.1" @@ -706,16 +688,6 @@ dependencies = [ "libc", ] -[[package]] -name = "error-chain" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" -dependencies = [ - "backtrace", - "version_check", -] - [[package]] name = "esaxx-rs" version = "0.1.8" @@ -753,7 +725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" dependencies = [ "crc32fast", - "miniz_oxide 0.7.1", + "miniz_oxide", ] [[package]] @@ -878,12 +850,6 @@ dependencies = [ "wasi 0.11.0+wasi-snapshot-preview1", ] -[[package]] -name = "gimli" -version = "0.27.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" - [[package]] name = "glob" version = "0.3.1" @@ -1308,15 +1274,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "miniz_oxide" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" -dependencies = [ - "adler", -] - [[package]] name = "miniz_oxide" version = "0.7.1" @@ -1418,15 +1375,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" -[[package]] -name = "object" -version = "0.30.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.17.1" @@ -1873,12 +1821,6 @@ dependencies = [ "walkdir", ] -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - [[package]] name = "rustix" version = "0.37.19" @@ -2175,11 +2117,11 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" name = "tabby" version = "0.1.0" dependencies = [ + "anyhow", "axum", "clap", "ctranslate2-bindings", "env_logger", - "error-chain", "futures-util", "hyper", "indicatif 0.17.3", diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 3d2dd74..26d25e3 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -23,11 +23,11 @@ lazy_static = { workspace = true } rust-embed = "6.6.1" mime_guess = "2.0.4" strum = { version = "0.24", features = ["derive"] } -reqwest = { version = "0.11.18", features = ["stream"] } -error-chain = "0.12.4" +reqwest = { version = "0.11.18", features = ["stream", "json"] } indicatif = "0.17.3" futures-util = "0.3.28" tabby-common = { path = "../tabby-common" } +anyhow = "1.0.71" [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs deleted file mode 100644 index c458a44..0000000 --- a/crates/tabby/src/download.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::cmp; -use std::fs; -use std::io::Write; -use std::path::Path; - -use clap::Args; -use error_chain::error_chain; -use futures_util::StreamExt; -use indicatif::{ProgressBar, ProgressStyle}; -use tabby_common::path::ModelDir; - -#[derive(Args)] -pub struct DownloadArgs { - /// model id to fetch. - #[clap(long)] - model: String, -} - -error_chain! { - foreign_links { - Io(std::io::Error); - HttpRequest(reqwest::Error); - TemplateError(indicatif::style::TemplateError); - } -} - -pub async fn main(args: &DownloadArgs) -> Result<()> { - download_model(&args.model).await.unwrap(); - Ok(()) -} - -async fn download_model(model_id: &str) -> Result<()> { - download_metadata(model_id).await?; - download_model_file(model_id, "tokenizer.json").await?; - download_model_file(model_id, &format!("ctranslate2/config.json")).await?; - download_model_file(model_id, &format!("ctranslate2/vocabulary.txt")).await?; - download_model_file(model_id, &format!("ctranslate2/shared_vocabulary.txt")).await?; - download_model_file(model_id, &format!("ctranslate2/model.bin")).await?; - Ok(()) -} - -async fn download_metadata(model_id: &str) -> Result<()> { - let url = format!("https://huggingface.co/api/models/{}", model_id); - let filepath = ModelDir::new(model_id).metadata_file(); - download_file(&url, &filepath).await -} - -async fn download_model_file(model_id: &str, fname: &str) -> Result<()> { - // Create url. - let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname); - - // Create destination path. - let filepath = ModelDir::new(model_id).path_string(fname); - download_file(&url, &filepath).await -} - -async fn download_file(url: &str, path: &str) -> Result<()> { - fs::create_dir_all(Path::new(path).parent().unwrap())?; - - // Reqwest setup - let res = reqwest::get(url) - .await - .or(Err(format!("Failed to GET from '{}'", url)))?; - - let total_size = res - .content_length() - .ok_or(format!("Failed to get content length from '{}'", url))?; - - // Indicatif setup - let pb = ProgressBar::new(total_size); - pb.set_style(ProgressStyle::default_bar() - .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? - .progress_chars("#>-")); - pb.set_message(format!("Downloading {}", path)); - - // download chunks - let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?; - let mut downloaded: u64 = 0; - let mut stream = res.bytes_stream(); - - while let Some(item) = stream.next().await { - let chunk = item.or(Err(format!("Error while downloading file")))?; - file.write_all(&chunk) - .or(Err(format!("Error while writing to file")))?; - let new = cmp::min(downloaded + (chunk.len() as u64), total_size); - downloaded = new; - pb.set_position(new); - } - - pb.finish_with_message(format!("Downloaded {}", path)); - return Ok(()); -} diff --git a/crates/tabby/src/download/metadata.rs b/crates/tabby/src/download/metadata.rs new file mode 100644 index 0000000..9f21d93 --- /dev/null +++ b/crates/tabby/src/download/metadata.rs @@ -0,0 +1,101 @@ +use anyhow::{anyhow, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; +use tabby_common::path::ModelDir; + +#[derive(Deserialize)] +struct HFTransformersInfo { + auto_model: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct HFMetadata { + transformers_info: HFTransformersInfo, +} + +impl HFMetadata { + async fn from(model_id: &str) -> Result { + let api_url = format!("https://huggingface.co/api/models/{}", model_id); + let metadata = reqwest::get(api_url).await?.json::().await?; + Ok(metadata) + } +} + +#[derive(Serialize, Deserialize)] +pub struct Metadata { + auto_model: String, + etags: HashMap, +} + +impl Metadata { + pub async fn from(model_id: &str) -> Result { + if let Some(metadata) = Self::from_local(model_id) { + Ok(metadata) + } else { + let hf_metadata = HFMetadata::from(model_id).await?; + let metadata = Metadata { + auto_model: hf_metadata.transformers_info.auto_model, + etags: HashMap::new(), + }; + Ok(metadata) + } + } + + fn from_local(model_id: &str) -> Option { + let metadata_file = ModelDir::new(model_id).metadata_file(); + if fs::metadata(&metadata_file).is_ok() { + let metadata = serdeconv::from_json_file(metadata_file); + metadata.ok() + } else { + None + } + } + + pub fn has_etag(&self, url: &str) -> bool { + self.etags.get(url).is_some() + } + + pub async fn match_etag(&self, url: &str, path: &str) -> Result { + let etag = self + .etags + .get(url) + .ok_or(anyhow!("Path doesn't exist: {}", path))?; + let etag_from_header = reqwest::get(url) + .await? + .headers() + .get("etag") + .ok_or(anyhow!("URL doesn't have etag header: '{}'", url))? + .to_str()? + .to_owned(); + + Ok(etag == &etag_from_header) + } + + pub async fn update_etag(&mut self, url: &str, path: &str) { + self.etags.insert(url.to_owned(), path.to_owned()); + } + + pub fn save(&self, model_id: &str) -> Result<()> { + let metadata_file = ModelDir::new(model_id).metadata_file(); + let metadata_file_path = Path::new(&metadata_file); + serdeconv::to_json_file(self, metadata_file_path)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_hf() { + let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap(); + assert_eq!( + hf_metadata.transformers_info.auto_model, + "AutoModelForCausalLM" + ); + } +} diff --git a/crates/tabby/src/download/mod.rs b/crates/tabby/src/download/mod.rs new file mode 100644 index 0000000..59c9039 --- /dev/null +++ b/crates/tabby/src/download/mod.rs @@ -0,0 +1,128 @@ +mod metadata; + +use anyhow::{anyhow, Result}; +use std::cmp; +use std::fs; +use std::io::Write; +use std::path::Path; + +use clap::Args; +use futures_util::StreamExt; +use indicatif::{ProgressBar, ProgressStyle}; +use tabby_common::path::ModelDir; + +#[derive(Args)] +pub struct DownloadArgs { + /// model id to fetch. + #[clap(long)] + model: String, + + /// If true, skip checking for remote model file. + #[clap(long, default_value_t = true)] + prefer_local_file: bool, +} + +pub async fn main(args: &DownloadArgs) -> Result<()> { + download_model(&args.model, args.prefer_local_file).await?; + Ok(()) +} + +impl metadata::Metadata { + async fn download( + &mut self, + model_id: &str, + path: &str, + prefer_local_file: bool, + ) -> Result<()> { + // Create url. + let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); + + // Create destination path. + let filepath = ModelDir::new(model_id).path_string(path); + + // Cache hit. + let mut cache_hit = false; + if fs::metadata(&filepath).is_ok() && self.has_etag(&url) { + if prefer_local_file || self.match_etag(&url, path).await? { + cache_hit = true + } + } + + if !cache_hit { + let etag = download_file(&url, &filepath).await?; + self.update_etag(&url, &etag).await + } + + Ok(()) + } +} + +async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> { + let mut metadata = metadata::Metadata::from(model_id).await?; + + metadata + .download(model_id, "tokenizer.json", prefer_local_file) + .await?; + metadata + .download(model_id, "ctranslate2/config.json", prefer_local_file) + .await?; + metadata + .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) + .await?; + metadata + .download( + model_id, + "ctranslate2/shared_vocabulary.txt", + prefer_local_file, + ) + .await?; + metadata + .download(model_id, "ctranslate2/model.bin", prefer_local_file) + .await?; + metadata.save(model_id)?; + Ok(()) +} + +async fn download_file(url: &str, path: &str) -> Result { + fs::create_dir_all(Path::new(path).parent().unwrap())?; + + // Reqwest setup + let res = reqwest::get(url) + .await + .or(Err(anyhow!("Failed to GET from '{}'", url)))?; + + let etag = res + .headers() + .get("etag") + .ok_or(anyhow!("Failed to get etag from '{}", url))? + .to_str()? + .to_string(); + + let total_size = res + .content_length() + .ok_or(anyhow!("Failed to get content length from '{}'", url))?; + + // Indicatif setup + let pb = ProgressBar::new(total_size); + pb.set_style(ProgressStyle::default_bar() + .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? + .progress_chars("#>-")); + pb.set_message(format!("Downloading {}", path)); + + // download chunks + let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?; + let mut downloaded: u64 = 0; + let mut stream = res.bytes_stream(); + + while let Some(item) = stream.next().await { + let chunk = item.or(Err(anyhow!("Error while downloading file")))?; + file.write_all(&chunk) + .or(Err(anyhow!("Error while writing to file")))?; + let new = cmp::min(downloaded + (chunk.len() as u64), total_size); + downloaded = new; + pb.set_position(new); + } + + pb.finish_with_message(format!("Downloaded {}", path)); + return Ok(etag); +}