From 4cb672ec39841d70f02a1574071bc89322037127 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 6 Jun 2023 19:02:58 -0700 Subject: [PATCH] feat: improve error handling and messages [TAB-58] (#213) * add fatal macro * switch expect to fatal * improve error handling of serve * improve error handling on download module * improve error handling in scheduler * improve error handling * fmt * fmt --- crates/ctranslate2-bindings/src/lib.rs | 6 ++- crates/tabby-common/src/config.rs | 17 ++++-- crates/tabby-download/src/cache_info.rs | 13 ++--- crates/tabby-download/src/lib.rs | 66 +++++++++++++----------- crates/tabby-scheduler/src/index.rs | 36 +++++++------ crates/tabby-scheduler/src/lib.rs | 30 ++++++----- crates/tabby-scheduler/src/repository.rs | 37 ++++++------- crates/tabby/src/download.rs | 12 ++++- crates/tabby/src/main.rs | 27 ++++++++-- crates/tabby/src/serve/admin.rs | 6 ++- crates/tabby/src/serve/completions.rs | 22 ++++---- crates/tabby/src/serve/mod.rs | 30 +++++------ 12 files changed, 179 insertions(+), 123 deletions(-) diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 62ed19b..152c6d9 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -172,8 +172,10 @@ fn reverse(s: String) -> String { s.chars().rev().collect() } -fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &Vec<&str>) -> Regex { - let encodings = tokenizer.encode_batch(stop_words.clone(), false).unwrap(); +fn create_stop_regex(tokenizer: &Tokenizer, stop_words: &[&str]) -> Regex { + let encodings = tokenizer + .encode_batch(stop_words.to_owned(), false) + .unwrap(); let stop_tokens: Vec = encodings .iter() .map(|x| x.get_tokens().join("")) diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 8b4df9a..f13b932 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -1,9 +1,12 @@ -use std::path::PathBuf; +use std::{ + io::{Error, ErrorKind}, + path::PathBuf, +}; use filenamify::filenamify; use serde::Deserialize; -use crate::path::repositories_dir; +use crate::path::{config_file, repositories_dir}; #[derive(Deserialize)] pub struct Config { @@ -11,8 +14,14 @@ pub struct Config { } impl Config { - pub fn load() -> Result { - serdeconv::from_toml_file(crate::path::config_file().as_path()) + pub fn load() -> Result { + let file = serdeconv::from_toml_file(crate::path::config_file().as_path()); + file.map_err(|_| { + Error::new( + ErrorKind::InvalidData, + format!("Config {:?} doesn't exist or is not valid", config_file()), + ) + }) } } diff --git a/crates/tabby-download/src/cache_info.rs b/crates/tabby-download/src/cache_info.rs index 27029e5..8b308a0 100644 --- a/crates/tabby-download/src/cache_info.rs +++ b/crates/tabby-download/src/cache_info.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, fs, path::Path}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use serde::{Deserialize, Serialize}; use tabby_common::path::ModelDir; @@ -33,12 +33,13 @@ impl CacheInfo { self.etags.get(path).map(|x| x.as_str()) } - pub fn remote_cache_key(res: &reqwest::Response) -> &str { - res.headers() + pub fn remote_cache_key(res: &reqwest::Response) -> Result<&str> { + let key = res + .headers() .get("etag") - .unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url())) - .to_str() - .unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url())) + .ok_or(anyhow!("etag key missing"))? + .to_str()?; + Ok(key) } pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index f582c90..1e1a6c3 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -2,13 +2,19 @@ mod cache_info; use std::{cmp, fs, io::Write, path::Path}; +use anyhow::{anyhow, Result}; use cache_info::CacheInfo; use futures_util::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; use tabby_common::path::ModelDir; impl CacheInfo { - async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) { + async fn download( + &mut self, + model_id: &str, + path: &str, + prefer_local_file: bool, + ) -> Result<()> { // Create url. let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); @@ -25,89 +31,89 @@ impl CacheInfo { } if !local_file_ready { - let etag = download_file(&url, &filepath, local_cache_key).await; - self.set_local_cache_key(path, &etag).await + let etag = download_file(&url, &filepath, local_cache_key).await?; + self.set_local_cache_key(path, &etag).await; } + Ok(()) } } -pub async fn download_model(model_id: &str, prefer_local_file: bool) { +pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> { if fs::metadata(model_id).is_ok() { // Local path, no need for downloading. - return; + return Ok(()); } let mut cache_info = CacheInfo::from(model_id).await; cache_info .download(model_id, "tabby.json", prefer_local_file) - .await; + .await?; cache_info .download(model_id, "tokenizer.json", prefer_local_file) - .await; + .await?; cache_info .download(model_id, "ctranslate2/config.json", prefer_local_file) - .await; + .await?; cache_info .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) - .await; + .await?; cache_info .download( model_id, "ctranslate2/shared_vocabulary.txt", prefer_local_file, ) - .await; + .await?; cache_info .download(model_id, "ctranslate2/model.bin", prefer_local_file) - .await; - cache_info - .save(model_id) - .unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id)); + .await?; + cache_info.save(model_id)?; + + Ok(()) } -async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String { - fs::create_dir_all(Path::new(path).parent().unwrap()) - .unwrap_or_else(|_| panic!("Failed to create path '{}'", path)); +async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> Result { + fs::create_dir_all(Path::new(path).parent().unwrap())?; // Reqwest setup - let res = reqwest::get(url) - .await - .unwrap_or_else(|_| panic!("Failed to GET from '{}'", url)); + let res = reqwest::get(url).await?; - let remote_cache_key = CacheInfo::remote_cache_key(&res).to_string(); + if !res.status().is_success() { + return Err(anyhow!(format!("Invalid url: {}", url))); + } + + let remote_cache_key = CacheInfo::remote_cache_key(&res)?.to_string(); if let Some(local_cache_key) = local_cache_key { if local_cache_key == remote_cache_key { - return remote_cache_key; + return Ok(remote_cache_key); } } let total_size = res .content_length() - .unwrap_or_else(|| panic!("Failed to get content length from '{}'", url)); + .ok_or(anyhow!("No content length in headers"))?; // Indicatif setup let pb = ProgressBar::new(total_size); pb.set_style(ProgressStyle::default_bar() - .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})") - .expect("Invalid progress style") + .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? .progress_chars("#>-")); pb.set_message(format!("Downloading {}", path)); // download chunks - let mut file = - fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path)); + let mut file = fs::File::create(path)?; let mut downloaded: u64 = 0; let mut stream = res.bytes_stream(); while let Some(item) = stream.next().await { - let chunk = item.expect("Error while downloading file"); - file.write_all(&chunk).expect("Error while writing to file"); + let chunk = item?; + file.write_all(&chunk)?; let new = cmp::min(downloaded + (chunk.len() as u64), total_size); downloaded = new; pb.set_position(new); } pb.finish_with_message(format!("Downloaded {}", path)); - remote_cache_key + Ok(remote_cache_key) } diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index 9aeb40a..8ac2499 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -15,11 +15,11 @@ use tracing::{info, warn}; use walkdir::{DirEntry, WalkDir}; trait RepositoryExt { - fn index(&self, schema: &Schema, writer: &mut IndexWriter); + fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()>; } impl RepositoryExt for Repository { - fn index(&self, schema: &Schema, writer: &mut IndexWriter) { + fn index(&self, schema: &Schema, writer: &mut IndexWriter) -> Result<()> { let git_url = schema.get_field("git_url").unwrap(); let filepath = schema.get_field("filepath").unwrap(); let content = schema.get_field("content").unwrap(); @@ -36,17 +36,17 @@ impl RepositoryExt for Repository { let relative_path = entry.path().strip_prefix(dir.as_path()).unwrap(); if let Ok(file_content) = read_to_string(entry.path()) { info!("Indexing {:?}", relative_path); - writer - .add_document(doc!( - git_url => self.git_url.clone(), - filepath => relative_path.display().to_string(), - content => file_content, - )) - .unwrap(); + writer.add_document(doc!( + git_url => self.git_url.clone(), + filepath => relative_path.display().to_string(), + content => file_content, + ))?; } else { warn!("Skip {:?}", relative_path); } } + + Ok(()) } } @@ -66,18 +66,20 @@ fn create_schema() -> Schema { builder.build() } -pub fn index_repositories(config: &Config) { +pub fn index_repositories(config: &Config) -> Result<()> { let schema = create_schema(); - fs::create_dir_all(index_dir()).unwrap(); - let directory = MmapDirectory::open(index_dir()).unwrap(); - let index = Index::open_or_create(directory, schema.clone()).unwrap(); - let mut writer = index.writer(10_000_000).unwrap(); + fs::create_dir_all(index_dir())?; + let directory = MmapDirectory::open(index_dir())?; + let index = Index::open_or_create(directory, schema.clone())?; + let mut writer = index.writer(10_000_000)?; - writer.delete_all_documents().unwrap(); + writer.delete_all_documents()?; for repository in config.repositories.as_slice() { - repository.index(&schema, &mut writer); + repository.index(&schema, &mut writer)?; } - writer.commit().unwrap(); + writer.commit()?; + + Ok(()) } diff --git a/crates/tabby-scheduler/src/lib.rs b/crates/tabby-scheduler/src/lib.rs index 9706fab..6e39a34 100644 --- a/crates/tabby-scheduler/src/lib.rs +++ b/crates/tabby-scheduler/src/lib.rs @@ -1,30 +1,32 @@ mod index; mod repository; +use anyhow::Result; use job_scheduler::{Job, JobScheduler}; use tabby_common::config::Config; use tracing::{error, info}; -pub fn scheduler(now: bool) { - let config = Config::load(); - if config.is_err() { - error!("Please create config.toml before using scheduler"); - return; - } - - let config = config.unwrap(); +pub async fn scheduler(now: bool) -> Result<()> { + let config = Config::load()?; let mut scheduler = JobScheduler::new(); let job = || { info!("Syncing repositories..."); - repository::sync_repositories(&config); + let ret = repository::sync_repositories(&config); + if let Err(err) = ret { + error!("Failed to sync repositories, err: '{}'", err); + return; + } info!("Indexing repositories..."); - index::index_repositories(&config); + let ret = index::index_repositories(&config); + if let Err(err) = ret { + error!("Failed to index repositories, err: '{}'", err); + } }; if now { - job() + job(); } else { // Every 5 hours. scheduler.add(Job::new("0 0 1/5 * * * *".parse().unwrap(), job)); @@ -37,6 +39,8 @@ pub fn scheduler(now: bool) { std::thread::sleep(duration); } } + + Ok(()) } #[cfg(test)] @@ -61,7 +65,7 @@ mod tests { }], }; - repository::sync_repositories(&config); - index::index_repositories(&config); + repository::sync_repositories(&config).unwrap(); + index::index_repositories(&config).unwrap(); } } diff --git a/crates/tabby-scheduler/src/repository.rs b/crates/tabby-scheduler/src/repository.rs index 4c14f9c..48a52a9 100644 --- a/crates/tabby-scheduler/src/repository.rs +++ b/crates/tabby-scheduler/src/repository.rs @@ -1,33 +1,32 @@ use std::process::Command; +use anyhow::{anyhow, Result}; use tabby_common::config::{Config, Repository}; trait ConfigExt { - fn sync_repositories(&self); + fn sync_repositories(&self) -> Result<()>; } impl ConfigExt for Config { - fn sync_repositories(&self) { + fn sync_repositories(&self) -> Result<()> { for repository in self.repositories.iter() { - repository.sync() + repository.sync()?; } + + Ok(()) } } trait RepositoryExt { - fn sync(&self); + fn sync(&self) -> Result<()>; } impl RepositoryExt for Repository { - fn sync(&self) { + fn sync(&self) -> Result<()> { let dir = self.dir(); let dir_string = dir.display().to_string(); let status = if dir.exists() { - Command::new("git") - .current_dir(&dir) - .arg("pull") - .status() - .expect("git could not be executed") + Command::new("git").current_dir(&dir).arg("pull").status() } else { std::fs::create_dir_all(&dir) .unwrap_or_else(|_| panic!("Failed to create dir {}", dir_string)); @@ -39,20 +38,22 @@ impl RepositoryExt for Repository { .arg(&self.git_url) .arg(dir) .status() - .expect("git could not be executed") }; - if let Some(code) = status.code() { + if let Some(code) = status?.code() { if code != 0 { - panic!( - "Failed to pull remote '{}'\nConsider remove dir '{}' and retry", - &self.git_url, &dir_string - ); + return Err(anyhow!( + "Failed to pull remote '{}'. Consider remove dir '{}' and retry", + &self.git_url, + &dir_string + )); } } + + Ok(()) } } -pub fn sync_repositories(config: &Config) { - config.sync_repositories(); +pub fn sync_repositories(config: &Config) -> Result<()> { + config.sync_repositories() } diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index d154f25..7ed994f 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,6 +1,8 @@ use clap::Args; use tracing::info; +use crate::fatal; + #[derive(Args)] pub struct DownloadArgs { /// model id to fetch. @@ -13,6 +15,14 @@ pub struct DownloadArgs { } pub async fn main(args: &DownloadArgs) { - tabby_download::download_model(&args.model, args.prefer_local_file).await; + tabby_download::download_model(&args.model, args.prefer_local_file) + .await + .unwrap_or_else(|err| { + fatal!( + "Failed to fetch model due to '{}', is '{}' a valid model id?", + err, + args.model + ) + }); info!("model '{}' is ready", args.model); } diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 0150175..d48a9e1 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -14,13 +14,13 @@ struct Cli { #[derive(Subcommand)] pub enum Commands { - /// Serve the model + /// Starts the api endpoint for IDE / Editor extensions. Serve(serve::ServeArgs), - /// Download the model + /// Download the language model for serving. Download(download::DownloadArgs), - /// Starts the scheduler process. + /// Run scheduler progress for cron jobs integrating external code repositories. Scheduler(SchedulerArgs), } @@ -42,6 +42,25 @@ async fn main() { match &cli.command { Commands::Serve(args) => serve::main(args).await, Commands::Download(args) => download::main(args).await, - Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now), + Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) + .await + .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), } } + +#[macro_export] +macro_rules! fatal { + ($msg:expr) => { + ({ + tracing::error!($msg); + std::process::exit(1); + }) + }; + + ($fmt:expr, $($arg:tt)*) => { + ({ + tracing::error!($fmt, $($arg)*); + std::process::exit(1); + }) + }; +} diff --git a/crates/tabby/src/serve/admin.rs b/crates/tabby/src/serve/admin.rs index 357c05a..b6c2c40 100644 --- a/crates/tabby/src/serve/admin.rs +++ b/crates/tabby/src/serve/admin.rs @@ -4,6 +4,8 @@ use axum::{ response::{IntoResponse, Response}, }; +use crate::fatal; + #[derive(rust_embed::RustEmbed)] #[folder = "../tabby-admin/dist/"] struct AdminAssets; @@ -26,12 +28,12 @@ where Response::builder() .header(header::CONTENT_TYPE, mime.as_ref()) .body(body) - .expect("Invalid response") + .unwrap_or_else(|_| fatal!("Invalid response")) } None => Response::builder() .status(StatusCode::NOT_FOUND) .body(boxed(Full::from("404"))) - .expect("Invalid response"), + .unwrap_or_else(|_| fatal!("Invalid response")), } } } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 0e4c136..74296ae 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -4,12 +4,14 @@ use axum::{extract::State, Json}; use ctranslate2_bindings::{ TextInferenceEngine, TextInferenceEngineCreateOptionsBuilder, TextInferenceOptionsBuilder, }; +use hyper::StatusCode; use serde::{Deserialize, Serialize}; use strfmt::{strfmt, strfmt_builder}; use tabby_common::{events, path::ModelDir}; use utoipa::ToSchema; use self::languages::get_stop_words; +use crate::fatal; mod languages; @@ -58,20 +60,19 @@ pub struct CompletionResponse { pub async fn completion( State(state): State>, Json(request): Json, -) -> Json { - let language = request.language.unwrap_or("unknown".into()); +) -> Result, StatusCode> { + let language = request.language.unwrap_or("unknown".to_string()); let options = TextInferenceOptionsBuilder::default() .max_decoding_length(128) .sampling_temperature(0.1) .stop_words(get_stop_words(&language)) .build() - .expect("Invalid TextInferenceOptions"); + .unwrap(); let prompt = if let Some(Segments { prefix, suffix }) = request.segments { if let Some(prompt_template) = &state.prompt_template { if let Some(suffix) = suffix { - strfmt!(prompt_template, prefix => prefix, suffix => suffix) - .expect("Failed to format prompt") + strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() } else { // If suffix is empty, just returns prefix. prefix @@ -81,7 +82,7 @@ pub async fn completion( prefix } } else { - request.prompt.expect("No prompt is set") + return Err(StatusCode::BAD_REQUEST); }; let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); @@ -98,10 +99,10 @@ pub async fn completion( } .log(); - Json(CompletionResponse { + Ok(Json(CompletionResponse { id: completion_id, choices: vec![Choice { index: 0, text }], - }) + })) } pub struct CompletionState { @@ -123,7 +124,7 @@ impl CompletionState { .device_indices(args.device_indices.clone()) .num_replicas_per_device(args.num_replicas_per_device) .build() - .expect("Invalid TextInferenceEngineCreateOptions"); + .unwrap(); let engine = TextInferenceEngine::create(options); Self { engine, @@ -147,5 +148,6 @@ struct Metadata { } fn read_metadata(model_dir: &ModelDir) -> Metadata { - serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata") + serdeconv::from_json_file(model_dir.metadata_file()) + .unwrap_or_else(|_| fatal!("Invalid metadata file: {}", model_dir.metadata_file())) } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 81b8da7..39cc746 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -8,13 +8,13 @@ use std::{ }; use axum::{routing, Router, Server}; -use clap::{error::ErrorKind, Args, CommandFactory}; +use clap::Args; use tower_http::cors::CorsLayer; use tracing::info; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::Cli; +use crate::fatal; #[derive(OpenApi)] #[openapi( @@ -68,7 +68,15 @@ pub async fn main(args: &ServeArgs) { valid_args(args); // Ensure model exists. - tabby_download::download_model(&args.model, true).await; + tabby_download::download_model(&args.model, true) + .await + .unwrap_or_else(|err| { + fatal!( + "Failed to fetch model due to '{}', is '{}' a valid model id?", + err, + args.model + ) + }); let app = Router::new() .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) @@ -81,7 +89,7 @@ pub async fn main(args: &ServeArgs) { Server::bind(&address) .serve(app.into_make_service()) .await - .expect("Error happends during model serving") + .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } fn api_router(args: &ServeArgs) -> Router { @@ -104,21 +112,11 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter { fn valid_args(args: &ServeArgs) { if args.device == Device::Cuda && args.num_replicas_per_device != 1 { - Cli::command() - .error( - ErrorKind::ValueValidation, - "CUDA device only supports 1 replicas per device", - ) - .exit(); + fatal!("CUDA device only supports 1 replicas per device"); } if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) { - Cli::command() - .error( - ErrorKind::ValueValidation, - "CPU device only supports device indices = [0]", - ) - .exit(); + fatal!("CPU device only supports device indices = [0]"); } }