feat: add `tabby download` command (#157)

* simplify fmt-display

* cleanup

* move tabby-admin to reduce nest

* add model downloader

* get rid of model-type

* improve commands

* fix fmt
add-prefix-suffix
Meng Zhang 2023-05-28 14:36:11 -07:00 committed by GitHub
parent 80d1dd1ca6
commit 48796ecd77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
75 changed files with 351 additions and 81 deletions

148
Cargo.lock generated
View File

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
@ -152,6 +161,21 @@ dependencies = [
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide 0.6.2",
"object",
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.13.1"
@ -652,6 +676,16 @@ dependencies = [
"libc",
]
[[package]]
name = "error-chain"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc"
dependencies = [
"backtrace",
"version_check",
]
[[package]]
name = "esaxx-rs"
version = "0.1.8"
@ -689,7 +723,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
dependencies = [
"crc32fast",
"miniz_oxide",
"miniz_oxide 0.7.1",
]
[[package]]
@ -753,6 +787,17 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]]
name = "futures-macro"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.18",
]
[[package]]
name = "futures-sink"
version = "0.3.28"
@ -773,6 +818,8 @@ checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
@ -801,6 +848,12 @@ dependencies = [
"wasi",
]
[[package]]
name = "gimli"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]]
name = "glob"
version = "0.3.1"
@ -996,6 +1049,18 @@ dependencies = [
"regex",
]
[[package]]
name = "indicatif"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729"
dependencies = [
"console",
"number_prefix 0.4.0",
"portable-atomic 0.3.20",
"unicode-width",
]
[[package]]
name = "inout"
version = "0.1.3"
@ -1190,6 +1255,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
dependencies = [
"adler",
]
[[package]]
name = "miniz_oxide"
version = "0.7.1"
@ -1282,6 +1356,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "object"
version = "0.30.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.17.1"
@ -1450,6 +1533,21 @@ version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
[[package]]
name = "portable-atomic"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e30165d31df606f5726b090ec7592c308a0eaf61721ff64c9a3018e344a8753e"
dependencies = [
"portable-atomic 1.3.2",
]
[[package]]
name = "portable-atomic"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc59d1bcc64fc5d021d67521f818db868368028108d37f0e98d74e33f68297b5"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@ -1642,10 +1740,12 @@ dependencies = [
"serde_urlencoded",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"winreg",
]
@ -1689,6 +1789,12 @@ dependencies = [
"walkdir",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustix"
version = "0.37.19"
@ -1903,6 +2009,28 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "strum"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.24.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn 1.0.109",
]
[[package]]
name = "subtle"
version = "2.5.0"
@ -1945,14 +2073,19 @@ dependencies = [
"clap",
"ctranslate2-bindings",
"env_logger",
"error-chain",
"futures-util",
"hyper",
"indicatif 0.17.3",
"lazy_static",
"log",
"mime_guess",
"regex",
"reqwest",
"rust-embed",
"serde",
"serde_json",
"strum",
"tokio",
"tower",
"tower-http",
@ -2454,6 +2587,19 @@ version = "0.2.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93"
[[package]]
name = "wasm-streams"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.63"

View File

@ -91,9 +91,15 @@ std::unique_ptr<TextInferenceEngine> create_engine(
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
loader.num_replicas_per_device = num_replicas_per_device;
if (model_type_str == "decoder") {
if (loader.device == ctranslate2::Device::CPU) {
loader.compute_type = ctranslate2::ComputeType::INT8;
} else if (loader.device == ctranslate2::Device::CUDA) {
loader.compute_type = ctranslate2::ComputeType::FLOAT16;
}
if (model_type_str == "AutoModelForCausalLM") {
return DecoderImpl::create(loader);
} else if (model_type_str == "encoder-decoder") {
} else if (model_type_str == "AutoModelForSeq2SeqLM") {
return EncoderDecoderImpl::create(loader);
} else {
return nullptr;

View File

Before

Width:  |  Height:  |  Size: 115 KiB

After

Width:  |  Height:  |  Size: 115 KiB

View File

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 120 KiB

View File

Before

Width:  |  Height:  |  Size: 116 KiB

After

Width:  |  Height:  |  Size: 116 KiB

View File

Before

Width:  |  Height:  |  Size: 6.0 KiB

After

Width:  |  Height:  |  Size: 6.0 KiB

View File

Before

Width:  |  Height:  |  Size: 120 KiB

After

Width:  |  Height:  |  Size: 120 KiB

View File

Before

Width:  |  Height:  |  Size: 6.0 KiB

After

Width:  |  Height:  |  Size: 6.0 KiB

View File

@ -21,6 +21,11 @@ regex = "1.8.3"
lazy_static = "1.4.0"
rust-embed = "6.6.1"
mime_guess = "2.0.4"
strum = { version = "0.24", features = ["derive"] }
reqwest = { version = "0.11.18", features = ["stream"] }
error-chain = "0.12.4"
indicatif = "0.17.3"
futures-util = "0.3.28"
[dependencies.uuid]
version = "1.3.3"

View File

@ -0,0 +1,99 @@
use std::cmp;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use clap::Args;
use error_chain::error_chain;
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
#[derive(Args)]
pub struct DownloadArgs {
/// model id to fetch.
#[clap(long)]
model: String,
}
error_chain! {
foreign_links {
Io(std::io::Error);
HttpRequest(reqwest::Error);
TemplateError(indicatif::style::TemplateError);
}
}
pub async fn main(args: &DownloadArgs) -> Result<()> {
download_model(&args.model).await.unwrap();
Ok(())
}
async fn download_model(model_id: &str) -> Result<()> {
download_metadata(model_id).await?;
download_model_file(model_id, "tokenizer.json").await?;
download_model_file(model_id, &format!("ctranslate2/config.json")).await?;
download_model_file(model_id, &format!("ctranslate2/vocabulary.txt")).await?;
download_model_file(model_id, &format!("ctranslate2/shared_vocabulary.txt")).await?;
download_model_file(model_id, &format!("ctranslate2/model.bin")).await?;
Ok(())
}
fn get_model_dir(model_id: &str) -> PathBuf {
let home = std::env::var("HOME").unwrap();
let tabby_root = format!("{}/.tabby", home);
let model_dir = Path::new(&tabby_root).join("models").join(model_id);
model_dir
}
async fn download_metadata(model_id: &str) -> Result<()> {
let url = format!("https://huggingface.co/api/models/{}", model_id);
let fname = "metadata.json";
let filepath = get_model_dir(model_id).join(fname).display().to_string();
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
}
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
// Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
// Create destination path.
let filepath = get_model_dir(model_id).join(fname).display().to_string();
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
}
async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
fs::create_dir_all(Path::new(path).parent().unwrap())?;
// Reqwest setup
let res = reqwest::get(url)
.await
.or(Err(format!("Failed to GET from '{}'", url)))?;
let total_size = res
.content_length()
.ok_or(format!("Failed to get content length from '{}'", url))?;
// Indicatif setup
let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.progress_chars("#>-"));
pb.set_message(format!("Downloading {}", &name));
// download chunks
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let chunk = item.or(Err(format!("Error while downloading file")))?;
file.write_all(&chunk)
.or(Err(format!("Error while writing to file")))?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}
pb.finish_with_message(format!("Downloaded {}", &name));
return Ok(());
}

View File

@ -1,3 +1,6 @@
mod download;
mod serve;
use clap::{Parser, Subcommand};
#[derive(Parser)]
@ -12,21 +15,25 @@ struct Cli {
pub enum Commands {
/// Serve the model
Serve(serve::ServeArgs),
}
mod serve;
/// Download the model
Download(download::DownloadArgs),
}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
// You can check for the existence of subcommands, and if found use their
// matches just as you would the top level cmd
match &cli.command {
Commands::Serve(args) => {
serve::main(args)
.await
.expect("Error happens during the serve");
}
Commands::Download(args) => {
download::main(args)
.await
.expect("Error happens during the download");
}
}
}

View File

@ -5,7 +5,7 @@ use axum::{
};
#[derive(rust_embed::RustEmbed)]
#[folder = "src/serve/admin/dist/"]
#[folder = "../tabby-admin/dist/"]
struct AdminAssets;
struct AdminStaticFile<T>(pub T);

View File

@ -1,8 +1,9 @@
use axum::{extract::State, Json};
use ctranslate2_bindings::{
TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder,
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use utoipa::ToSchema;
@ -64,7 +65,22 @@ pub struct CompletionState {
}
impl CompletionState {
pub fn new(options: TextInferenceEngineCreateOptions) -> Self {
pub fn new(args: &crate::serve::ServeArgs) -> Self {
let home = std::env::var("HOME").unwrap();
let tabby_root = format!("{}/.tabby", home);
let model_dir = Path::new(&tabby_root).join("models").join(&args.model);
let metadata = read_metadata(&model_dir);
let device = format!("{}", args.device);
let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(model_dir.join("ctranslate2").display().to_string())
.tokenizer_path(model_dir.join("tokenizer.json").display().to_string())
.device(device)
.model_type(metadata.transformers_info.auto_model)
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.build()
.unwrap();
let engine = TextInferenceEngine::create(options);
Self { engine }
}
@ -78,3 +94,20 @@ fn timestamp() -> u64 {
.expect("Time went backwards")
.as_secs()
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct Metadata {
transformers_info: TransformersInfo,
}
#[derive(Deserialize)]
struct TransformersInfo {
auto_model: String,
}
fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata {
let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap();
let reader = std::io::BufReader::new(file);
serde_json::from_reader(reader).unwrap()
}

View File

@ -1,20 +1,19 @@
mod admin;
mod completions;
mod events;
use crate::Cli;
use axum::{routing, Router, Server};
use clap::{error::ErrorKind, Args, CommandFactory};
use hyper::Error;
use std::{
net::{Ipv4Addr, SocketAddr},
sync::Arc,
};
use axum::{routing, Router, Server};
use clap::Args;
use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder;
use hyper::Error;
use std::path::Path;
use tower_http::cors::CorsLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
mod completions;
mod events;
#[derive(OpenApi)]
#[openapi(
paths(events::log_event, completions::completion,),
@ -27,70 +26,47 @@ mod events;
)]
struct ApiDoc;
#[derive(clap::ValueEnum, Clone)]
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[strum(serialize = "cpu")]
CPU,
#[strum(serialize = "cuda")]
CUDA,
}
impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let printable = match *self {
Device::CPU => "cpu",
Device::CUDA => "cuda",
};
write!(f, "{}", printable)
}
}
#[derive(clap::ValueEnum, Clone)]
pub enum ModelType {
EncoderDecoder,
Decoder,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let printable = match *self {
ModelType::EncoderDecoder => "encoder-decoder",
ModelType::Decoder => "decoder",
};
write!(f, "{}", printable)
}
}
#[derive(Args)]
pub struct ServeArgs {
/// path to model for serving
/// Model id for serving.
#[clap(long)]
model: String,
/// model type for serving
#[clap(long, default_value_t=ModelType::Decoder)]
model_type: ModelType,
#[clap(long, default_value_t = 8080)]
port: u16,
/// Device to run model inference.
#[clap(long, default_value_t=Device::CPU)]
device: Device,
/// GPU indices to run models, only applicable for CUDA.
#[clap(long, default_values_t=[0])]
device_indices: Vec<i32>,
/// num_replicas_per_device
/// Number of replicas per device, only applicable for CPU.
#[clap(long, default_value_t = 1)]
num_replicas_per_device: usize,
/// *INTERNAL ONLY*
#[clap(long, default_value_t = false)]
experimental_admin_panel: bool,
}
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
valid_args(args);
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.nest("/v1", api_router(args))
.fallback(fallback(args));
.fallback(fallback(args.experimental_admin_panel));
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
println!("Listening at {}", address);
@ -100,41 +76,39 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
fn api_router(args: &ServeArgs) -> Router {
Router::new()
.route("/events", routing::post(events::log_event))
.route("/completions", routing::post(completions::completion))
.with_state(Arc::new(new_completion_state(args)))
.route(
"/completions",
routing::post(completions::completion)
.with_state(Arc::new(completions::CompletionState::new(args))),
)
.layer(CorsLayer::permissive())
}
mod admin;
fn fallback(args: &ServeArgs) -> routing::MethodRouter {
if args.experimental_admin_panel {
fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
if experimental_admin_panel {
routing::get(admin::handler)
} else {
routing::get(|| async { axum::response::Redirect::temporary("/swagger-ui") })
}
}
fn new_completion_state(args: &ServeArgs) -> completions::CompletionState {
let device = format!("{}", args.device);
let options = TextInferenceEngineCreateOptionsBuilder::default()
.model_path(
Path::new(&args.model)
.join("ctranslate2")
.join(device.clone())
.display()
.to_string(),
fn valid_args(args: &ServeArgs) {
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
Cli::command()
.error(
ErrorKind::ValueValidation,
"CUDA device only supports 1 replicas per device",
)
.tokenizer_path(
Path::new(&args.model)
.join("tokenizer.json")
.display()
.to_string(),
)
.device(device)
.model_type(format!("{}", args.model_type))
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.build()
.unwrap();
completions::CompletionState::new(options)
.exit();
}
if args.device == Device::CPU && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{
Cli::command()
.error(
ErrorKind::ValueValidation,
"CPU device only supports device indices = [0]",
)
.exit();
}
}