tabby/crates/tabby-download/src/lib.rs

179 lines
4.8 KiB
Rust
Raw Normal View History

mod cache_info;
mod registry;
use std::{cmp, fs, io::Write, path::Path};
use anyhow::{anyhow, Result};
use cache_info::CacheInfo;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use registry::{create_registry, Registry};
use tabby_common::path::ModelDir;
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
pub struct Downloader {
model_id: String,
prefer_local_file: bool,
registry: Box<dyn Registry>,
}
impl Downloader {
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
Self {
model_id: model_id.to_owned(),
prefer_local_file,
registry: create_registry(),
}
}
pub async fn download_ctranslate2_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ctranslate2/vocabulary.txt", false),
("ctranslate2/shared_vocabulary.txt", false),
("ctranslate2/vocabulary.json", false),
("ctranslate2/shared_vocabulary.json", false),
("ctranslate2/config.json", true),
("ctranslate2/model.bin", true),
];
self.download_files(&files).await
}
pub async fn download_ggml_files(&self) -> Result<()> {
let files = vec![
("tabby.json", true),
("tokenizer.json", true),
("ggml/q8_0.gguf", true),
];
self.download_files(&files).await
}
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
// Local path, no need for downloading.
if fs::metadata(&self.model_id).is_ok() {
return Ok(());
}
let mut cache_info = CacheInfo::from(&self.model_id).await;
for (path, required) in files {
download_model_file(
self.registry.as_ref(),
&mut cache_info,
&self.model_id,
path,
self.prefer_local_file,
*required,
)
.await?;
}
Ok(())
}
}
async fn download_model_file(
registry: &dyn Registry,
cache_info: &mut CacheInfo,
model_id: &str,
path: &str,
prefer_local_file: bool,
required: bool,
) -> Result<()> {
// Create url.
let url = registry.build_url(model_id, path);
// Get cache key.
let local_cache_key = cache_info.local_cache_key(path);
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(path);
// Cache hit.
let local_file_ready = if prefer_local_file {
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == "404" {
true
} else {
fs::metadata(&filepath).is_ok()
}
} else {
false
}
} else {
false
};
if !local_file_ready {
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let etag = Retry::spawn(strategy, || {
download_file(registry, &url, &filepath, local_cache_key, !required)
})
.await?;
cache_info.set_local_cache_key(path, &etag).await;
}
cache_info.save(model_id)?;
Ok(())
}
async fn download_file(
registry: &dyn Registry,
url: &str,
path: &str,
local_cache_key: Option<&str>,
is_optional: bool,
) -> Result<String> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
let res = reqwest::get(url).await?;
if is_optional && res.status() == 404 {
// Cache 404 for optional file.
return Ok("404".to_owned());
}
if !res.status().is_success() {
return Err(anyhow!(format!("Invalid url: {}", url)));
}
let remote_cache_key = registry.build_cache_key(url).await?;
if let Some(local_cache_key) = local_cache_key {
if local_cache_key == remote_cache_key {
return Ok(remote_cache_key);
}
}
let total_size = res
.content_length()
.ok_or(anyhow!("No content length in headers"))?;
// Indicatif setup
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path));
// download chunks
let mut file = fs::File::create(path)?;
let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk)?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", path));
Ok(remote_cache_key)
}