parent
781e6a7db3
commit
d8cee4adac
|
|
@ -2086,6 +2086,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"tabby-common",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http",
|
||||
|
|
@ -2094,6 +2095,10 @@ dependencies = [
|
|||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabby-common"
|
||||
version = "0.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.38"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"crates/tabby",
|
||||
"crates/tabby-common",
|
||||
"crates/ctranslate2-bindings",
|
||||
"crates/rust-cxx-cmake-bridge",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
[package]
|
||||
name = "tabby-common"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
|
|
@ -0,0 +1 @@
|
|||
pub mod path;
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn get_root_dir() -> PathBuf {
|
||||
match env::var("TABBY_ROOT") {
|
||||
Ok(x) => PathBuf::from(x),
|
||||
Err(_) => PathBuf::from(env::var("HOME").unwrap()).join(".tabby"),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelDir(PathBuf);
|
||||
|
||||
impl ModelDir {
|
||||
pub fn new(model: &str) -> Self {
|
||||
Self(get_root_dir().join("models").join(model))
|
||||
}
|
||||
|
||||
pub fn from(path: &str) -> Self {
|
||||
Self(PathBuf::from(path))
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn path_string(&self, name: &str) -> String {
|
||||
self.0.join(name).display().to_string()
|
||||
}
|
||||
|
||||
pub fn metadata_file(&self) -> String {
|
||||
return self.path_string("metadata.json");
|
||||
}
|
||||
|
||||
pub fn tokenizer_file(&self) -> String {
|
||||
return self.path_string("tokenizer.json");
|
||||
}
|
||||
|
||||
pub fn ctranslate2_dir(&self) -> String {
|
||||
self.path_string("ctranslate2")
|
||||
}
|
||||
}
|
||||
|
|
@ -26,6 +26,7 @@ reqwest = { version = "0.11.18", features = ["stream"] }
|
|||
error-chain = "0.12.4"
|
||||
indicatif = "0.17.3"
|
||||
futures-util = "0.3.28"
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
use std::cmp;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
|
||||
use clap::Args;
|
||||
use error_chain::error_chain;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use tabby_common::path::ModelDir;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
|
|
@ -38,18 +39,10 @@ async fn download_model(model_id: &str) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_model_dir(model_id: &str) -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let tabby_root = format!("{}/.tabby", home);
|
||||
let model_dir = Path::new(&tabby_root).join("models").join(model_id);
|
||||
model_dir
|
||||
}
|
||||
|
||||
async fn download_metadata(model_id: &str) -> Result<()> {
|
||||
let url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||
let fname = "metadata.json";
|
||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
||||
let filepath = ModelDir::new(model_id).metadata_file();
|
||||
download_file(&url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
||||
|
|
@ -57,11 +50,11 @@ async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
|||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
|
||||
|
||||
// Create destination path.
|
||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
||||
let filepath = ModelDir::new(model_id).path_string(fname);
|
||||
download_file(&url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
||||
async fn download_file(url: &str, path: &str) -> Result<()> {
|
||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||
|
||||
// Reqwest setup
|
||||
|
|
@ -78,7 +71,7 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
|||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", &name));
|
||||
pb.set_message(format!("Downloading {}", path));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
|
||||
|
|
@ -94,6 +87,6 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
|||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", &name));
|
||||
pb.finish_with_message(format!("Downloaded {}", path));
|
||||
return Ok(());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use ctranslate2_bindings::{
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tabby_common::path::ModelDir;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
mod languages;
|
||||
|
|
@ -72,8 +73,8 @@ impl CompletionState {
|
|||
|
||||
let device = format!("{}", args.device);
|
||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
.model_path(model_dir.join("ctranslate2").display().to_string())
|
||||
.tokenizer_path(model_dir.join("tokenizer.json").display().to_string())
|
||||
.model_path(model_dir.ctranslate2_dir())
|
||||
.tokenizer_path(model_dir.tokenizer_file())
|
||||
.device(device)
|
||||
.model_type(metadata.transformers_info.auto_model)
|
||||
.device_indices(args.device_indices.clone())
|
||||
|
|
@ -85,13 +86,11 @@ impl CompletionState {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_model_dir(model: &str) -> std::path::PathBuf {
|
||||
fn get_model_dir(model: &str) -> ModelDir {
|
||||
if Path::new(model).exists() {
|
||||
Path::new(model).to_path_buf()
|
||||
ModelDir::from(model)
|
||||
} else {
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let tabby_root = format!("{}/.tabby", home);
|
||||
Path::new(&tabby_root).join("models").join(model)
|
||||
ModelDir::new(model)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -115,8 +114,8 @@ struct TransformersInfo {
|
|||
auto_model: String,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata {
|
||||
let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap();
|
||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||
let file = std::fs::File::open(model_dir.metadata_file()).unwrap();
|
||||
let reader = std::io::BufReader::new(file);
|
||||
serde_json::from_reader(reader).unwrap()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue