feat: simplify download management, model file should be able to indi… (#690)
* feat: simplify download management, model file should be able to individually introduced * fix typo * update local model support * update spec back * update spec * update * updaterelease-notes-05
parent
0ed4289958
commit
0e4a2d2a12
|
|
@ -1151,6 +1151,12 @@ version = "0.3.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||
|
||||
[[package]]
|
||||
name = "hex"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "htmlescape"
|
||||
version = "0.3.1"
|
||||
|
|
@ -2646,6 +2652,19 @@ dependencies = [
|
|||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha256"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7895c8ae88588ccead14ff438b939b0c569cd619116f14b4d13fdff7b8333386"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"hex",
|
||||
"sha2",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.4"
|
||||
|
|
@ -2855,6 +2874,7 @@ dependencies = [
|
|||
name = "tabby-common"
|
||||
version = "0.5.0-dev"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"filenamify",
|
||||
"lazy_static",
|
||||
|
|
@ -2880,6 +2900,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"serdeconv",
|
||||
"sha256",
|
||||
"tabby-common",
|
||||
"tokio-retry",
|
||||
"tracing",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
# Tabby Model Specification (Unstable)
|
||||
|
||||
> [!WARNING] **Since v0.5.0** This document is intended exclusively for local models. For remote models, we rely on the `tabby-registry` repository within each organization or user. You can refer to https://github.com/TabbyML/registry-tabby/blob/main/models.json for an example.
|
||||
|
||||
Tabby organizes the model within a directory. This document provides an explanation of the necessary contents for supporting model serving. An example model directory can be found at https://huggingface.co/TabbyML/StarCoder-1B
|
||||
|
||||
The minimal Tabby model directory should include the following contents:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ reqwest = { workspace = true, features = [ "json" ] }
|
|||
tokio = { workspace = true, features = ["rt", "macros"] }
|
||||
uuid = { version = "1.4.1", features = ["v4"] }
|
||||
tantivy.workspace = true
|
||||
anyhow.workspace = true
|
||||
|
||||
[features]
|
||||
testutils = []
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ pub mod events;
|
|||
pub mod index;
|
||||
pub mod languages;
|
||||
pub mod path;
|
||||
pub mod registry;
|
||||
pub mod usage;
|
||||
|
||||
use std::{
|
||||
|
|
|
|||
|
|
@ -51,38 +51,4 @@ pub fn events_dir() -> PathBuf {
|
|||
tabby_root().join("events")
|
||||
}
|
||||
|
||||
pub struct ModelDir(PathBuf);
|
||||
|
||||
impl ModelDir {
|
||||
pub fn new(model: &str) -> Self {
|
||||
Self(models_dir().join(model))
|
||||
}
|
||||
|
||||
pub fn from(path: &str) -> Self {
|
||||
Self(PathBuf::from(path))
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn path_string(&self, name: &str) -> String {
|
||||
self.0.join(name).display().to_string()
|
||||
}
|
||||
|
||||
pub fn cache_info_file(&self) -> String {
|
||||
self.path_string(".cache_info.json")
|
||||
}
|
||||
|
||||
pub fn metadata_file(&self) -> String {
|
||||
self.path_string("tabby.json")
|
||||
}
|
||||
|
||||
pub fn ggml_q8_0_file(&self) -> String {
|
||||
self.path_string("ggml/q8_0.gguf")
|
||||
}
|
||||
|
||||
pub fn ggml_q8_0_v2_file(&self) -> String {
|
||||
self.path_string("ggml/q8_0.v2.gguf")
|
||||
}
|
||||
}
|
||||
mod registry {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,85 @@
|
|||
use std::{fs, path::PathBuf};
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::path::models_dir;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_template: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template: Option<String>,
|
||||
pub urls: Vec<String>,
|
||||
pub sha256: String,
|
||||
}
|
||||
|
||||
fn models_json_file(registry: &str) -> PathBuf {
|
||||
models_dir().join(registry).join("models.json")
|
||||
}
|
||||
|
||||
async fn load_remote_registry(registry: &str) -> Result<Vec<ModelInfo>> {
|
||||
let value = reqwest::get(format!(
|
||||
"https://raw.githubusercontent.com/{}/registry-tabby/main/models.json",
|
||||
registry
|
||||
))
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
fs::create_dir_all(models_dir().join(registry))?;
|
||||
serdeconv::to_json_file(&value, models_json_file(registry))?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn load_local_registry(registry: &str) -> Result<Vec<ModelInfo>> {
|
||||
Ok(serdeconv::from_json_file(models_json_file(registry))?)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ModelRegistry {
|
||||
pub name: String,
|
||||
pub models: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
pub async fn new(registry: &str) -> Self {
|
||||
Self {
|
||||
name: registry.to_owned(),
|
||||
models: load_remote_registry(registry).await.unwrap_or_else(|err| {
|
||||
load_local_registry(registry).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Failed to fetch model organization <{}>: {:?}",
|
||||
registry, err
|
||||
)
|
||||
})
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_model_path(&self, name: &str) -> PathBuf {
|
||||
models_dir()
|
||||
.join(&self.name)
|
||||
.join(name)
|
||||
.join(GGML_MODEL_RELATIVE_PATH)
|
||||
}
|
||||
|
||||
pub fn get_model_info(&self, name: &str) -> &ModelInfo {
|
||||
self.models
|
||||
.iter()
|
||||
.find(|x| x.name == name)
|
||||
.unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_model_id(model_id: &str) -> (&str, &str) {
|
||||
let parts: Vec<_> = model_id.split('/').collect();
|
||||
if parts.len() != 2 {
|
||||
panic!("Invalid model id {}", model_id);
|
||||
}
|
||||
|
||||
(parts[0], parts[1])
|
||||
}
|
||||
|
||||
pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf";
|
||||
|
|
@ -17,3 +17,4 @@ urlencoding = "2.1.3"
|
|||
serde_json = { workspace = true }
|
||||
cached = { version = "0.46.0", features = ["async", "proc_macro"] }
|
||||
async-trait = { workspace = true }
|
||||
sha256 = "1.4.0"
|
||||
|
|
|
|||
|
|
@ -1,46 +0,0 @@
|
|||
use std::{collections::HashMap, fs, path::Path};
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CacheInfo {
|
||||
etags: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl CacheInfo {
|
||||
pub async fn from(model_id: &str) -> CacheInfo {
|
||||
if let Some(cache_info) = Self::from_local(model_id) {
|
||||
cache_info
|
||||
} else {
|
||||
CacheInfo {
|
||||
etags: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_local(model_id: &str) -> Option<CacheInfo> {
|
||||
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||
if fs::metadata(&cache_info_file).is_ok() {
|
||||
serdeconv::from_json_file(cache_info_file).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
||||
self.etags.get(path).map(|x| x.as_str())
|
||||
}
|
||||
|
||||
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||
}
|
||||
|
||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||
let cache_info_file = ModelDir::new(model_id).cache_info_file();
|
||||
let cache_info_file_path = Path::new(&cache_info_file);
|
||||
serdeconv::to_json_file(self, cache_info_file_path)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,126 +1,65 @@
|
|||
mod cache_info;
|
||||
mod registry;
|
||||
|
||||
use std::{cmp, fs, io::Write, path::Path};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use cache_info::CacheInfo;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use registry::{create_registry, Registry};
|
||||
use tabby_common::path::ModelDir;
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry};
|
||||
use tokio_retry::{
|
||||
strategy::{jitter, ExponentialBackoff},
|
||||
Retry,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub struct Downloader {
|
||||
model_id: String,
|
||||
async fn download_model_impl(
|
||||
registry: &ModelRegistry,
|
||||
name: &str,
|
||||
prefer_local_file: bool,
|
||||
registry: Box<dyn Registry>,
|
||||
}
|
||||
) -> Result<()> {
|
||||
let model_info = registry.get_model_info(name);
|
||||
let model_path = registry.get_model_path(name);
|
||||
if model_path.exists() {
|
||||
if !prefer_local_file {
|
||||
info!("Checking model integrity..");
|
||||
let checksum = sha256::try_digest(&model_path).unwrap();
|
||||
if checksum == model_info.sha256 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
impl Downloader {
|
||||
pub fn new(model_id: &str, prefer_local_file: bool) -> Self {
|
||||
Self {
|
||||
model_id: model_id.to_owned(),
|
||||
prefer_local_file,
|
||||
registry: create_registry(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn download_ggml_files(&self) -> Result<()> {
|
||||
let files = vec![("tabby.json", true), ("ggml/q8_0.v2.gguf", true)];
|
||||
self.download_files(&files).await
|
||||
}
|
||||
|
||||
async fn download_files(&self, files: &[(&str, bool)]) -> Result<()> {
|
||||
// Local path, no need for downloading.
|
||||
if fs::metadata(&self.model_id).is_ok() {
|
||||
warn!(
|
||||
"Checksum doesn't match for <{}/{}>, re-downloading...",
|
||||
registry.name, name
|
||||
);
|
||||
fs::remove_file(&model_path)?;
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut cache_info = CacheInfo::from(&self.model_id).await;
|
||||
for (path, required) in files {
|
||||
download_model_file(
|
||||
self.registry.as_ref(),
|
||||
&mut cache_info,
|
||||
&self.model_id,
|
||||
path,
|
||||
self.prefer_local_file,
|
||||
*required,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_model_file(
|
||||
registry: &dyn Registry,
|
||||
cache_info: &mut CacheInfo,
|
||||
model_id: &str,
|
||||
path: &str,
|
||||
prefer_local_file: bool,
|
||||
required: bool,
|
||||
) -> Result<()> {
|
||||
// Create url.
|
||||
let url = registry.build_url(model_id, path);
|
||||
|
||||
// Create destination path.
|
||||
let filepath = ModelDir::new(model_id).path_string(path);
|
||||
|
||||
// Get cache key.
|
||||
let local_cache_key = cache_info.local_cache_key(path);
|
||||
|
||||
// Check local file ready.
|
||||
let local_cache_key = local_cache_key
|
||||
// local cache key is only valid if == 404 or local file exists.
|
||||
// FIXME(meng): use sha256 to validate file is ready.
|
||||
.filter(|&local_cache_key| local_cache_key == "404" || fs::metadata(&filepath).is_ok());
|
||||
let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
|
||||
let Some(model_url) = model_info.urls.iter().find(|x| x.contains(®istry)) else {
|
||||
return Err(anyhow!(
|
||||
"Invalid mirror <{}> for model urls: {:?}",
|
||||
registry,
|
||||
model_info.urls
|
||||
));
|
||||
};
|
||||
|
||||
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
|
||||
let download_job = Retry::spawn(strategy, || {
|
||||
download_file(registry, &url, &filepath, local_cache_key, !required)
|
||||
});
|
||||
if let Ok(etag) = download_job.await {
|
||||
cache_info.set_local_cache_key(path, &etag).await;
|
||||
} else if prefer_local_file && local_cache_key.is_some() {
|
||||
// Do nothing.
|
||||
} else {
|
||||
return Err(anyhow!("Failed to fetch url {}", url));
|
||||
}
|
||||
|
||||
cache_info.save(model_id)?;
|
||||
let download_job = Retry::spawn(strategy, || download_file(model_url, model_path.as_path()));
|
||||
download_job.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_file(
|
||||
registry: &dyn Registry,
|
||||
url: &str,
|
||||
path: &str,
|
||||
local_cache_key: Option<&str>,
|
||||
is_optional: bool,
|
||||
) -> Result<String> {
|
||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||
async fn download_file(url: &str, path: &Path) -> Result<()> {
|
||||
fs::create_dir_all(path.parent().unwrap())?;
|
||||
|
||||
// Reqwest setup
|
||||
let res = reqwest::get(url).await?;
|
||||
|
||||
if is_optional && res.status() == 404 {
|
||||
// Cache 404 for optional file.
|
||||
return Ok("404".to_owned());
|
||||
}
|
||||
|
||||
if !res.status().is_success() {
|
||||
return Err(anyhow!(format!("Invalid url: {}", url)));
|
||||
}
|
||||
|
||||
let remote_cache_key = registry.build_cache_key(url).await?;
|
||||
if local_cache_key == Some(remote_cache_key.as_str()) {
|
||||
return Ok(remote_cache_key);
|
||||
}
|
||||
|
||||
let total_size = res
|
||||
.content_length()
|
||||
.ok_or(anyhow!("No content length in headers"))?;
|
||||
|
|
@ -130,7 +69,7 @@ async fn download_file(
|
|||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", path));
|
||||
pb.set_message(format!("Downloading {}", path.display()));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(path)?;
|
||||
|
|
@ -145,6 +84,17 @@ async fn download_file(
|
|||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", path));
|
||||
Ok(remote_cache_key)
|
||||
pb.finish_with_message(format!("Downloaded {}", path.display()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
|
||||
let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
|
||||
download_model_impl(®istry, name, prefer_local_file)
|
||||
.await
|
||||
.unwrap_or_else(handler)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
use clap::Args;
|
||||
use tabby_download::Downloader;
|
||||
use tabby_download::download_model;
|
||||
use tracing::info;
|
||||
|
||||
use crate::fatal;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
/// model id to fetch.
|
||||
|
|
@ -16,12 +14,6 @@ pub struct DownloadArgs {
|
|||
}
|
||||
|
||||
pub async fn main(args: &DownloadArgs) {
|
||||
let downloader = Downloader::new(&args.model, args.prefer_local_file);
|
||||
|
||||
downloader
|
||||
.download_ggml_files()
|
||||
.await
|
||||
.unwrap_or_else(|err| fatal!("Failed to fetch model '{}' due to '{}'", args.model, err));
|
||||
|
||||
download_model(&args.model, args.prefer_local_file).await;
|
||||
info!("model '{}' is ready", args.model);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,28 +1,39 @@
|
|||
use std::path::Path;
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tabby_common::path::ModelDir;
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||
use tabby_inference::TextGeneration;
|
||||
|
||||
use crate::fatal;
|
||||
|
||||
pub fn create_engine(
|
||||
model: &str,
|
||||
pub async fn create_engine(
|
||||
model_id: &str,
|
||||
args: &crate::serve::ServeArgs,
|
||||
) -> (Box<dyn TextGeneration>, EngineInfo) {
|
||||
if args.device != super::Device::ExperimentalHttp {
|
||||
let model_dir = get_model_dir(model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
let engine = create_ggml_engine(&args.device, &model_dir);
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template: metadata.prompt_template,
|
||||
chat_template: metadata.chat_template,
|
||||
},
|
||||
)
|
||||
if fs::metadata(model_id).is_ok() {
|
||||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine =
|
||||
create_ggml_engine(&args.device, model_path.display().to_string().as_str());
|
||||
let engine_info = EngineInfo::read(path.join("tabby.json"));
|
||||
(engine, engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(&args.device, &model_path);
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
prompt_template: model_info.prompt_template.clone(),
|
||||
chat_template: model_info.chat_template.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
} else {
|
||||
let (engine, prompt_template) = http_api_bindings::create(model);
|
||||
let (engine, prompt_template) = http_api_bindings::create(model_id);
|
||||
(
|
||||
engine,
|
||||
EngineInfo {
|
||||
|
|
@ -33,38 +44,25 @@ pub fn create_engine(
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EngineInfo {
|
||||
pub prompt_template: Option<String>,
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
|
||||
impl EngineInfo {
|
||||
fn read(filepath: PathBuf) -> EngineInfo {
|
||||
serdeconv::from_json_file(&filepath)
|
||||
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", filepath.display()))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_dir.ggml_q8_0_v2_file())
|
||||
.model_path(model_path.to_owned())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Box::new(llama_cpp_bindings::LlamaTextGeneration::create(options))
|
||||
}
|
||||
|
||||
fn get_model_dir(model: &str) -> ModelDir {
|
||||
if Path::new(model).exists() {
|
||||
ModelDir::from(model)
|
||||
} else {
|
||||
ModelDir::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Metadata {
|
||||
#[allow(dead_code)]
|
||||
auto_model: String,
|
||||
prompt_template: Option<String>,
|
||||
chat_template: Option<String>,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
serdeconv::from_json_file(model_dir.metadata_file())
|
||||
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ mod search;
|
|||
mod ui;
|
||||
|
||||
use std::{
|
||||
fs,
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
|
|
@ -16,7 +17,7 @@ use axum::{routing, Router, Server};
|
|||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use clap::Args;
|
||||
use tabby_common::{config::Config, usage};
|
||||
use tabby_download::Downloader;
|
||||
use tabby_download::download_model;
|
||||
use tokio::time::sleep;
|
||||
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
|
||||
use tracing::{info, warn};
|
||||
|
|
@ -129,9 +130,13 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
valid_args(args);
|
||||
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
download_model(&args.model).await;
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model(chat_model).await;
|
||||
if fs::metadata(&args.model).is_ok() {
|
||||
info!("Loading model from local path {}", &args.model);
|
||||
} else {
|
||||
download_model(&args.model, true).await;
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model(chat_model, true).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("HTTP device is unstable and does not comply with semver expectations.")
|
||||
|
|
@ -144,7 +149,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/", routing::get(ui::handler))
|
||||
.merge(api_router(args, config))
|
||||
.merge(api_router(args, config).await)
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
||||
.fallback(ui::handler);
|
||||
|
||||
|
|
@ -165,7 +170,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||
}
|
||||
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
let index_server = Arc::new(IndexServer::new());
|
||||
let completion_state = {
|
||||
let (
|
||||
|
|
@ -173,7 +178,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
EngineInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = create_engine(&args.model, args);
|
||||
) = create_engine(&args.model, args).await;
|
||||
let engine = Arc::new(engine);
|
||||
let state = completions::CompletionState::new(
|
||||
engine.clone(),
|
||||
|
|
@ -184,7 +189,7 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
};
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args);
|
||||
let (engine, EngineInfo { chat_template, .. }) = create_engine(chat_model, args).await;
|
||||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
|
|
@ -262,13 +267,6 @@ fn start_heartbeat(args: &ServeArgs) {
|
|||
});
|
||||
}
|
||||
|
||||
async fn download_model(model: &str) {
|
||||
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
|
||||
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
|
||||
let download_result = downloader.download_ggml_files().await;
|
||||
download_result.unwrap_or_else(handler);
|
||||
}
|
||||
|
||||
trait OpenApiOverride {
|
||||
fn override_doc(&mut self, args: &ServeArgs);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue