feat: simplify download management, model file should be able to indi… (#690)

* feat: simplify download management, model file should be able to individually introduced

* fix typo

* update local model support

* update spec back

* update spec

* update

* update
release-notes-05
Meng Zhang 2023-11-02 16:01:04 -07:00 committed by GitHub
parent 0ed4289958
commit 0e4a2d2a12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 210 additions and 241 deletions

21
Cargo.lock generated
View File

@ -1151,6 +1151,12 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]] [[package]]
name = "htmlescape" name = "htmlescape"
version = "0.3.1" version = "0.3.1"
@ -2646,6 +2652,19 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha256"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7895c8ae88588ccead14ff438b939b0c569cd619116f14b4d13fdff7b8333386"
dependencies = [
"async-trait",
"bytes",
"hex",
"sha2",
"tokio",
]
[[package]] [[package]]
name = "sharded-slab" name = "sharded-slab"
version = "0.1.4" version = "0.1.4"
@ -2855,6 +2874,7 @@ dependencies = [
name = "tabby-common" name = "tabby-common"
version = "0.5.0-dev" version = "0.5.0-dev"
dependencies = [ dependencies = [
"anyhow",
"chrono", "chrono",
"filenamify", "filenamify",
"lazy_static", "lazy_static",
@ -2880,6 +2900,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serdeconv", "serdeconv",
"sha256",
"tabby-common", "tabby-common",
"tokio-retry", "tokio-retry",
"tracing", "tracing",

View File

@ -1,5 +1,7 @@
# Tabby Model Specification (Unstable) # Tabby Model Specification (Unstable)
> [!WARNING] **Since v0.5.0** This document is intended exclusively for local models. For remote models, we rely on the `tabby-registry` repository within each organization or user. You can refer to https://github.com/TabbyML/registry-tabby/blob/main/models.json for an example.
Tabby organizes the model within a directory. This document provides an explanation of the necessary contents for supporting model serving. An example model directory can be found at https://huggingface.co/TabbyML/StarCoder-1B Tabby organizes the model within a directory. This document provides an explanation of the necessary contents for supporting model serving. An example model directory can be found at https://huggingface.co/TabbyML/StarCoder-1B
The minimal Tabby model directory should include the following contents: The minimal Tabby model directory should include the following contents:

View File

@ -14,6 +14,7 @@ reqwest = { workspace = true, features = [ "json" ] }
tokio = { workspace = true, features = ["rt", "macros"] } tokio = { workspace = true, features = ["rt", "macros"] }
uuid = { version = "1.4.1", features = ["v4"] } uuid = { version = "1.4.1", features = ["v4"] }
tantivy.workspace = true tantivy.workspace = true
anyhow.workspace = true
[features] [features]
testutils = [] testutils = []

View File

@ -3,6 +3,7 @@ pub mod events;
pub mod index; pub mod index;
pub mod languages; pub mod languages;
pub mod path; pub mod path;
pub mod registry;
pub mod usage; pub mod usage;
use std::{ use std::{

View File

@ -51,38 +51,4 @@ pub fn events_dir() -> PathBuf {
tabby_root().join("events") tabby_root().join("events")
} }
pub struct ModelDir(PathBuf); mod registry {}
impl ModelDir {
pub fn new(model: &str) -> Self {
Self(models_dir().join(model))
}
pub fn from(path: &str) -> Self {
Self(PathBuf::from(path))
}
pub fn path(&self) -> &PathBuf {
&self.0
}
pub fn path_string(&self, name: &str) -> String {
self.0.join(name).display().to_string()
}
pub fn cache_info_file(&self) -> String {
self.path_string(".cache_info.json")
}
pub fn metadata_file(&self) -> String {
self.path_string("tabby.json")
}
pub fn ggml_q8_0_file(&self) -> String {
self.path_string("ggml/q8_0.gguf")
}
pub fn ggml_q8_0_v2_file(&self) -> String {
self.path_string("ggml/q8_0.v2.gguf")
}
}

View File

@ -0,0 +1,85 @@
use std::{fs, path::PathBuf};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::path::models_dir;
#[derive(Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
pub urls: Vec<String>,
pub sha256: String,
}
fn models_json_file(registry: &str) -> PathBuf {
models_dir().join(registry).join("models.json")
}
async fn load_remote_registry(registry: &str) -> Result<Vec<ModelInfo>> {
let value = reqwest::get(format!(
"https://raw.githubusercontent.com/{}/registry-tabby/main/models.json",
registry
))
.await?
.json()
.await?;
fs::create_dir_all(models_dir().join(registry))?;
serdeconv::to_json_file(&value, models_json_file(registry))?;
Ok(value)
}
fn load_local_registry(registry: &str) -> Result<Vec<ModelInfo>> {
Ok(serdeconv::from_json_file(models_json_file(registry))?)
}
#[derive(Default)]
pub struct ModelRegistry {
pub name: String,
pub models: Vec<ModelInfo>,
}
impl ModelRegistry {
pub async fn new(registry: &str) -> Self {
Self {
name: registry.to_owned(),
models: load_remote_registry(registry).await.unwrap_or_else(|err| {
load_local_registry(registry).unwrap_or_else(|_| {
panic!(
"Failed to fetch model organization <{}>: {:?}",
registry, err
)
})
}),
}
}
pub fn get_model_path(&self, name: &str) -> PathBuf {
models_dir()
.join(&self.name)
.join(name)
.join(GGML_MODEL_RELATIVE_PATH)
}
pub fn get_model_info(&self, name: &str) -> &ModelInfo {
self.models
.iter()
.find(|x| x.name == name)
.unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name))
}
}
pub fn parse_model_id(model_id: &str) -> (&str, &str) {
let parts: Vec<_> = model_id.split('/').collect();
if parts.len() != 2 {
panic!("Invalid model id {}", model_id);
}
(parts[0], parts[1])
}
pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf";

View File

@ -17,3 +17,4 @@ urlencoding = "2.1.3"
serde_json = { workspace = true } serde_json = { workspace = true }
cached = { version = "0.46.0", features = ["async", "proc_macro"] } cached = { version = "0.46.0", features = ["async", "proc_macro"] }
async-trait = { workspace = true } async-trait = { workspace = true }
sha256 = "1.4.0"

View File

@ -1,46 +0,0 @@
use std::{collections::HashMap, fs, path::Path};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tabby_common::path::ModelDir;
#[derive(Serialize, Deserialize)]
pub struct CacheInfo {
etags: HashMap<String, String>,
}
impl CacheInfo {
pub async fn from(model_id: &str) -> CacheInfo {
if let Some(cache_info) = Self::from_local(model_id) {
cache_info
} else {
CacheInfo {
etags: HashMap::new(),
}
}
}
fn from_local(model_id: &str) -> Option<CacheInfo> {
let cache_info_file = ModelDir::new(model_id).cache_info_file();
if fs::metadata(&cache_info_file).is_ok() {
serdeconv::from_json_file(cache_info_file).ok()
} else {
None
}
}
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
self.etags.get(path).map(|x| x.as_str())
}
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
self.etags.insert(path.to_string(), cache_key.to_string());
}
pub fn save(&self, model_id: &str) -> Result<()> {
let cache_info_file = ModelDir::new(model_id).cache_info_file();
let cache_info_file_path = Path::new(&cache_info_file);
serdeconv::to_json_file(self, cache_info_file_path)?;
Ok(())
}
}

