parent
781e6a7db3
commit
d8cee4adac
|
|
@ -2086,6 +2086,7 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"strum",
|
"strum",
|
||||||
|
"tabby-common",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
|
@ -2094,6 +2095,10 @@ dependencies = [
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tabby-common"
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tar"
|
name = "tar"
|
||||||
version = "0.4.38"
|
version = "0.4.38"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"crates/tabby",
|
"crates/tabby",
|
||||||
|
"crates/tabby-common",
|
||||||
"crates/ctranslate2-bindings",
|
"crates/ctranslate2-bindings",
|
||||||
"crates/rust-cxx-cmake-bridge",
|
"crates/rust-cxx-cmake-bridge",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
[package]
|
||||||
|
name = "tabby-common"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
pub mod path;
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
use std::env;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
fn get_root_dir() -> PathBuf {
|
||||||
|
match env::var("TABBY_ROOT") {
|
||||||
|
Ok(x) => PathBuf::from(x),
|
||||||
|
Err(_) => PathBuf::from(env::var("HOME").unwrap()).join(".tabby"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModelDir(PathBuf);
|
||||||
|
|
||||||
|
impl ModelDir {
|
||||||
|
pub fn new(model: &str) -> Self {
|
||||||
|
Self(get_root_dir().join("models").join(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from(path: &str) -> Self {
|
||||||
|
Self(PathBuf::from(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn path(&self) -> &PathBuf {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn path_string(&self, name: &str) -> String {
|
||||||
|
self.0.join(name).display().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn metadata_file(&self) -> String {
|
||||||
|
return self.path_string("metadata.json");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenizer_file(&self) -> String {
|
||||||
|
return self.path_string("tokenizer.json");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ctranslate2_dir(&self) -> String {
|
||||||
|
self.path_string("ctranslate2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -26,6 +26,7 @@ reqwest = { version = "0.11.18", features = ["stream"] }
|
||||||
error-chain = "0.12.4"
|
error-chain = "0.12.4"
|
||||||
indicatif = "0.17.3"
|
indicatif = "0.17.3"
|
||||||
futures-util = "0.3.28"
|
futures-util = "0.3.28"
|
||||||
|
tabby-common = { path = "../tabby-common" }
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::Path;
|
||||||
|
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use error_chain::error_chain;
|
use error_chain::error_chain;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
|
use tabby_common::path::ModelDir;
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
pub struct DownloadArgs {
|
pub struct DownloadArgs {
|
||||||
|
|
@ -38,18 +39,10 @@ async fn download_model(model_id: &str) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_dir(model_id: &str) -> PathBuf {
|
|
||||||
let home = std::env::var("HOME").unwrap();
|
|
||||||
let tabby_root = format!("{}/.tabby", home);
|
|
||||||
let model_dir = Path::new(&tabby_root).join("models").join(model_id);
|
|
||||||
model_dir
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn download_metadata(model_id: &str) -> Result<()> {
|
async fn download_metadata(model_id: &str) -> Result<()> {
|
||||||
let url = format!("https://huggingface.co/api/models/{}", model_id);
|
let url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||||
let fname = "metadata.json";
|
let filepath = ModelDir::new(model_id).metadata_file();
|
||||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
download_file(&url, &filepath).await
|
||||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
||||||
|
|
@ -57,11 +50,11 @@ async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
|
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
|
||||||
|
|
||||||
// Create destination path.
|
// Create destination path.
|
||||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
let filepath = ModelDir::new(model_id).path_string(fname);
|
||||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
download_file(&url, &filepath).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
async fn download_file(url: &str, path: &str) -> Result<()> {
|
||||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||||
|
|
||||||
// Reqwest setup
|
// Reqwest setup
|
||||||
|
|
@ -78,7 +71,7 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
||||||
pb.set_style(ProgressStyle::default_bar()
|
pb.set_style(ProgressStyle::default_bar()
|
||||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||||
.progress_chars("#>-"));
|
.progress_chars("#>-"));
|
||||||
pb.set_message(format!("Downloading {}", &name));
|
pb.set_message(format!("Downloading {}", path));
|
||||||
|
|
||||||
// download chunks
|
// download chunks
|
||||||
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
|
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
|
||||||
|
|
@ -94,6 +87,6 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
||||||
pb.set_position(new);
|
pb.set_position(new);
|
||||||
}
|
}
|
||||||
|
|
||||||
pb.finish_with_message(format!("Downloaded {}", &name));
|
pb.finish_with_message(format!("Downloaded {}", path));
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ use ctranslate2_bindings::{
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tabby_common::path::ModelDir;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
mod languages;
|
mod languages;
|
||||||
|
|
@ -72,8 +73,8 @@ impl CompletionState {
|
||||||
|
|
||||||
let device = format!("{}", args.device);
|
let device = format!("{}", args.device);
|
||||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||||
.model_path(model_dir.join("ctranslate2").display().to_string())
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
.tokenizer_path(model_dir.join("tokenizer.json").display().to_string())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
.device(device)
|
.device(device)
|
||||||
.model_type(metadata.transformers_info.auto_model)
|
.model_type(metadata.transformers_info.auto_model)
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
|
|
@ -85,13 +86,11 @@ impl CompletionState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_dir(model: &str) -> std::path::PathBuf {
|
fn get_model_dir(model: &str) -> ModelDir {
|
||||||
if Path::new(model).exists() {
|
if Path::new(model).exists() {
|
||||||
Path::new(model).to_path_buf()
|
ModelDir::from(model)
|
||||||
} else {
|
} else {
|
||||||
let home = std::env::var("HOME").unwrap();
|
ModelDir::new(model)
|
||||||
let tabby_root = format!("{}/.tabby", home);
|
|
||||||
Path::new(&tabby_root).join("models").join(model)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,8 +114,8 @@ struct TransformersInfo {
|
||||||
auto_model: String,
|
auto_model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap();
|
let file = std::fs::File::open(model_dir.metadata_file()).unwrap();
|
||||||
let reader = std::io::BufReader::new(file);
|
let reader = std::io::BufReader::new(file);
|
||||||
serde_json::from_reader(reader).unwrap()
|
serde_json::from_reader(reader).unwrap()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue