feat: support downloading resume (#700)
parent
36d13d2837
commit
b4fe249636
File diff suppressed because it is too large
Load Diff
|
|
@ -5,16 +5,8 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tabby-common = { path = "../tabby-common" }
|
tabby-common = { path = "../tabby-common" }
|
||||||
indicatif = "0.17.3"
|
|
||||||
futures-util = "0.3.28"
|
|
||||||
reqwest = { workspace = true, features = [ "stream", "json" ] }
|
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
serde = { workspace = true }
|
|
||||||
serdeconv = { workspace = true }
|
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tokio-retry = "0.3.0"
|
tokio-retry = "0.3.0"
|
||||||
urlencoding = "2.1.3"
|
|
||||||
serde_json = { workspace = true }
|
|
||||||
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
|
|
||||||
async-trait = { workspace = true }
|
|
||||||
sha256 = "1.4.0"
|
sha256 = "1.4.0"
|
||||||
|
aim = "1.8.5"
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
use std::{cmp, fs, io::Write, path::Path};
|
use std::{fs, path::Path};
|
||||||
|
|
||||||
|
use aim::bar::WrappedBar;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use futures_util::StreamExt;
|
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
|
||||||
use tabby_common::registry::{parse_model_id, ModelRegistry};
|
use tabby_common::registry::{parse_model_id, ModelRegistry};
|
||||||
use tokio_retry::{
|
use tokio_retry::{
|
||||||
strategy::{jitter, ExponentialBackoff},
|
strategy::{jitter, ExponentialBackoff},
|
||||||
|
|
@ -51,40 +50,17 @@ async fn download_model_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(url: &str, path: &Path) -> Result<()> {
|
async fn download_file(url: &str, path: &Path) -> Result<()> {
|
||||||
fs::create_dir_all(path.parent().unwrap())?;
|
let dir = path.parent().unwrap();
|
||||||
|
fs::create_dir_all(dir)?;
|
||||||
|
|
||||||
// Reqwest setup
|
let filename = path.to_str().unwrap();
|
||||||
let res = reqwest::get(url).await?;
|
let intermediate_filename = filename.to_owned() + ".tmp";
|
||||||
|
|
||||||
if !res.status().is_success() {
|
let mut bar = WrappedBar::new(0, url, false);
|
||||||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let total_size = res
|
aim::https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, "").await?;
|
||||||
.content_length()
|
|
||||||
.ok_or(anyhow!("No content length in headers"))?;
|
|
||||||
|
|
||||||
// Indicatif setup
|
fs::rename(intermediate_filename, filename)?;
|
||||||
let pb = ProgressBar::new(total_size);
|
|
||||||
pb.set_style(ProgressStyle::default_bar()
|
|
||||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
|
||||||
.progress_chars("#>-"));
|
|
||||||
pb.set_message(format!("Downloading {}", path.display()));
|
|
||||||
|
|
||||||
// download chunks
|
|
||||||
let mut file = fs::File::create(path)?;
|
|
||||||
let mut downloaded: u64 = 0;
|
|
||||||
let mut stream = res.bytes_stream();
|
|
||||||
|
|
||||||
while let Some(item) = stream.next().await {
|
|
||||||
let chunk = item?;
|
|
||||||
file.write_all(&chunk)?;
|
|
||||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
|
||||||
downloaded = new;
|
|
||||||
pb.set_position(new);
|
|
||||||
}
|
|
||||||
|
|
||||||
pb.finish_with_message(format!("Downloaded {}", path.display()));
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,24 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
use crate::Registry;
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct HuggingFaceRegistry {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Registry for HuggingFaceRegistry {
|
|
||||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
|
||||||
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
|
||||||
let res = reqwest::get(url).await?;
|
|
||||||
let cache_key = res
|
|
||||||
.headers()
|
|
||||||
.get("etag")
|
|
||||||
.ok_or(anyhow!("etag key missing"))?
|
|
||||||
.to_str()?;
|
|
||||||
Ok(cache_key.to_owned())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
mod huggingface;
|
|
||||||
mod modelscope;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use huggingface::HuggingFaceRegistry;
|
|
||||||
|
|
||||||
use self::modelscope::ModelScopeRegistry;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Registry {
|
|
||||||
fn build_url(&self, model_id: &str, path: &str) -> String;
|
|
||||||
async fn build_cache_key(&self, url: &str) -> Result<String>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_registry() -> Box<dyn Registry> {
|
|
||||||
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
|
|
||||||
if registry == "huggingface" {
|
|
||||||
Box::<HuggingFaceRegistry>::default()
|
|
||||||
} else if registry == "modelscope" {
|
|
||||||
Box::<ModelScopeRegistry>::default()
|
|
||||||
} else {
|
|
||||||
panic!("Unsupported registry {}", registry);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,79 +0,0 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use cached::proc_macro::cached;
|
|
||||||
use reqwest::Url;
|
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
use crate::Registry;
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct ModelScopeRegistry {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Registry for ModelScopeRegistry {
|
|
||||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
|
|
||||||
model_id,
|
|
||||||
urlencoding::encode(path)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
|
||||||
let url = Url::parse(url)?;
|
|
||||||
let model_id = url
|
|
||||||
.path()
|
|
||||||
.strip_prefix("/api/v1/models/")
|
|
||||||
.ok_or(anyhow!("Invalid url"))?
|
|
||||||
.strip_suffix("/repo")
|
|
||||||
.ok_or(anyhow!("Invalid url"))?;
|
|
||||||
|
|
||||||
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
|
|
||||||
let path = query
|
|
||||||
.get("FilePath")
|
|
||||||
.ok_or(anyhow!("Failed to extract FilePath"))?;
|
|
||||||
|
|
||||||
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
|
|
||||||
for x in revision_map.data.files {
|
|
||||||
if x.path == *path {
|
|
||||||
return Ok(x.sha256);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(anyhow!("Failed to find {} in revisions", path))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cached(size = 1, result = true)]
|
|
||||||
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
|
|
||||||
let url = format!(
|
|
||||||
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
|
|
||||||
model_id
|
|
||||||
);
|
|
||||||
let resp = reqwest::get(url)
|
|
||||||
.await?
|
|
||||||
.json::<ModelScopeRevision>()
|
|
||||||
.await?;
|
|
||||||
Ok(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Clone)]
|
|
||||||
#[serde(rename_all = "PascalCase")]
|
|
||||||
struct ModelScopeRevision {
|
|
||||||
data: ModelScopeRevisionData,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Clone)]
|
|
||||||
#[serde(rename_all = "PascalCase")]
|
|
||||||
struct ModelScopeRevisionData {
|
|
||||||
files: Vec<ModelScopeRevisionFile>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Clone)]
|
|
||||||
#[serde(rename_all = "PascalCase")]
|
|
||||||
struct ModelScopeRevisionFile {
|
|
||||||
path: String,
|
|
||||||
sha256: String,
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue