feat: improve error handling and messages [TAB-58] (#213)

* add fatal macro

* switch expect to fatal

* improve error handling of serve

* improve error handling on download module

* improve error handling in scheduler

* improve error handling

* fmt

* fmt
improve-workflow
Meng Zhang 2023-06-06 19:02:58 -07:00 committed by GitHub
parent c0106ad774
commit 4cb672ec39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 179 additions and 123 deletions

View File

@ -172,8 +172,10 @@ fn reverse(s: String) -> String {
s.chars().rev().collect() s.chars().rev().collect()
} }
fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex { fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex {
let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap(); let encodings = tokenizer
.encode_batch(stop_words.to_owned(), false)
.unwrap();
let stop_tokens: Vec<String> = encodings let stop_tokens: Vec<String> = encodings
.iter() .iter()
.map(|x| x.get_tokens().join("")) .map(|x| x.get_tokens().join(""))

View File

@ -1,9 +1,12 @@
use std::path::PathBuf; use std::{
io::{Error, ErrorKind},
path::PathBuf,
};
use filenamify::filenamify; use filenamify::filenamify;
use serde::Deserialize; use serde::Deserialize;
use crate::path::repositories_dir; use crate::path::{config_file, repositories_dir};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
@ -11,8 +14,14 @@ pub struct Config {
} }
impl Config { impl Config {
pub fn load() -> Result<Self, serdeconv::Error> { pub fn load() -> Result<Self, Error> {
serdeconv::from_toml_file(crate::path::config_file().as_path()) let file = serdeconv::from_toml_file(crate::path::config_file().as_path());
file.map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("Config {:?} doesn't exist or is not valid", config_file()),
)
})
} }
} }

View File

@ -1,6 +1,6 @@
use std::{collections::HashMap, fs, path::Path}; use std::{collections::HashMap, fs, path::Path};
use anyhow::Result; use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::path::ModelDir; use tabby_common::path::ModelDir;
@ -33,12 +33,13 @@ impl CacheInfo {
self.etags.get(path).map(|x| x.as_str()) self.etags.get(path).map(|x| x.as_str())
} }
pub fn remote_cache_key(res: &reqwest::Response) -> &str { pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> {
res.headers() let key = res
.headers()
.get("etag") .get("etag")
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url())) .ok_or(anyhow!("etag key missing"))?
.to_str() .to_str()?;
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url())) Ok(key)
} }
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {

View File

@ -2,13 +2,19 @@ mod cache_info;
use std::{cmp, fs, io::Write, path::Path}; use std::{cmp, fs, io::Write, path::Path};
use anyhow::{anyhow, Result};
use cache_info::CacheInfo; use cache_info::CacheInfo;
use futures_util::StreamExt; use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle}; use indicatif::{ProgressBar, ProgressStyle};
use tabby_common::path::ModelDir; use tabby_common::path::ModelDir;
impl CacheInfo { impl CacheInfo {
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) { async fn download(
&mut self,
model_id: &str,
path: &str,
prefer_local_file: bool,
) -> Result<()> {
// Create url. // Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
@ -25,89 +31,89 @@ impl CacheInfo {
} }
if !local_file_ready { if !local_file_ready {
let etag = download_file(&url, &filepath, local_cache_key).await; let etag = download_file(&url, &filepath, local_cache_key).await?;
self.set_local_cache_key(path, &etag).await self.set_local_cache_key(path, &etag).await;
} }
Ok(())
} }
} }
pub async fn download_model(model_id: &str, prefer_local_file: bool) { pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
if fs::metadata(model_id).is_ok() { if fs::metadata(model_id).is_ok() {
// Local path, no need for downloading. // Local path, no need for downloading.
return; return Ok(());
} }
let mut cache_info = CacheInfo::from(model_id).await; let mut cache_info = CacheInfo::from(model_id).await;
cache_info cache_info
.download(model_id, "tabby.json", prefer_local_file) .download(model_id, "tabby.json", prefer_local_file)
.await; .await?;
cache_info cache_info
.download(model_id, "tokenizer.json", prefer_local_file) .download(model_id, "tokenizer.json", prefer_local_file)
.await; .await?;
cache_info cache_info
.download(model_id, "ctranslate2/config.json", prefer_local_file) .download(model_id, "ctranslate2/config.json", prefer_local_file)
.await; .await?;
cache_info cache_info
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
.await; .await?;
cache_info cache_info
.download( .download(
model_id, model_id,
"ctranslate2/shared_vocabulary.txt", "ctranslate2/shared_vocabulary.txt",
prefer_local_file, prefer_local_file,
) )
.await; .await?;
cache_info cache_info
.download(model_id, "ctranslate2/model.bin", prefer_local_file) .download(model_id, "ctranslate2/model.bin", prefer_local_file)
.await; .await?;
cache_info cache_info.save(model_id)?;
.save(model_id)
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id)); Ok(())
} }
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String { async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result<String> {
fs::create_dir_all(Path::new(path).parent().unwrap()) fs::create_dir_all(Path::new(path).parent().unwrap())?;
.unwrap_or_else(|_| panic!("Failed to create path '{}'", path));
// Reqwest setup // Reqwest setup
let res = reqwest::get(url) let res = reqwest::get(url).await?;
.await
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string(); if !res.status().is_success() {
return Err(anyhow!(format!("Invalid url: {}", url)));
}
let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string();
if let Some(local_cache_key) = local_cache_key { if let Some(local_cache_key) = local_cache_key {
if local_cache_key == remote_cache_key { if local_cache_key == remote_cache_key {
return remote_cache_key; return Ok(remote_cache_key);
} }
} }
let total_size = res let total_size = res
.content_length() .content_length()
.unwrap_or_else(|| panic!("Failed to get content length from '{}'", url)); .ok_or(anyhow!("No content length in headers"))?;
// Indicatif setup // Indicatif setup
let pb = ProgressBar::new(total_size); let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar() pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})") .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
.expect("Invalid progress style")
.progress_chars("#>-")); .progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path)); pb.set_message(format!("Downloading {}", path));
// download chunks // download chunks
let mut file = let mut file = fs::File::create(path)?;
fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path));
let mut downloaded: u64 = 0; let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream(); let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
let chunk = item.expect("Error while downloading file"); let chunk = item?;
file.write_all(&chunk).expect("Error while writing to file"); file.write_all(&chunk)?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size); let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new; downloaded = new;
pb.set_position(new); pb.set_position(new);
} }
pb.finish_with_message(format!("Downloaded {}", path)); pb.finish_with_message(format!("Downloaded {}", path));
remote_cache_key Ok(remote_cache_key)
} }

View File

@ -15,11 +15,11 @@ use tracing::{info, warn};
use walkdir::{DirEntry, WalkDir}; use walkdir::{DirEntry, WalkDir};
trait RepositoryExt { trait RepositoryExt {
fn index(&self, schema: &Schema, writer: &mut IndexWriter); fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()>;
} }
impl RepositoryExt for Repository { impl RepositoryExt for Repository {
fn index(&self, schema: &Schema, writer: &mut IndexWriter) { fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()> {
let git_url = schema.get_field("git_url").unwrap(); let git_url = schema.get_field("git_url").unwrap();
let filepath = schema.get_field("filepath").unwrap(); let filepath = schema.get_field("filepath").unwrap();
let content = schema.get_field("content").unwrap(); let content = schema.get_field("content").unwrap();
@ -36,17 +36,17 @@ impl RepositoryExt for Repository {
let relative_path = entry.path().strip_prefix(dir.as_path()).unwrap(); let relative_path = entry.path().strip_prefix(dir.as_path()).unwrap();
if let Ok(file_content) = read_to_string(entry.path()) { if let Ok(file_content) = read_to_string(entry.path()) {
info!("Indexing {:?}", relative_path); info!("Indexing {:?}", relative_path);
writer writer.add_document(doc!(
.add_document(doc!( git_url => self.git_url.clone(),
git_url => self.git_url.clone(), filepath => relative_path.display().to_string(),
filepath => relative_path.display().to_string(), content => file_content,
content => file_content, ))?;
))
.unwrap();
} else { } else {
warn!("Skip {:?}", relative_path); warn!("Skip {:?}", relative_path);
} }
} }
Ok(())
} }
} }
@ -66,18 +66,20 @@ fn create_schema() -> Schema {
builder.build() builder.build()
} }
pub fn index_repositories(config: &Config) { pub fn index_repositories(config: &Config) -> Result<()> {
let schema = create_schema(); let schema = create_schema();
fs::create_dir_all(index_dir()).unwrap(); fs::create_dir_all(index_dir())?;
let directory = MmapDirectory::open(index_dir()).unwrap(); let directory = MmapDirectory::open(index_dir())?;
let index = Index::open_or_create(directory, schema.clone()).unwrap(); let index = Index::open_or_create(directory, schema.clone())?;
let mut writer = index.writer(10_000_000).unwrap(); let mut writer = index.writer(10_000_000)?;
writer.delete_all_documents().unwrap(); writer.delete_all_documents()?;
for repository in config.repositories.as_slice() { for repository in config.repositories.as_slice() {
repository.index(&schema, &mut writer); repository.index(&schema, &mut writer)?;
} }
writer.commit().unwrap(); writer.commit()?;
Ok(())
} }

View File

@ -1,30 +1,32 @@
mod index; mod index;
mod repository; mod repository;
use anyhow::Result;
use job_scheduler::{Job, JobScheduler}; use job_scheduler::{Job, JobScheduler};
use tabby_common::config::Config; use tabby_common::config::Config;
use tracing::{error, info}; use tracing::{error, info};
pub fn scheduler(now: bool) { pub async fn scheduler(now: bool) -> Result<()> {
let config = Config::load(); let config = Config::load()?;
if config.is_err() {
error!("Please create config.toml before using scheduler");
return;
}
let config = config.unwrap();
let mut scheduler = JobScheduler::new(); let mut scheduler = JobScheduler::new();
let job = || { let job = || {
info!("Syncing repositories..."); info!("Syncing repositories...");
repository::sync_repositories(&config); let ret = repository::sync_repositories(&config);
if let Err(err) = ret {
error!("Failed to sync repositories, err: '{}'", err);
return;
}
info!("Indexing repositories..."); info!("Indexing repositories...");
index::index_repositories(&config); let ret = index::index_repositories(&config);
if let Err(err) = ret {
error!("Failed to index repositories, err: '{}'", err);
}
}; };
if now { if now {
job() job();
} else { } else {
// Every 5 hours. // Every 5 hours.
scheduler.add(Job::new("0 0 1/5 * * * *".parse().unwrap(), job)); scheduler.add(Job::new("0 0 1/5 * * * *".parse().unwrap(), job));
@ -37,6 +39,8 @@ pub fn scheduler(now: bool) {
std::thread::sleep(duration); std::thread::sleep(duration);
} }
} }
Ok(())
} }
#[cfg(test)] #[cfg(test)]
@ -61,7 +65,7 @@ mod tests {
}], }],
}; };
repository::sync_repositories(&config); repository::sync_repositories(&config).unwrap();
index::index_repositories(&config); index::index_repositories(&config).unwrap();
} }
} }

View File

