diff --git a/Cargo.lock b/Cargo.lock index 475c904..5338028 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2086,6 +2086,7 @@ dependencies = [ "serde", "serde_json", "strum", + "tabby-common", "tokio", "tower", "tower-http", @@ -2094,6 +2095,10 @@ dependencies = [ "uuid", ] +[[package]] +name = "tabby-common" +version = "0.1.0" + [[package]] name = "tar" version = "0.4.38" diff --git a/Cargo.toml b/Cargo.toml index 1385297..22fae40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/tabby", + "crates/tabby-common", "crates/ctranslate2-bindings", "crates/rust-cxx-cmake-bridge", ] diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml new file mode 100644 index 0000000..e457ec7 --- /dev/null +++ b/crates/tabby-common/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tabby-common" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs new file mode 100644 index 0000000..4da9789 --- /dev/null +++ b/crates/tabby-common/src/lib.rs @@ -0,0 +1 @@ +pub mod path; diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs new file mode 100644 index 0000000..333b182 --- /dev/null +++ b/crates/tabby-common/src/path.rs @@ -0,0 +1,41 @@ +use std::env; +use std::path::PathBuf; + +fn get_root_dir() -> PathBuf { + match env::var("TABBY_ROOT") { + Ok(x) => PathBuf::from(x), + Err(_) => PathBuf::from(env::var("HOME").unwrap()).join(".tabby"), + } +} + +pub struct ModelDir(PathBuf); + +impl ModelDir { + pub fn new(model: &str) -> Self { + Self(get_root_dir().join("models").join(model)) + } + + pub fn from(path: &str) -> Self { + Self(PathBuf::from(path)) + } + + pub fn path(&self) -> &PathBuf { + &self.0 + } + + pub fn path_string(&self, name: &str) -> String { + self.0.join(name).display().to_string() + } + + pub fn metadata_file(&self) -> String { + return self.path_string("metadata.json"); + } + + pub fn tokenizer_file(&self) -> String { + return self.path_string("tokenizer.json"); + } + + pub fn ctranslate2_dir(&self) -> String { + self.path_string("ctranslate2") + } +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 7ff8399..97f6ec7 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -26,6 +26,7 @@ reqwest = { version = "0.11.18", features = ["stream"] } error-chain = "0.12.4" indicatif = "0.17.3" futures-util = "0.3.28" +tabby-common = { path = "../tabby-common" } [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index 7b8ec6f..c458a44 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,12 +1,13 @@ use std::cmp; use std::fs; use std::io::Write; -use std::path::{Path, PathBuf}; +use std::path::Path; use clap::Args; use error_chain::error_chain; use futures_util::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; +use tabby_common::path::ModelDir; #[derive(Args)] pub struct DownloadArgs { @@ -38,18 +39,10 @@ async fn download_model(model_id: &str) -> Result<()> { Ok(()) } -fn get_model_dir(model_id: &str) -> PathBuf { - let home = std::env::var("HOME").unwrap(); - let tabby_root = format!("{}/.tabby", home); - let model_dir = Path::new(&tabby_root).join("models").join(model_id); - model_dir -} - async fn download_metadata(model_id: &str) -> Result<()> { let url = format!("https://huggingface.co/api/models/{}", model_id); - let fname = "metadata.json"; - let filepath = get_model_dir(model_id).join(fname).display().to_string(); - download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await + let filepath = ModelDir::new(model_id).metadata_file(); + download_file(&url, &filepath).await } async fn download_model_file(model_id: &str, fname: &str) -> Result<()> { @@ -57,11 +50,11 @@ async fn download_model_file(model_id: &str, fname: &str) -> Result<()> { let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname); // Create destination path. - let filepath = get_model_dir(model_id).join(fname).display().to_string(); - download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await + let filepath = ModelDir::new(model_id).path_string(fname); + download_file(&url, &filepath).await } -async fn download_file(name: &str, url: &str, path: &str) -> Result<()> { +async fn download_file(url: &str, path: &str) -> Result<()> { fs::create_dir_all(Path::new(path).parent().unwrap())?; // Reqwest setup @@ -78,7 +71,7 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> { pb.set_style(ProgressStyle::default_bar() .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? .progress_chars("#>-")); - pb.set_message(format!("Downloading {}", &name)); + pb.set_message(format!("Downloading {}", path)); // download chunks let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?; @@ -94,6 +87,6 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> { pb.set_position(new); } - pb.finish_with_message(format!("Downloaded {}", &name)); + pb.finish_with_message(format!("Downloaded {}", path)); return Ok(()); } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 7747f78..746e12e 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,6 +5,7 @@ use ctranslate2_bindings::{ use serde::{Deserialize, Serialize}; use std::path::Path; use std::sync::Arc; +use tabby_common::path::ModelDir; use utoipa::ToSchema; mod languages; @@ -72,8 +73,8 @@ impl CompletionState { let device = format!("{}", args.device); let options = TextInferenceEngineCreateOptionsBuilder::default() - .model_path(model_dir.join("ctranslate2").display().to_string()) - .tokenizer_path(model_dir.join("tokenizer.json").display().to_string()) + .model_path(model_dir.ctranslate2_dir()) + .tokenizer_path(model_dir.tokenizer_file()) .device(device) .model_type(metadata.transformers_info.auto_model) .device_indices(args.device_indices.clone()) @@ -85,13 +86,11 @@ impl CompletionState { } } -fn get_model_dir(model: &str) -> std::path::PathBuf { +fn get_model_dir(model: &str) -> ModelDir { if Path::new(model).exists() { - Path::new(model).to_path_buf() + ModelDir::from(model) } else { - let home = std::env::var("HOME").unwrap(); - let tabby_root = format!("{}/.tabby", home); - Path::new(&tabby_root).join("models").join(model) + ModelDir::new(model) } } @@ -115,8 +114,8 @@ struct TransformersInfo { auto_model: String, } -fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata { - let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap(); +fn read_metadata(model_dir: &ModelDir) -> Metadata { + let file = std::fs::File::open(model_dir.metadata_file()).unwrap(); let reader = std::io::BufReader::new(file); serde_json::from_reader(reader).unwrap() }