View File

@ -1,126 +1,65 @@
mod cache_info;
mod registry;
use std::{cmp, fs, io::Write, path::Path}; use std::{cmp, fs, io::Write, path::Path};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use cache_info::CacheInfo;
use futures_util::StreamExt; use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle}; use indicatif::{ProgressBar, ProgressStyle};
use registry::{create_registry, Registry}; use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_common::path::ModelDir;
use tokio_retry::{ use tokio_retry::{
strategy::{jitter, ExponentialBackoff}, strategy::{jitter, ExponentialBackoff},
Retry, Retry,
}; };
use tracing::{info, warn};
pub struct Downloader { async fn download_model_impl(
model_id: String, registry: &ModelRegistry,
name: &str,
prefer_local_file: bool, prefer_local_file: bool,
registry: Box<dyn Registry>, ) -> Result<()> {
} let model_info = registry.get_model_info(name);
let model_path = registry.get_model_path(name);
if model_path.exists() {
if !prefer_local_file {
info!("Checking model integrity..");
let checksum = sha256::try_digest(&model_path).unwrap();
if checksum == model_info.sha256 {
return Ok(());
}
impl Downloader { warn!(
pub fn new(model_id: &str, prefer_local_file: bool) -> Self { "Checksum doesn't match for <{}/{}>, re-downloading...",
Self { registry.name, name
model_id: model_id.to_owned(), );
prefer_local_file, fs::remove_file(&model_path)?;
registry: create_registry(), } else {
}
}
pub async fn download_ggml_files(&self) -> Result<()> {
let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)];
self.download_files(&files).await
}
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
// Local path, no need for downloading.
if fs::metadata(&self.model_id).is_ok() {
return Ok(()); return Ok(());
} }
let mut cache_info = CacheInfo::from(&self.model_id).await;
for (path, required) in files {
download_model_file(
self.registry.as_ref(),
&mut cache_info,
&self.model_id,
path,
self.prefer_local_file,
*required,
)
.await?;
}
Ok(())
} }
}
async fn download_model_file( let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
registry: &dyn Registry, let Some(model_url) = model_info.urls.iter().find(|x| x.contains(&registry)) else {
cache_info: &mut CacheInfo, return Err(anyhow!(
model_id: &str, "Invalid mirror <{}> for model urls: {:?}",
path: &str, registry,
prefer_local_file: bool, model_info.urls
required: bool, ));
) -> Result<()> { };
// Create url.
let url = registry.build_url(model_id, path);
// Create destination path.
let filepath = ModelDir::new(model_id).path_string(path);
// Get cache key.
let local_cache_key = cache_info.local_cache_key(path);
// Check local file ready.
let local_cache_key = local_cache_key
// local cache key is only valid if == 404 or local file exists.
// FIXME(meng): use sha256 to validate file is ready.
.filter(|&local_cache_key| local_cache_key == "404" || fs::metadata(&filepath).is_ok());
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let download_job = Retry::spawn(strategy, || { let download_job = Retry::spawn(strategy, || download_file(model_url, model_path.as_path()));
download_file(registry, &url, &filepath, local_cache_key, !required) download_job.await?;
});
if let Ok(etag) = download_job.await {
cache_info.set_local_cache_key(path, &etag).await;
} else if prefer_local_file && local_cache_key.is_some() {
// Do nothing.
} else {
return Err(anyhow!("Failed to fetch url {}", url));
}
cache_info.save(model_id)?;
Ok(()) Ok(())
} }
async fn download_file( async fn download_file(url: &str, path: &Path) -> Result<()> {
registry: &dyn Registry, fs::create_dir_all(path.parent().unwrap())?;
url: &str,
path: &str,
local_cache_key: Option<&str>,
is_optional: bool,
) -> Result<String> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup // Reqwest setup
let res = reqwest::get(url).await?; let res = reqwest::get(url).await?;
if is_optional && res.status() == 404 {
// Cache 404 for optional file.
return Ok("404".to_owned());
}
if !res.status().is_success() { if !res.status().is_success() {
return Err(anyhow!(format!("Invalid url: {}", url))); return Err(anyhow!(format!("Invalid url: {}", url)));
} }
let remote_cache_key = registry.build_cache_key(url).await?;
if local_cache_key == Some(remote_cache_key.as_str()) {
return Ok(remote_cache_key);
}
let total_size = res let total_size = res
.content_length() .content_length()
.ok_or(anyhow!("No content length in headers"))?; .ok_or(anyhow!("No content length in headers"))?;
@ -130,7 +69,7 @@ async fn download_file(
pb.set_style(ProgressStyle::default_bar() pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-")); .progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path)); pb.set_message(format!("Downloading {}", path.display()));
// download chunks // download chunks
let mut file = fs::File::create(path)?; let mut file = fs::File::create(path)?;
@ -145,6 +84,17 @@ async fn download_file(
pb.set_position(new); pb.set_position(new);
} }
pb.finish_with_message(format!("Downloaded {}", path)); pb.finish_with_message(format!("Downloaded {}", path.display()));
Ok(remote_cache_key) Ok(())
}
pub async fn download_model(model_id: &str, prefer_local_file: bool) {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
download_model_impl(&registry, name, prefer_local_file)
.await
.unwrap_or_else(handler)
} }

