feat: support downloading resume (#700)

r0.5
Meng Zhang 2023-11-04 19:38:06 -07:00
parent 36d13d2837
commit b4fe249636
6 changed files with 1607 additions and 307 deletions

1734
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,16 +5,8 @@ edition = "2021"
[dependencies]
tabby-common = { path = "../tabby-common" }
indicatif = "0.17.3"
futures-util = "0.3.28"
reqwest = { workspace = true, features = [ "stream", "json" ] }
anyhow = { workspace = true }
serde = { workspace = true }
serdeconv = { workspace = true }
tracing = { workspace = true }
tokio-retry = "0.3.0"
urlencoding = "2.1.3"
serde_json = { workspace = true }
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
async-trait = { workspace = true }
sha256 = "1.4.0"
aim = "1.8.5"

View File

@ -1,8 +1,7 @@
use std::{cmp, fs, io::Write, path::Path};
use std::{fs, path::Path};
use aim::bar::WrappedBar;
use anyhow::{anyhow, Result};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
@ -51,40 +50,17 @@ async fn download_model_impl(
}
async fn download_file(url: &str, path: &Path) -> Result<()> {
fs::create_dir_all(path.parent().unwrap())?;
let dir = path.parent().unwrap();
fs::create_dir_all(dir)?;
// Reqwest setup
let res = reqwest::get(url).await?;
let filename = path.to_str().unwrap();
let intermediate_filename = filename.to_owned() + ".tmp";
if !res.status().is_success() {
return Err(anyhow!(format!("Invalid url: {}", url)));
}
let mut bar = WrappedBar::new(0, url, false);
let total_size = res
.content_length()
.ok_or(anyhow!("No content length in headers"))?;
aim::https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, "").await?;
// Indicatif setup
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path.display()));
// download chunks
let mut file = fs::File::create(path)?;
let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk)?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", path.display()));
fs::rename(intermediate_filename, filename)?;
Ok(())
}

View File

@ -1,24 +0,0 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use crate::Registry;
#[derive(Default)]
pub struct HuggingFaceRegistry {}
#[async_trait]
impl Registry for HuggingFaceRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
}
async fn build_cache_key(&self, url: &str) -> Result<String> {
let res = reqwest::get(url).await?;
let cache_key = res
.headers()
.get("etag")
.ok_or(anyhow!("etag key missing"))?
.to_str()?;
Ok(cache_key.to_owned())
}
}

View File

@ -1,25 +0,0 @@
mod huggingface;
mod modelscope;
use anyhow::Result;
use async_trait::async_trait;
use huggingface::HuggingFaceRegistry;
use self::modelscope::ModelScopeRegistry;
#[async_trait]
pub trait Registry {
fn build_url(&self, model_id: &str, path: &str) -> String;
async fn build_cache_key(&self, url: &str) -> Result<String>;
}
pub fn create_registry() -> Box<dyn Registry> {
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
if registry == "huggingface" {
Box::<HuggingFaceRegistry>::default()
} else if registry == "modelscope" {
Box::<ModelScopeRegistry>::default()
} else {
panic!("Unsupported registry {}", registry);
}
}

View File

@ -1,79 +0,0 @@
use std::collections::HashMap;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cached::proc_macro::cached;
use reqwest::Url;
use serde::Deserialize;
use crate::Registry;
#[derive(Default)]
pub struct ModelScopeRegistry {}
#[async_trait]
impl Registry for ModelScopeRegistry {
fn build_url(&self, model_id: &str, path: &str) -> String {
format!(
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
model_id,
urlencoding::encode(path)
)
}
async fn build_cache_key(&self, url: &str) -> Result<String> {
let url = Url::parse(url)?;
let model_id = url
.path()
.strip_prefix("/api/v1/models/")
.ok_or(anyhow!("Invalid url"))?
.strip_suffix("/repo")
.ok_or(anyhow!("Invalid url"))?;
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
let path = query
.get("FilePath")
.ok_or(anyhow!("Failed to extract FilePath"))?;
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
for x in revision_map.data.files {
if x.path == *path {
return Ok(x.sha256);
}
}
Err(anyhow!("Failed to find {} in revisions", path))
}
}
#[cached(size = 1, result = true)]
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
let url = format!(
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
model_id
);
let resp = reqwest::get(url)
.await?
.json::<ModelScopeRevision>()
.await?;
Ok(resp)
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevision {
data: ModelScopeRevisionData,
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevisionData {
files: Vec<ModelScopeRevisionFile>,
}
#[derive(Deserialize, Clone)]
#[serde(rename_all = "PascalCase")]
struct ModelScopeRevisionFile {
path: String,
sha256: String,
}