feat: add `tabby download` command (#157)
* simplify fmt-display * cleanup * move tabby-admin to reduce nest * add model downloader * get rid of model-type * improve commands * fix fmtadd-prefix-suffix
|
|
@ -2,6 +2,15 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
|
||||
dependencies = [
|
||||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
|
|
@ -152,6 +161,21 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.67"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"miniz_oxide 0.6.2",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
|
|
@ -652,6 +676,16 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "error-chain"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esaxx-rs"
|
||||
version = "0.1.8"
|
||||
|
|
@ -689,7 +723,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
"miniz_oxide 0.7.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -753,6 +787,17 @@ version = "0.3.28"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.28"
|
||||
|
|
@ -773,6 +818,8 @@ checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
|
|||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
|
|
@ -801,6 +848,12 @@ dependencies = [
|
|||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
|
|
@ -996,6 +1049,18 @@ dependencies = [
|
|||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.17.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cef509aa9bc73864d6756f0d34d35504af3cf0844373afe9b8669a5b8005a729"
|
||||
dependencies = [
|
||||
"console",
|
||||
"number_prefix 0.4.0",
|
||||
"portable-atomic 0.3.20",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.3"
|
||||
|
|
@ -1190,6 +1255,15 @@ version = "0.2.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
|
||||
dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.7.1"
|
||||
|
|
@ -1282,6 +1356,15 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.30.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.17.1"
|
||||
|
|
@ -1450,6 +1533,21 @@ version = "0.3.27"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "0.3.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e30165d31df606f5726b090ec7592c308a0eaf61721ff64c9a3018e344a8753e"
|
||||
dependencies = [
|
||||
"portable-atomic 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc59d1bcc64fc5d021d67521f818db868368028108d37f0e98d74e33f68297b5"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
|
|
@ -1642,10 +1740,12 @@ dependencies = [
|
|||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"winreg",
|
||||
]
|
||||
|
|
@ -1689,6 +1789,12 @@ dependencies = [
|
|||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.37.19"
|
||||
|
|
@ -1903,6 +2009,28 @@ version = "0.10.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.24.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.24.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
|
|
@ -1945,14 +2073,19 @@ dependencies = [
|
|||
"clap",
|
||||
"ctranslate2-bindings",
|
||||
"env_logger",
|
||||
"error-chain",
|
||||
"futures-util",
|
||||
"hyper",
|
||||
"indicatif 0.17.3",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"mime_guess",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rust-embed",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http",
|
||||
|
|
@ -2454,6 +2587,19 @@ version = "0.2.86"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.63"
|
||||
|
|
|
|||
|
|
@ -91,9 +91,15 @@ std::unique_ptr<TextInferenceEngine> create_engine(
|
|||
loader.device_indices = std::vector<int>(device_indices.begin(), device_indices.end());
|
||||
loader.num_replicas_per_device = num_replicas_per_device;
|
||||
|
||||
if (model_type_str == "decoder") {
|
||||
if (loader.device == ctranslate2::Device::CPU) {
|
||||
loader.compute_type = ctranslate2::ComputeType::INT8;
|
||||
} else if (loader.device == ctranslate2::Device::CUDA) {
|
||||
loader.compute_type = ctranslate2::ComputeType::FLOAT16;
|
||||
}
|
||||
|
||||
if (model_type_str == "AutoModelForCausalLM") {
|
||||
return DecoderImpl::create(loader);
|
||||
} else if (model_type_str == "encoder-decoder") {
|
||||
} else if (model_type_str == "AutoModelForSeq2SeqLM") {
|
||||
return EncoderDecoderImpl::create(loader);
|
||||
} else {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 115 KiB |
|
Before Width: | Height: | Size: 120 KiB After Width: | Height: | Size: 120 KiB |
|
Before Width: | Height: | Size: 116 KiB After Width: | Height: | Size: 116 KiB |
|
Before Width: | Height: | Size: 6.0 KiB After Width: | Height: | Size: 6.0 KiB |
|
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 115 KiB |
|
Before Width: | Height: | Size: 120 KiB After Width: | Height: | Size: 120 KiB |
|
Before Width: | Height: | Size: 116 KiB After Width: | Height: | Size: 116 KiB |
|
Before Width: | Height: | Size: 6.0 KiB After Width: | Height: | Size: 6.0 KiB |
|
|
@ -21,6 +21,11 @@ regex = "1.8.3"
|
|||
lazy_static = "1.4.0"
|
||||
rust-embed = "6.6.1"
|
||||
mime_guess = "2.0.4"
|
||||
strum = { version = "0.24", features = ["derive"] }
|
||||
reqwest = { version = "0.11.18", features = ["stream"] }
|
||||
error-chain = "0.12.4"
|
||||
indicatif = "0.17.3"
|
||||
futures-util = "0.3.28"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,99 @@
|
|||
use std::cmp;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use clap::Args;
|
||||
use error_chain::error_chain;
|
||||
use futures_util::StreamExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct DownloadArgs {
|
||||
/// model id to fetch.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
}
|
||||
|
||||
error_chain! {
|
||||
foreign_links {
|
||||
Io(std::io::Error);
|
||||
HttpRequest(reqwest::Error);
|
||||
TemplateError(indicatif::style::TemplateError);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn main(args: &DownloadArgs) -> Result<()> {
|
||||
download_model(&args.model).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download_model(model_id: &str) -> Result<()> {
|
||||
download_metadata(model_id).await?;
|
||||
download_model_file(model_id, "tokenizer.json").await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/config.json")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/vocabulary.txt")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/shared_vocabulary.txt")).await?;
|
||||
download_model_file(model_id, &format!("ctranslate2/model.bin")).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_model_dir(model_id: &str) -> PathBuf {
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let tabby_root = format!("{}/.tabby", home);
|
||||
let model_dir = Path::new(&tabby_root).join("models").join(model_id);
|
||||
model_dir
|
||||
}
|
||||
|
||||
async fn download_metadata(model_id: &str) -> Result<()> {
|
||||
let url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||
let fname = "metadata.json";
|
||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_model_file(model_id: &str, fname: &str) -> Result<()> {
|
||||
// Create url.
|
||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, fname);
|
||||
|
||||
// Create destination path.
|
||||
let filepath = get_model_dir(model_id).join(fname).display().to_string();
|
||||
download_file(&format!("{}/{}", model_id, fname), &url, &filepath).await
|
||||
}
|
||||
|
||||
async fn download_file(name: &str, url: &str, path: &str) -> Result<()> {
|
||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
||||
|
||||
// Reqwest setup
|
||||
let res = reqwest::get(url)
|
||||
.await
|
||||
.or(Err(format!("Failed to GET from '{}'", url)))?;
|
||||
|
||||
let total_size = res
|
||||
.content_length()
|
||||
.ok_or(format!("Failed to get content length from '{}'", url))?;
|
||||
|
||||
// Indicatif setup
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(ProgressStyle::default_bar()
|
||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
||||
.progress_chars("#>-"));
|
||||
pb.set_message(format!("Downloading {}", &name));
|
||||
|
||||
// download chunks
|
||||
let mut file = fs::File::create(&path).or(Err(format!("Failed to create file '{}'", &path)))?;
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut stream = res.bytes_stream();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
let chunk = item.or(Err(format!("Error while downloading file")))?;
|
||||
file.write_all(&chunk)
|
||||
.or(Err(format!("Error while writing to file")))?;
|
||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
|
||||
pb.finish_with_message(format!("Downloaded {}", &name));
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
mod download;
|
||||
mod serve;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Parser)]
|
||||
|
|
@ -12,21 +15,25 @@ struct Cli {
|
|||
pub enum Commands {
|
||||
/// Serve the model
|
||||
Serve(serve::ServeArgs),
|
||||
}
|
||||
|
||||
mod serve;
|
||||
/// Download the model
|
||||
Download(download::DownloadArgs),
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// You can check for the existence of subcommands, and if found use their
|
||||
// matches just as you would the top level cmd
|
||||
match &cli.command {
|
||||
Commands::Serve(args) => {
|
||||
serve::main(args)
|
||||
.await
|
||||
.expect("Error happens during the serve");
|
||||
}
|
||||
Commands::Download(args) => {
|
||||
download::main(args)
|
||||
.await
|
||||
.expect("Error happens during the download");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use axum::{
|
|||
};
|
||||
|
||||
#[derive(rust_embed::RustEmbed)]
|
||||
#[folder = "src/serve/admin/dist/"]
|
||||
#[folder = "../tabby-admin/dist/"]
|
||||
struct AdminAssets;
|
||||
|
||||
struct AdminStaticFile<T>(pub T);
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
use axum::{extract::State, Json};
|
||||
use ctranslate2_bindings::{
|
||||
TextInferenceEngine, TextInferenceEngineCreateOptions, TextInferenceOptionsBuilder,
|
||||
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
|
|
@ -64,7 +65,22 @@ pub struct CompletionState {
|
|||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(options: TextInferenceEngineCreateOptions) -> Self {
|
||||
pub fn new(args: &crate::serve::ServeArgs) -> Self {
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let tabby_root = format!("{}/.tabby", home);
|
||||
let model_dir = Path::new(&tabby_root).join("models").join(&args.model);
|
||||
let metadata = read_metadata(&model_dir);
|
||||
|
||||
let device = format!("{}", args.device);
|
||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
.model_path(model_dir.join("ctranslate2").display().to_string())
|
||||
.tokenizer_path(model_dir.join("tokenizer.json").display().to_string())
|
||||
.device(device)
|
||||
.model_type(metadata.transformers_info.auto_model)
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.build()
|
||||
.unwrap();
|
||||
let engine = TextInferenceEngine::create(options);
|
||||
Self { engine }
|
||||
}
|
||||
|
|
@ -78,3 +94,20 @@ fn timestamp() -> u64 {
|
|||
.expect("Time went backwards")
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Metadata {
|
||||
transformers_info: TransformersInfo,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TransformersInfo {
|
||||
auto_model: String,
|
||||
}
|
||||
|
||||
fn read_metadata(model_dir: &std::path::PathBuf) -> Metadata {
|
||||
let file = std::fs::File::open(model_dir.join("metadata.json")).unwrap();
|
||||
let reader = std::io::BufReader::new(file);
|
||||
serde_json::from_reader(reader).unwrap()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,20 +1,19 @@
|
|||
mod admin;
|
||||
mod completions;
|
||||
mod events;
|
||||
|
||||
use crate::Cli;
|
||||
use axum::{routing, Router, Server};
|
||||
use clap::{error::ErrorKind, Args, CommandFactory};
|
||||
use hyper::Error;
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::{routing, Router, Server};
|
||||
use clap::Args;
|
||||
use ctranslate2_bindings::TextInferenceEngineCreateOptionsBuilder;
|
||||
use hyper::Error;
|
||||
use std::path::Path;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
mod completions;
|
||||
mod events;
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(events::log_event, completions::completion,),
|
||||
|
|
@ -27,70 +26,47 @@ mod events;
|
|||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
#[derive(clap::ValueEnum, Clone)]
|
||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||
pub enum Device {
|
||||
#[strum(serialize = "cpu")]
|
||||
CPU,
|
||||
|
||||
#[strum(serialize = "cuda")]
|
||||
CUDA,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Device {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let printable = match *self {
|
||||
Device::CPU => "cpu",
|
||||
Device::CUDA => "cuda",
|
||||
};
|
||||
write!(f, "{}", printable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Clone)]
|
||||
pub enum ModelType {
|
||||
EncoderDecoder,
|
||||
Decoder,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ModelType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let printable = match *self {
|
||||
ModelType::EncoderDecoder => "encoder-decoder",
|
||||
ModelType::Decoder => "decoder",
|
||||
};
|
||||
write!(f, "{}", printable)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// path to model for serving
|
||||
/// Model id for serving.
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
|
||||
/// model type for serving
|
||||
#[clap(long, default_value_t=ModelType::Decoder)]
|
||||
model_type: ModelType,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
/// Device to run model inference.
|
||||
#[clap(long, default_value_t=Device::CPU)]
|
||||
device: Device,
|
||||
|
||||
/// GPU indices to run models, only applicable for CUDA.
|
||||
#[clap(long, default_values_t=[0])]
|
||||
device_indices: Vec<i32>,
|
||||
|
||||
/// num_replicas_per_device
|
||||
/// Number of replicas per device, only applicable for CPU.
|
||||
#[clap(long, default_value_t = 1)]
|
||||
num_replicas_per_device: usize,
|
||||
|
||||
/// *INTERNAL ONLY*
|
||||
#[clap(long, default_value_t = false)]
|
||||
experimental_admin_panel: bool,
|
||||
}
|
||||
|
||||
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||
valid_args(args);
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||
.nest("/v1", api_router(args))
|
||||
.fallback(fallback(args));
|
||||
.fallback(fallback(args.experimental_admin_panel));
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
println!("Listening at {}", address);
|
||||
|
|
@ -100,41 +76,39 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
|||
fn api_router(args: &ServeArgs) -> Router {
|
||||
Router::new()
|
||||
.route("/events", routing::post(events::log_event))
|
||||
.route("/completions", routing::post(completions::completion))
|
||||
.with_state(Arc::new(new_completion_state(args)))
|
||||
.route(
|
||||
"/completions",
|
||||
routing::post(completions::completion)
|
||||
.with_state(Arc::new(completions::CompletionState::new(args))),
|
||||
)
|
||||
.layer(CorsLayer::permissive())
|
||||
}
|
||||
|
||||
mod admin;
|
||||
fn fallback(args: &ServeArgs) -> routing::MethodRouter {
|
||||
if args.experimental_admin_panel {
|
||||
fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
|
||||
if experimental_admin_panel {
|
||||
routing::get(admin::handler)
|
||||
} else {
|
||||
routing::get(|| async { axum::response::Redirect::temporary("/swagger-ui") })
|
||||
}
|
||||
}
|
||||
|
||||
fn new_completion_state(args: &ServeArgs) -> completions::CompletionState {
|
||||
let device = format!("{}", args.device);
|
||||
let options = TextInferenceEngineCreateOptionsBuilder::default()
|
||||
.model_path(
|
||||
Path::new(&args.model)
|
||||
.join("ctranslate2")
|
||||
.join(device.clone())
|
||||
.display()
|
||||
.to_string(),
|
||||
fn valid_args(args: &ServeArgs) {
|
||||
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
|
||||
Cli::command()
|
||||
.error(
|
||||
ErrorKind::ValueValidation,
|
||||
"CUDA device only supports 1 replicas per device",
|
||||
)
|
||||
.tokenizer_path(
|
||||
Path::new(&args.model)
|
||||
.join("tokenizer.json")
|
||||
.display()
|
||||
.to_string(),
|
||||
.exit();
|
||||
}
|
||||
|
||||
if args.device == Device::CPU && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
||||
{
|
||||
Cli::command()
|
||||
.error(
|
||||
ErrorKind::ValueValidation,
|
||||
"CPU device only supports device indices = [0]",
|
||||
)
|
||||
.device(device)
|
||||
.model_type(format!("{}", args.model_type))
|
||||
.device_indices(args.device_indices.clone())
|
||||
.num_replicas_per_device(args.num_replicas_per_device)
|
||||
.build()
|
||||
.unwrap();
|
||||
completions::CompletionState::new(options)
|
||||
.exit();
|
||||
}
|
||||
}
|
||||
|
|
|
|||