77 lines
2.3 KiB
Rust
77 lines
2.3 KiB
Rust
use std::{fs, path::Path};
|
|
|
|
use aim::bar::WrappedBar;
|
|
use anyhow::{anyhow, Result};
|
|
use tabby_common::registry::{parse_model_id, ModelRegistry};
|
|
use tokio_retry::{
|
|
strategy::{jitter, ExponentialBackoff},
|
|
Retry,
|
|
};
|
|
use tracing::{info, warn};
|
|
|
|
async fn download_model_impl(
|
|
registry: &ModelRegistry,
|
|
name: &str,
|
|
prefer_local_file: bool,
|
|
) -> Result<()> {
|
|
let model_info = registry.get_model_info(name);
|
|
let model_path = registry.get_model_path(name);
|
|
if model_path.exists() {
|
|
if !prefer_local_file {
|
|
info!("Checking model integrity..");
|
|
let checksum = sha256::try_digest(&model_path).unwrap();
|
|
if checksum == model_info.sha256 {
|
|
return Ok(());
|
|
}
|
|
|
|
warn!(
|
|
"Checksum doesn't match for <{}/{}>, re-downloading...",
|
|
registry.name, name
|
|
);
|
|
fs::remove_file(&model_path)?;
|
|
} else {
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
|
|
let Some(model_url) = model_info.urls.iter().find(|x| x.contains(®istry)) else {
|
|
return Err(anyhow!(
|
|
"Invalid mirror <{}> for model urls: {:?}",
|
|
registry,
|
|
model_info.urls
|
|
));
|
|
};
|
|
|
|
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
|
|
let download_job = Retry::spawn(strategy, || download_file(model_url, model_path.as_path()));
|
|
download_job.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn download_file(url: &str, path: &Path) -> Result<()> {
|
|
let dir = path.parent().unwrap();
|
|
fs::create_dir_all(dir)?;
|
|
|
|
let filename = path.to_str().unwrap();
|
|
let intermediate_filename = filename.to_owned() + ".tmp";
|
|
|
|
let mut bar = WrappedBar::new(0, url, false);
|
|
|
|
aim::https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, "").await?;
|
|
|
|
fs::rename(intermediate_filename, filename)?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
|
let (registry, name) = parse_model_id(model_id);
|
|
|
|
let registry = ModelRegistry::new(registry).await;
|
|
|
|
let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
|
|
download_model_impl(®istry, name, prefer_local_file)
|
|
.await
|
|
.unwrap_or_else(handler)
|
|
}
|