@ -1,33 +1,32 @@
use std::process::Command; use std::process::Command;
use anyhow::{anyhow, Result};
use tabby_common::config::{Config, Repository}; use tabby_common::config::{Config, Repository};
trait ConfigExt { trait ConfigExt {
fn sync_repositories(&self); fn sync_repositories(&self) -> Result<()>;
} }
impl ConfigExt for Config { impl ConfigExt for Config {
fn sync_repositories(&self) { fn sync_repositories(&self) -> Result<()> {
for repository in self.repositories.iter() { for repository in self.repositories.iter() {
repository.sync() repository.sync()?;
} }
Ok(())
} }
} }
trait RepositoryExt { trait RepositoryExt {
fn sync(&self); fn sync(&self) -> Result<()>;
} }
impl RepositoryExt for Repository { impl RepositoryExt for Repository {
fn sync(&self) { fn sync(&self) -> Result<()> {
let dir = self.dir(); let dir = self.dir();
let dir_string = dir.display().to_string(); let dir_string = dir.display().to_string();
let status = if dir.exists() { let status = if dir.exists() {
Command::new("git") Command::new("git").current_dir(&dir).arg("pull").status()
.current_dir(&dir)
.arg("pull")
.status()
.expect("git could not be executed")
} else { } else {
std::fs::create_dir_all(&dir) std::fs::create_dir_all(&dir)
.unwrap_or_else(|_| panic!("Failed to create dir {}", dir_string)); .unwrap_or_else(|_| panic!("Failed to create dir {}", dir_string));
@ -39,20 +38,22 @@ impl RepositoryExt for Repository {
.arg(&self.git_url) .arg(&self.git_url)
.arg(dir) .arg(dir)
.status() .status()
.expect("git could not be executed")
}; };
if let Some(code) = status.code() { if let Some(code) = status?.code() {
if code != 0 { if code != 0 {
panic!( return Err(anyhow!(
"Failed to pull remote '{}'\nConsider remove dir '{}' and retry", "Failed to pull remote '{}'. Consider remove dir '{}' and retry",
&self.git_url, &dir_string &self.git_url,
); &dir_string
));
} }
} }
Ok(())
} }
} }
pub fn sync_repositories(config: &Config) { pub fn sync_repositories(config: &Config) -> Result<()> {
config.sync_repositories(); config.sync_repositories()
} }

View File

@ -1,6 +1,8 @@
use clap::Args; use clap::Args;
use tracing::info; use tracing::info;
use crate::fatal;
#[derive(Args)] #[derive(Args)]
pub struct DownloadArgs { pub struct DownloadArgs {
/// model id to fetch. /// model id to fetch.
@ -13,6 +15,14 @@ pub struct DownloadArgs {
} }
pub async fn main(args: &DownloadArgs) { pub async fn main(args: &DownloadArgs) {
tabby_download::download_model(&args.model, args.prefer_local_file).await; tabby_download::download_model(&args.model, args.prefer_local_file)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
info!("model '{}' is ready", args.model); info!("model '{}' is ready", args.model);
} }

View File

@ -14,13 +14,13 @@ struct Cli {
#[derive(Subcommand)] #[derive(Subcommand)]
pub enum Commands { pub enum Commands {
/// Serve the model /// Starts the api endpoint for IDE / Editor extensions.
Serve(serve::ServeArgs), Serve(serve::ServeArgs),
/// Download the model /// Download the language model for serving.
Download(download::DownloadArgs), Download(download::DownloadArgs),
/// Starts the scheduler process. /// Run scheduler progress for cron jobs integrating external code repositories.
Scheduler(SchedulerArgs), Scheduler(SchedulerArgs),
} }
@ -42,6 +42,25 @@ async fn main() {
match &cli.command { match &cli.command {
Commands::Serve(args) => serve::main(args).await, Commands::Serve(args) => serve::main(args).await,
Commands::Download(args) => download::main(args).await, Commands::Download(args) => download::main(args).await,
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now), Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
} }
} }
#[macro_export]
macro_rules! fatal {
($msg:expr) => {
({
tracing::error!($msg);
std::process::exit(1);
})
};
($fmt:expr, $($arg:tt)*) => {
({
tracing::error!($fmt, $($arg)*);
std::process::exit(1);
})
};
}

View File

@ -4,6 +4,8 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use crate::fatal;
#[derive(rust_embed::RustEmbed)] #[derive(rust_embed::RustEmbed)]
#[folder = "../tabby-admin/dist/"] #[folder = "../tabby-admin/dist/"]
struct AdminAssets; struct AdminAssets;
@ -26,12 +28,12 @@ where
Response::builder() Response::builder()
.header(header::CONTENT_TYPE, mime.as_ref()) .header(header::CONTENT_TYPE, mime.as_ref())
.body(body) .body(body)
.expect("Invalid response") .unwrap_or_else(|_| fatal!("Invalid response"))
} }
None => Response::builder() None => Response::builder()
.status(StatusCode::NOT_FOUND) .status(StatusCode::NOT_FOUND)
.body(boxed(Full::from("404"))) .body(boxed(Full::from("404")))
.expect("Invalid response"), .unwrap_or_else(|_| fatal!("Invalid response")),
} }
} }
} }

View File

@ -4,12 +4,14 @@ use axum::{extract::State, Json};
use ctranslate2_bindings::{ use ctranslate2_bindings::{
TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder, TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder,
}; };
use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use strfmt::{strfmt, strfmt_builder}; use strfmt::{strfmt, strfmt_builder};
use tabby_common::{events, path::ModelDir}; use tabby_common::{events, path::ModelDir};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_stop_words; use self::languages::get_stop_words;
use crate::fatal;
mod languages; mod languages;
@ -58,20 +60,19 @@ pub struct CompletionResponse {
pub async fn completion( pub async fn completion(
State(state): State<Arc<CompletionState>>, State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Json<CompletionResponse> { ) -> Result<Json<CompletionResponse>, StatusCode> {
let language = request.language.unwrap_or("unknown".into()); let language = request.language.unwrap_or("unknown".to_string());
let options = TextInferenceOptionsBuilder::default() let options = TextInferenceOptionsBuilder::default()
.max_decoding_length(128) .max_decoding_length(128)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(get_stop_words(&language)) .stop_words(get_stop_words(&language))
.build() .build()
.expect("Invalid TextInferenceOptions"); .unwrap();
let prompt = if let Some(Segments { prefix, suffix }) = request.segments { let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
if let Some(prompt_template) = &state.prompt_template { if let Some(prompt_template) = &state.prompt_template {
if let Some(suffix) = suffix { if let Some(suffix) = suffix {
strfmt!(prompt_template, prefix => prefix, suffix => suffix) strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
.expect("Failed to format prompt")
} else { } else {
// If suffix is empty, just returns prefix. // If suffix is empty, just returns prefix.
prefix prefix
@ -81,7 +82,7 @@ pub async fn completion(
prefix prefix
} }
} else { } else {
request.prompt.expect("No prompt is set") return Err(StatusCode::BAD_REQUEST);
}; };
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
@ -98,10 +99,10 @@ pub async fn completion(
} }
.log(); .log();
Json(CompletionResponse { Ok(Json(CompletionResponse {
id: completion_id, id: completion_id,
choices: vec![Choice { index: 0, text }], choices: vec![Choice { index: 0, text }],
}) }))
} }
pub struct CompletionState { pub struct CompletionState {
@ -123,7 +124,7 @@ impl CompletionState {
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(args.num_replicas_per_device)
.build() .build()
.expect("Invalid TextInferenceEngineCreateOptions"); .unwrap();
let engine = TextInferenceEngine::create(options); let engine = TextInferenceEngine::create(options);
Self { Self {
engine, engine,
@ -147,5 +148,6 @@ struct Metadata {
} }
fn read_metadata(model_dir: &ModelDir) -> Metadata { fn read_metadata(model_dir: &ModelDir) -> Metadata {
serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata") serdeconv::from_json_file(model_dir.metadata_file())
.unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file()))
} }

View File

@ -8,13 +8,13 @@ use std::{
}; };
use axum::{routing, Router, Server}; use axum::{routing, Router, Server};
use clap::{error::ErrorKind, Args, CommandFactory}; use clap::Args;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use tracing::info; use tracing::info;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use crate::Cli; use crate::fatal;
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -68,7 +68,15 @@ pub async fn main(args: &ServeArgs) {
valid_args(args); valid_args(args);
// Ensure model exists. // Ensure model exists.
tabby_download::download_model(&args.model, true).await; tabby_download::download_model(&args.model, true)
.await
.unwrap_or_else(|err| {
fatal!(
"Failed to fetch model due to '{}', is '{}' a valid model id?",
err,
args.model
)
});
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
@ -81,7 +89,7 @@ pub async fn main(args: &ServeArgs) {
Server::bind(&address) Server::bind(&address)
.serve(app.into_make_service()) .serve(app.into_make_service())
.await .await
.expect("Error happends during model serving") .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
} }
fn api_router(args: &ServeArgs) -> Router { fn api_router(args: &ServeArgs) -> Router {
@ -104,21 +112,11 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
fn valid_args(args: &ServeArgs) { fn valid_args(args: &ServeArgs) {
if args.device == Device::Cuda && args.num_replicas_per_device != 1 { if args.device == Device::Cuda && args.num_replicas_per_device != 1 {
Cli::command() fatal!("CUDA device only supports 1 replicas per device");
.error(
ErrorKind::ValueValidation,
"CUDA device only supports 1 replicas per device",
)
.exit();
} }
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{ {
Cli::command() fatal!("CPU device only supports device indices = [0]");
.error(
ErrorKind::ValueValidation,
"CPU device only supports device indices = [0]",
)
.exit();
} }
} }