feat: support downloading resume (#700)
parent
4d6389a9ab
commit
33ef27ba30
File diff suppressed because it is too large
Load Diff
|
|
@ -5,16 +5,8 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
indicatif = "0.17.3"
|
||||
futures-util = "0.3.28"
|
||||
reqwest = { workspace = true, features = [ "stream", "json" ] }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serdeconv = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tokio-retry = "0.3.0"
|
||||
urlencoding = "2.1.3"
|
||||
serde_json = { workspace = true }
|
||||
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
|
||||
async-trait = { workspace = true }
|
||||
sha256 = "1.4.0"
|
||||
aim = "1.8.5"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
use std::{cmp, fs, io::Write, path::Path};
|
||||
use std::{fs, path::Path};
|
||||
|
||||
use aim::bar::WrappedBar;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
|
|
@ -51,40 +50,17 @@ async fn download_model_impl(
|
|||
}
|
||||
|
||||
async fn download_file(url: &str, path: &Path) -> Result<()> {
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
let dir = path.parent().unwrap();
|
||||
fs::create_dir_all(dir)?;
|
||||
|
||||
// Reqwest setup
|
||||
let res = reqwest::get(url).await?;
|
||||
let filename = path.to_str().unwrap();
|
||||
let intermediate_filename = filename.to_owned() + ".tmp";
|
||||
|
||||
if !res.status().is_success() {
|
||||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||
}
|
||||
let mut bar = WrappedBar::new(0, url, false);
|
||||
|
||||
let total_size = res
|
||||
.content_length()
|
||||
.ok_or(anyhow!("No content length in headers"))?;
|
||||
aim::https::HTTPSHandler::get(url, &intermediate_filename, &mut bar, "").await?;
|
||||
|
||||
// Indicatif setup
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", path.display()));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(path)?;
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut stream = res.bytes_stream();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
let chunk = item?;
|
||||
file.write_all(&chunk)?;
|
||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", path.display()));
|
||||
fs::rename(intermediate_filename, filename)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::Registry;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct HuggingFaceRegistry {}
|
||||
|
||||
#[async_trait]
|
||||
impl Registry for HuggingFaceRegistry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||
format!("https://huggingface.co/{}/resolve/main/{}", model_id, path)
|
||||
}
|
||||
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||
let res = reqwest::get(url).await?;
|
||||
let cache_key = res
|
||||
.headers()
|
||||
.get("etag")
|
||||
.ok_or(anyhow!("etag key missing"))?
|
||||
.to_str()?;
|
||||
Ok(cache_key.to_owned())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
mod huggingface;
|
||||
mod modelscope;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use huggingface::HuggingFaceRegistry;
|
||||
|
||||
use self::modelscope::ModelScopeRegistry;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Registry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String;
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String>;
|
||||
}
|
||||
|
||||
pub fn create_registry() -> Box<dyn Registry> {
|
||||
let registry = std::env::var("TABBY_REGISTRY").unwrap_or("huggingface".to_owned());
|
||||
if registry == "huggingface" {
|
||||
Box::<HuggingFaceRegistry>::default()
|
||||
} else if registry == "modelscope" {
|
||||
Box::<ModelScopeRegistry>::default()
|
||||
} else {
|
||||
panic!("Unsupported registry {}", registry);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use cached::proc_macro::cached;
|
||||
use reqwest::Url;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::Registry;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ModelScopeRegistry {}
|
||||
|
||||
#[async_trait]
|
||||
impl Registry for ModelScopeRegistry {
|
||||
fn build_url(&self, model_id: &str, path: &str) -> String {
|
||||
format!(
|
||||
"https://modelscope.cn/api/v1/models/{}/repo?FilePath={}",
|
||||
model_id,
|
||||
urlencoding::encode(path)
|
||||
)
|
||||
}
|
||||
|
||||
async fn build_cache_key(&self, url: &str) -> Result<String> {
|
||||
let url = Url::parse(url)?;
|
||||
let model_id = url
|
||||
.path()
|
||||
.strip_prefix("/api/v1/models/")
|
||||
.ok_or(anyhow!("Invalid url"))?
|
||||
.strip_suffix("/repo")
|
||||
.ok_or(anyhow!("Invalid url"))?;
|
||||
|
||||
let query: HashMap<_, _> = url.query_pairs().into_owned().collect();
|
||||
let path = query
|
||||
.get("FilePath")
|
||||
.ok_or(anyhow!("Failed to extract FilePath"))?;
|
||||
|
||||
let revision_map = fetch_revision_map(model_id.to_owned()).await?;
|
||||
for x in revision_map.data.files {
|
||||
if x.path == *path {
|
||||
return Ok(x.sha256);
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("Failed to find {} in revisions", path))
|
||||
}
|
||||
}
|
||||
|
||||
#[cached(size = 1, result = true)]
|
||||
async fn fetch_revision_map(model_id: String) -> Result<ModelScopeRevision> {
|
||||
let url = format!(
|
||||
"https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true",
|
||||
model_id
|
||||
);
|
||||
let resp = reqwest::get(url)
|
||||
.await?
|
||||
.json::<ModelScopeRevision>()
|
||||
.await?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevision {
|
||||
data: ModelScopeRevisionData,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevisionData {
|
||||
files: Vec<ModelScopeRevisionFile>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ModelScopeRevisionFile {
|
||||
path: String,
|
||||
sha256: String,
|
||||
}
|
||||
Loading…
Reference in New Issue