fix: optional file should be put in cache key to avoid internet access when prefer_local_files = true (#241)
* feat: when file is 404, cache as NotFound * explicitly mark optional file * refactorimprove-workflow
parent
cc760bc436
commit
9abf1a7521
|
|
@ -14,6 +14,7 @@ impl CacheInfo {
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
prefer_local_file: bool,
|
prefer_local_file: bool,
|
||||||
|
is_optional: bool,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Create url.
|
// Create url.
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||||
|
|
@ -25,13 +26,22 @@ impl CacheInfo {
|
||||||
let filepath = ModelDir::new(model_id).path_string(path);
|
let filepath = ModelDir::new(model_id).path_string(path);
|
||||||
|
|
||||||
// Cache hit.
|
// Cache hit.
|
||||||
let mut local_file_ready = false;
|
let local_file_ready = if prefer_local_file {
|
||||||
if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() {
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
local_file_ready = true;
|
if local_cache_key == "404" {
|
||||||
}
|
true
|
||||||
|
} else {
|
||||||
|
fs::metadata(&filepath).is_ok()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
if !local_file_ready {
|
if !local_file_ready {
|
||||||
let etag = download_file(&url, &filepath, local_cache_key).await?;
|
let etag = download_file(&url, &filepath, local_cache_key, is_optional).await?;
|
||||||
self.set_local_cache_key(path, &etag).await;
|
self.set_local_cache_key(path, &etag).await;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -46,39 +56,58 @@ pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<(
|
||||||
|
|
||||||
let mut cache_info = CacheInfo::from(model_id).await;
|
let mut cache_info = CacheInfo::from(model_id).await;
|
||||||
|
|
||||||
cache_info
|
let optional_files = vec![
|
||||||
.download(model_id, "tabby.json", prefer_local_file)
|
"ctranslate2/vocabulary.txt",
|
||||||
.await?;
|
"ctranslate2/shared_vocabulary.txt",
|
||||||
cache_info
|
];
|
||||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
for path in optional_files {
|
||||||
.await?;
|
cache_info
|
||||||
cache_info
|
.download(
|
||||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
model_id,
|
||||||
.await?;
|
path,
|
||||||
let _ = cache_info
|
prefer_local_file,
|
||||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
/* is_optional */ true,
|
||||||
.await;
|
)
|
||||||
let _ = cache_info
|
.await?;
|
||||||
.download(
|
}
|
||||||
model_id,
|
|
||||||
"ctranslate2/shared_vocabulary.txt",
|
|
||||||
prefer_local_file,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
cache_info
|
|
||||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
|
||||||
.await?;
|
|
||||||
cache_info.save(model_id)?;
|
|
||||||
|
|
||||||
|
let required_files = vec![
|
||||||
|
"tabby.json",
|
||||||
|
"tokenizer.json",
|
||||||
|
"ctranslate2/config.json",
|
||||||
|
"ctranslate2/model.bin",
|
||||||
|
];
|
||||||
|
for path in required_files {
|
||||||
|
cache_info
|
||||||
|
.download(
|
||||||
|
model_id,
|
||||||
|
path,
|
||||||
|
prefer_local_file,
|
||||||
|
/* required= */ false,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
cache_info.save(model_id)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result<String> {
|
async fn download_file(
|
||||||
|
url: &str,
|
||||||
|
path: &str,
|
||||||
|
local_cache_key: Option<&str>,
|
||||||
|
is_optional: bool,
|
||||||
|
) -> Result<String> {
|
||||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||||
|
|
||||||
// Reqwest setup
|
// Reqwest setup
|
||||||
let res = reqwest::get(url).await?;
|
let res = reqwest::get(url).await?;
|
||||||
|
|
||||||
|
if is_optional && res.status() == 404 {
|
||||||
|
// Cache 404 for optional file.
|
||||||
|
return Ok("404".to_owned());
|
||||||
|
}
|
||||||
|
|
||||||
if !res.status().is_success() {
|
if !res.status().is_success() {
|
||||||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue