From 9abf1a7521b687e9a8fc7b9be6674c80d00832e5 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Wed, 14 Jun 2023 21:13:52 -0700 Subject: [PATCH] fix: optional file should be put in cache key to avoid internet access when prefer_local_files = true (#241) * feat: when file is 404, cache as NotFound * explicitly mark optional file * refactor --- crates/tabby-download/src/lib.rs | 87 +++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 7085459..291f476 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -14,6 +14,7 @@ impl CacheInfo { model_id: &str, path: &str, prefer_local_file: bool, + is_optional: bool, ) -> Result<()> { // Create url. let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); @@ -25,13 +26,22 @@ impl CacheInfo { let filepath = ModelDir::new(model_id).path_string(path); // Cache hit. - let mut local_file_ready = false; - if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() { - local_file_ready = true; - } + let local_file_ready = if prefer_local_file { + if let Some(local_cache_key) = local_cache_key { + if local_cache_key == "404" { + true + } else { + fs::metadata(&filepath).is_ok() + } + } else { + false + } + } else { + false + }; if !local_file_ready { - let etag = download_file(&url, &filepath, local_cache_key).await?; + let etag = download_file(&url, &filepath, local_cache_key, is_optional).await?; self.set_local_cache_key(path, &etag).await; } Ok(()) @@ -46,39 +56,58 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<( let mut cache_info = CacheInfo::from(model_id).await; - cache_info - .download(model_id, "tabby.json", prefer_local_file) - .await?; - cache_info - .download(model_id, "tokenizer.json", prefer_local_file) - .await?; - cache_info - .download(model_id, "ctranslate2/config.json", prefer_local_file) - .await?; - let _ = cache_info - .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) - .await; - let _ = cache_info - .download( - model_id, - "ctranslate2/shared_vocabulary.txt", - prefer_local_file, - ) - .await; - cache_info - .download(model_id, "ctranslate2/model.bin", prefer_local_file) - .await?; - cache_info.save(model_id)?; + let optional_files = vec![ + "ctranslate2/vocabulary.txt", + "ctranslate2/shared_vocabulary.txt", + ]; + for path in optional_files { + cache_info + .download( + model_id, + path, + prefer_local_file, + /* is_optional */ true, + ) + .await?; + } + let required_files = vec![ + "tabby.json", + "tokenizer.json", + "ctranslate2/config.json", + "ctranslate2/model.bin", + ]; + for path in required_files { + cache_info + .download( + model_id, + path, + prefer_local_file, + /* required= */ false, + ) + .await?; + } + + cache_info.save(model_id)?; Ok(()) } -async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result { +async fn download_file( + url: &str, + path: &str, + local_cache_key: Option<&str>, + is_optional: bool, +) -> Result { fs::create_dir_all(Path::new(path).parent().unwrap())?; // Reqwest setup let res = reqwest::get(url).await?; + if is_optional && res.status() == 404 { + // Cache 404 for optional file. + return Ok("404".to_owned()); + } + if !res.status().is_success() { return Err(anyhow!(format!("Invalid url: {}", url))); }