fix: optional file should be put in cache key to avoid internet access when prefer_local_files = true (#241)

* feat: when file is 404, cache as NotFound

* explicitly mark optional file

* refactor
improve-workflow
Meng Zhang 2023-06-14 21:13:52 -07:00 committed by GitHub
parent cc760bc436
commit 9abf1a7521
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 58 additions and 29 deletions

View File

@ -14,6 +14,7 @@ impl CacheInfo {
model_id: &str,
path: &str,
prefer_local_file: bool,
is_optional: bool,
) -> Result<()> {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
@ -25,13 +26,22 @@ impl CacheInfo {
let filepath = ModelDir::new(model_id).path_string(path);
// Cache hit.
let mut local_file_ready = false;
if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() {
local_file_ready = true;
}
let local_file_ready = if prefer_local_file {
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == "404" {
true
} else {
fs::metadata(&filepath).is_ok()
}
} else {
false
}
} else {
false
};
if !local_file_ready {
let etag = download_file(&url, &filepath, local_cache_key).await?;
let etag = download_file(&url, &filepath, local_cache_key, is_optional).await?;
self.set_local_cache_key(path, &etag).await;
}
Ok(())
@ -46,39 +56,58 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
let mut cache_info = CacheInfo::from(model_id).await;
cache_info
.download(model_id, "tabby.json", prefer_local_file)
.await?;
cache_info
.download(model_id, "tokenizer.json", prefer_local_file)
.await?;
cache_info
.download(model_id, "ctranslate2/config.json", prefer_local_file)
.await?;
let _ = cache_info
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
.await;
let _ = cache_info
.download(
model_id,
"ctranslate2/shared_vocabulary.txt",
prefer_local_file,
)
.await;
cache_info
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
.await?;
cache_info.save(model_id)?;
let optional_files = vec![
"ctranslate2/vocabulary.txt",
"ctranslate2/shared_vocabulary.txt",
];
for path in optional_files {
cache_info
.download(
model_id,
path,
prefer_local_file,
/* is_optional */ true,
)
.await?;
}
let required_files = vec![
"tabby.json",
"tokenizer.json",
"ctranslate2/config.json",
"ctranslate2/model.bin",
];
for path in required_files {
cache_info
.download(
model_id,
path,
prefer_local_file,
/* required= */ false,
)
.await?;
}
cache_info.save(model_id)?;
Ok(())
}
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result<String> {
async fn download_file(
url: &str,
path: &str,
local_cache_key: Option<&str>,
is_optional: bool,
) -> Result<String> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
let res = reqwest::get(url).await?;
if is_optional && res.status() == 404 {
// Cache 404 for optional file.
return Ok("404".to_owned());
}
if !res.status().is_success() {
return Err(anyhow!(format!("Invalid url: {}", url)));
}