chore: extract tabby-common (#169)

* chore: extract tabby-common

* simplify
support-coreml
Meng Zhang 2023-05-29 23:39:02 -07:00 committed by GitHub
parent 781e6a7db3
commit d8cee4adac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 25 deletions

5
Cargo.lock generated
View File

@ -2086,6 +2086,7 @@ dependencies = [
"serde",
"serde_json",
"strum",
"tabby-common",
"tokio",
"tower",
"tower-http",
@ -2094,6 +2095,10 @@ dependencies = [
"uuid",
]
[[package]]
name = "tabby-common"
version = "0.1.0"
[[package]]
name = "tar"
version = "0.4.38"

View File

@ -1,6 +1,7 @@
[workspace]
members = [
"crates/tabby",
"crates/tabby-common",
"crates/ctranslate2-bindings",
"crates/rust-cxx-cmake-bridge",
]

View File

@ -0,0 +1,8 @@
[package]
name = "tabby-common"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@ -0,0 +1 @@
pub mod path;

View File

@ -0,0 +1,41 @@
use std::env;
use std::path::PathBuf;
fn get_root_dir() -> PathBuf {
match env::var("TABBY_ROOT") {
Ok(x) => PathBuf::from(x),
Err(_) => PathBuf::from(env::var("HOME").unwrap()).join(".tabby"),
}
}
pub struct ModelDir(PathBuf);
impl ModelDir {
pub fn new(model: &str) -> Self {
Self(get_root_dir().join("models").join(model))
}
pub fn from(path: &str) -> Self {
Self(PathBuf::from(path))
}
pub fn path(&self) -> &PathBuf {
&self.0
}
pub fn path_string(&self, name: &str) -> String {
self.0.join(name).display().to_string()
}
pub fn metadata_file(&self) -> String {
return self.path_string("metadata.json");
}
pub fn tokenizer_file(&self) -> String {
return self.path_string("tokenizer.json");
}
pub fn ctranslate2_dir(&self) -> String {
self.path_string("ctranslate2")
}
}

View File

@ -26,6 +26,7 @@ reqwest = { version = "0.11.18", features = ["stream"] }
error-chain = "0.12.4"
indicatif = "0.17.3"
futures-util = "0.3.28"
tabby-common = { path = "../tabby-common" }
[dependencies.uuid]
version = "1.3.3"

View File

@ -1,12 +1,13 @@
use std::cmp;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::path::Path;
use clap::Args;
use error_chain::error_chain;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::path::ModelDir;
#[derive(Args)]
pub struct DownloadArgs {
@ -38,18 +39,10 @@ async fn download_model(model_id: &str) -> Result<()> {
Ok(())
}
fn get_model_dir(model_id: &str) -> PathBuf {
let home = std::env::var("HOME").unwrap();
let tabby_root = format!("{}/.tabby", home);
let model_dir = Path::new(&tabby_root).join("models").join(model_id);
model_dir
}
async fn download_metadata(model_id: &str) -> Result<()> {
let url = format!("https://huggingface.co/api/models/{}", model_id);
let fname = "metadata.json";
let filepath = get_model_dir(model_id).join(fname).display().to_string();
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
let filepath = ModelDir::new(model_id).metadata_file();
download_file(&url, &filepath).await
}
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
@ -57,11 +50,11 @@ async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
// Create destination path.
let filepath = get_model_dir(model_id).join(fname).display().to_string();
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
let filepath = ModelDir::new(model_id).path_string(fname);
download_file(&url, &filepath).await
}
async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
async fn download_file(url: &str, path: &str) -> Result<()> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
@ -78,7 +71,7 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", &name));
pb.set_message(format!("Downloading {}", path));
// download chunks
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
@ -94,6 +87,6 @@ async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", &name));
pb.finish_with_message(format!("Downloaded {}", path));
return Ok(());
}

View File

@ -5,6 +5,7 @@ use ctranslate2_bindings::{
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use tabby_common::path::ModelDir;
use utoipa::ToSchema;
mod languages;
@ -72,8 +73,8 @@ impl CompletionState {
let device = format!("{}", args.device);
let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(model_dir.join("ctranslate2").display().to_string())
.tokenizer_path(model_dir.join("tokenizer.json").display().to_string())
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.transformers_info.auto_model)
.device_indices(args.device_indices.clone())
@ -85,13 +86,11 @@ impl CompletionState {
}
}
fn get_model_dir(model: &str) -> std::path::PathBuf {
fn get_model_dir(model: &str) -> ModelDir {
if Path::new(model).exists() {
Path::new(model).to_path_buf()
ModelDir::from(model)
} else {
let home = std::env::var("HOME").unwrap();
let tabby_root = format!("{}/.tabby", home);
Path::new(&tabby_root).join("models").join(model)
ModelDir::new(model)
}
}
@ -115,8 +114,8 @@ struct TransformersInfo {
auto_model: String,
}
fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata {
let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap();
fn read_metadata(model_dir: &ModelDir) -> Metadata {
let file = std::fs::File::open(model_dir.metadata_file()).unwrap();
let reader = std::io::BufReader::new(file);
serde_json::from_reader(reader).unwrap()
}