View File

@ -1,9 +1,7 @@
use clap::Args; use clap::Args;
use tabby_download::Downloader; use tabby_download::download_model;
use tracing::info; use tracing::info;
use crate::fatal;
#[derive(Args)] #[derive(Args)]
pub struct DownloadArgs { pub struct DownloadArgs {
/// model id to fetch. /// model id to fetch.
@ -16,12 +14,6 @@ pub struct DownloadArgs {
} }
pub async fn main(args: &DownloadArgs) { pub async fn main(args: &DownloadArgs) {
let downloader = Downloader::new(&args.model, args.prefer_local_file); download_model(&args.model, args.prefer_local_file).await;
downloader
.download_ggml_files()
.await
.unwrap_or_else(|err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err));
info!("model '{}' is ready", args.model); info!("model '{}' is ready", args.model);
} }

View File

@ -1,28 +1,39 @@
use std::path::Path; use std::{fs, path::PathBuf};
use serde::Deserialize; use serde::Deserialize;
use tabby_common::path::ModelDir; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_inference::TextGeneration; use tabby_inference::TextGeneration;
use crate::fatal; use crate::fatal;
pub fn create_engine( pub async fn create_engine(
model: &str, model_id: &str,
args: &crate::serve::ServeArgs, args: &crate::serve::ServeArgs,
) -> (Box<dyn TextGeneration>, EngineInfo) { ) -> (Box<dyn TextGeneration>, EngineInfo) {
if args.device != super::Device::ExperimentalHttp { if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model); if fs::metadata(model_id).is_ok() {
let metadata = read_metadata(&model_dir); let path = PathBuf::from(model_id);
let engine = create_ggml_engine(&args.device, &model_dir); let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
( let engine =
engine, create_ggml_engine(&args.device, model_path.display().to_string().as_str());
EngineInfo { let engine_info = EngineInfo::read(path.join("tabby.json"));
prompt_template: metadata.prompt_template, (engine, engine_info)
chat_template: metadata.chat_template, } else {
}, let (registry, name) = parse_model_id(model_id);
) let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path);
(
engine,
EngineInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
}
} else { } else {
let (engine, prompt_template) = http_api_bindings::create(model); let (engine, prompt_template) = http_api_bindings::create(model_id);
( (
engine, engine,
EngineInfo { EngineInfo {
@ -33,38 +44,25 @@ pub fn create_engine(
} }
} }
#[derive(Deserialize)]
pub struct EngineInfo { pub struct EngineInfo {
pub prompt_template: Option<String>, pub prompt_template: Option<String>,
pub chat_template: Option<String>, pub chat_template: Option<String>,
} }
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> { impl EngineInfo {
fn read(filepath: PathBuf) -> EngineInfo {
serdeconv::from_json_file(&filepath)
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
}
}
fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_v2_file()) .model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu()) .use_gpu(device.ggml_use_gpu())
.build() .build()
.unwrap(); .unwrap();
Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options)) Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
} }
fn get_model_dir(model: &str) -> ModelDir {
if Path::new(model).exists() {
ModelDir::from(model)
} else {
ModelDir::new(model)
}
}
#[derive(Deserialize)]
struct Metadata {
#[allow(dead_code)]
auto_model: String,
prompt_template: Option<String>,
chat_template: Option<String>,
}
fn read_metadata(model_dir: &ModelDir) -> Metadata {
serdeconv::from_json_file(model_dir.metadata_file())
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
}

View File

@ -7,6 +7,7 @@ mod search;
mod ui; mod ui;
use std::{ use std::{
fs,
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@ -16,7 +17,7 @@ use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args; use clap::Args;
use tabby_common::{config::Config, usage}; use tabby_common::{config::Config, usage};
use tabby_download::Downloader; use tabby_download::download_model;
use tokio::time::sleep; use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::{info, warn}; use tracing::{info, warn};
@ -129,9 +130,13 @@ pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args); valid_args(args);
if args.device != Device::ExperimentalHttp { if args.device != Device::ExperimentalHttp {
download_model(&args.model).await; if fs::metadata(&args.model).is_ok() {
if let Some(chat_model) = &args.chat_model { info!("Loading model from local path {}", &args.model);
download_model(chat_model).await; } else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
} }
} else { } else {
warn!("HTTP device is unstable and does not comply with semver expectations.") warn!("HTTP device is unstable and does not comply with semver expectations.")
@ -144,7 +149,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
let app = Router::new() let app = Router::new()
.route("/", routing::get(ui::handler)) .route("/", routing::get(ui::handler))
.merge(api_router(args, config)) .merge(api_router(args, config).await)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.fallback(ui::handler); .fallback(ui::handler);
@ -165,7 +170,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
} }
fn api_router(args: &ServeArgs, config: &Config) -> Router { async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let index_server = Arc::new(IndexServer::new()); let index_server = Arc::new(IndexServer::new());
let completion_state = { let completion_state = {
let ( let (
@ -173,7 +178,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
EngineInfo { EngineInfo {
prompt_template, .. prompt_template, ..
}, },
) = create_engine(&args.model, args); ) = create_engine(&args.model, args).await;
let engine = Arc::new(engine); let engine = Arc::new(engine);
let state = completions::CompletionState::new( let state = completions::CompletionState::new(
engine.clone(), engine.clone(),
@ -184,7 +189,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
}; };
let chat_state = if let Some(chat_model) = &args.chat_model { let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args); let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await;
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
@ -262,13 +267,6 @@ fn start_heartbeat(args: &ServeArgs) {
}); });
} }
async fn download_model(model: &str) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = downloader.download_ggml_files().await;
download_result.unwrap_or_else(handler);
}
trait OpenApiOverride { trait OpenApiOverride {
fn override_doc(&mut self, args: &ServeArgs); fn override_doc(&mut self, args: &ServeArgs);
} }