refactor: improve error handlings, fix clippy warnings (#181)

* refactor: minor improvements on error handling

* refactor: cleanup error handlings

* update

* update

* fix

* add clippy / test workflow

* fix clippy

* fix clippy

* update
support-coreml
Meng Zhang 2023-06-01 17:23:05 -07:00 committed by GitHub
parent 4c6f1338a8
commit 3cac2607e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 97 additions and 112 deletions

View File

@ -41,7 +41,7 @@ jobs:
run: cargo fmt --check run: cargo fmt --check
release-docker: release-docker:
needs: tests needs: release-binary
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: read contents: read
@ -134,6 +134,7 @@ jobs:
with: with:
toolchain: stable toolchain: stable
target: ${{ matrix.target }} target: ${{ matrix.target }}
components: clippy
- name: Sccache cache - name: Sccache cache
uses: mozilla-actions/sccache-action@v0.0.3 uses: mozilla-actions/sccache-action@v0.0.3
@ -151,6 +152,12 @@ jobs:
~/.cargo/registry ~/.cargo/registry
~/.cargo/git ~/.cargo/git
- name: Run cargo clippy check
run: cargo clippy -- -Dwarnings
- name: Run cargo tests
run: cargo test
- name: Bulid release binary - name: Bulid release binary
run: cargo build --release --target ${{ matrix.target }} run: cargo build --release --target ${{ matrix.target }}

View File

@ -58,11 +58,9 @@ fn link_static() -> PathBuf {
let dst = config.build(); let dst = config.build();
// Read static lib from generated deps. // Read static lib from generated deps.
let cmake_generated_libs_str = std::fs::read_to_string( let cmake_generated_libs_str =
&format!("/{}/build/cmake_generated_libs", dst.display()).to_string(), std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap();
)
.unwrap();
read_cmake_generated(&cmake_generated_libs_str); read_cmake_generated(&cmake_generated_libs_str);
return dst; dst
} }

View File

@ -67,7 +67,7 @@ impl Event<'_> {
writer.by_ref(), writer.by_ref(),
) )
.unwrap(); .unwrap();
write!(writer, "\n").unwrap(); writeln!(writer).unwrap();
writer.flush().unwrap(); writer.flush().unwrap();
} }
} }

View File

@ -33,11 +33,11 @@ impl ModelDir {
} }
pub fn metadata_file(&self) -> String { pub fn metadata_file(&self) -> String {
return self.path_string("metadata.json"); self.path_string("metadata.json")
} }
pub fn tokenizer_file(&self) -> String { pub fn tokenizer_file(&self) -> String {
return self.path_string("tokenizer.json"); self.path_string("tokenizer.json")
} }
pub fn ctranslate2_dir(&self) -> String { pub fn ctranslate2_dir(&self) -> String {

View File

@ -1,4 +1,4 @@
use anyhow::{anyhow, Result}; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
@ -17,10 +17,14 @@ struct HFMetadata {
} }
impl HFMetadata { impl HFMetadata {
async fn from(model_id: &str) -> Result<HFMetadata> { async fn from(model_id: &str) -> HFMetadata {
let api_url = format!("https://huggingface.co/api/models/{}", model_id); let api_url = format!("https://huggingface.co/api/models/{}", model_id);
let metadata = reqwest::get(api_url).await?.json::<HFMetadata>().await?; reqwest::get(&api_url)
Ok(metadata) .await
.unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url))
.json::<HFMetadata>()
.await
.unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url))
} }
} }
@ -31,16 +35,15 @@ pub struct Metadata {
} }
impl Metadata { impl Metadata {
pub async fn from(model_id: &str) -> Result<Metadata> { pub async fn from(model_id: &str) -> Metadata {
if let Some(metadata) = Self::from_local(model_id) { if let Some(metadata) = Self::from_local(model_id) {
Ok(metadata) metadata
} else { } else {
let hf_metadata = HFMetadata::from(model_id).await?; let hf_metadata = HFMetadata::from(model_id).await;
let metadata = Metadata { Metadata {
auto_model: hf_metadata.transformers_info.auto_model, auto_model: hf_metadata.transformers_info.auto_model,
etags: HashMap::new(), etags: HashMap::new(),
}; }
Ok(metadata)
} }
} }
@ -54,28 +57,20 @@ impl Metadata {
} }
} }
pub fn has_etag(&self, url: &str) -> bool { pub fn local_cache_key(&self, path: &str) -> Option<&str> {
self.etags.get(url).is_some() self.etags.get(path).map(|x| x.as_str())
} }
pub async fn match_etag(&self, url: &str, path: &str) -> Result<bool> { pub fn remote_cache_key(res: &reqwest::Response) -> &str {
let etag = self res.headers()
.etags
.get(url)
.ok_or(anyhow!("Path doesn't exist: {}", path))?;
let etag_from_header = reqwest::get(url)
.await?
.headers()
.get("etag") .get("etag")
.ok_or(anyhow!("URL doesn't have etag header: '{}'", url))? .unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
.to_str()? .to_str()
.to_owned(); .unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
Ok(etag == &etag_from_header)
} }
pub async fn update_etag(&mut self, url: &str, path: &str) { pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
self.etags.insert(url.to_owned(), path.to_owned()); self.etags.insert(path.to_string(), cache_key.to_string());
} }
pub fn save(&self, model_id: &str) -> Result<()> { pub fn save(&self, model_id: &str) -> Result<()> {
@ -92,7 +87,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_hf() { async fn test_hf() {
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap(); let hf_metadata = HFMetadata::from("TabbyML/J-350M").await;
assert_eq!( assert_eq!(
hf_metadata.transformers_info.auto_model, hf_metadata.transformers_info.auto_model,
"AutoModelForCausalLM" "AutoModelForCausalLM"

View File

@ -1,6 +1,5 @@
mod metadata; mod metadata;
use anyhow::{anyhow, Result};
use std::cmp; use std::cmp;
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
@ -18,111 +17,108 @@ pub struct DownloadArgs {
model: String, model: String,
/// If true, skip checking for remote model file. /// If true, skip checking for remote model file.
#[clap(long, default_value_t = true)] #[clap(long, default_value_t = false)]
prefer_local_file: bool, prefer_local_file: bool,
} }
pub async fn main(args: &DownloadArgs) -> Result<()> { pub async fn main(args: &DownloadArgs) {
download_model(&args.model, args.prefer_local_file).await?; download_model(&args.model, args.prefer_local_file).await;
Ok(()) println!("model '{}' is ready", args.model);
} }
impl metadata::Metadata { impl metadata::Metadata {
async fn download( async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
&mut self,
model_id: &str,
path: &str,
prefer_local_file: bool,
) -> Result<()> {
// Create url. // Create url.
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
// Get cache key.
let local_cache_key = self.local_cache_key(path);
// Create destination path. // Create destination path.
let filepath = ModelDir::new(model_id).path_string(path); let filepath = ModelDir::new(model_id).path_string(path);
// Cache hit. // Cache hit.
let mut cache_hit = false; let mut local_file_ready = false;
if fs::metadata(&filepath).is_ok() && self.has_etag(&url) { if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() {
if prefer_local_file || self.match_etag(&url, path).await? { local_file_ready = true;
cache_hit = true }
if !local_file_ready {
let etag = download_file(&url, &filepath, local_cache_key).await;
self.set_local_cache_key(path, &etag).await
}
} }
} }
if !cache_hit { pub async fn download_model(model_id: &str, prefer_local_file: bool) {
let etag = download_file(&url, &filepath).await?; let mut metadata = metadata::Metadata::from(model_id).await;
self.update_etag(&url, &etag).await
}
Ok(())
}
}
pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
let mut metadata = metadata::Metadata::from(model_id).await?;
metadata metadata
.download(model_id, "tokenizer.json", prefer_local_file) .download(model_id, "tokenizer.json", prefer_local_file)
.await?; .await;
metadata metadata
.download(model_id, "ctranslate2/config.json", prefer_local_file) .download(model_id, "ctranslate2/config.json", prefer_local_file)
.await?; .await;
metadata metadata
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
.await?; .await;
metadata metadata
.download( .download(
model_id, model_id,
"ctranslate2/shared_vocabulary.txt", "ctranslate2/shared_vocabulary.txt",
prefer_local_file, prefer_local_file,
) )
.await?; .await;
metadata metadata
.download(model_id, "ctranslate2/model.bin", prefer_local_file) .download(model_id, "ctranslate2/model.bin", prefer_local_file)
.await?; .await;
metadata.save(model_id)?; metadata
Ok(()) .save(model_id)
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
} }
async fn download_file(url: &str, path: &str) -> Result<String> { async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String {
fs::create_dir_all(Path::new(path).parent().unwrap())?; fs::create_dir_all(Path::new(path).parent().unwrap())
.unwrap_or_else(|_| panic!("Failed to create path '{}'", path));
// Reqwest setup // Reqwest setup
let res = reqwest::get(url) let res = reqwest::get(url)
.await .await
.or(Err(anyhow!("Failed to GET from '{}'", url)))?; .unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
let etag = res let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string();
.headers() if let Some(local_cache_key) = local_cache_key {
.get("etag") if local_cache_key == remote_cache_key {
.ok_or(anyhow!("Failed to get etag from '{}", url))? return remote_cache_key;
.to_str()? }
.to_string(); }
let total_size = res let total_size = res
.content_length() .content_length()
.ok_or(anyhow!("Failed to get content length from '{}'", url))?; .unwrap_or_else(|| panic!("Failed to get content length from '{}'", url));
// Indicatif setup // Indicatif setup
let pb = ProgressBar::new(total_size); let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar() pb.set_style(ProgressStyle::default_bar()
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
.expect("Invalid progress style")
.progress_chars("#>-")); .progress_chars("#>-"));
pb.set_message(format!("Downloading {}", path)); pb.set_message(format!("Downloading {}", path));
// download chunks // download chunks
let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?; let mut file =
fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path));
let mut downloaded: u64 = 0; let mut downloaded: u64 = 0;
let mut stream = res.bytes_stream(); let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
let chunk = item.or(Err(anyhow!("Error while downloading file")))?; let chunk = item.expect("Error while downloading file");
file.write_all(&chunk) file.write_all(&chunk).expect("Error while writing to file");
.or(Err(anyhow!("Error while writing to file")))?;
let new = cmp::min(downloaded + (chunk.len() as u64), total_size); let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
downloaded = new; downloaded = new;
pb.set_position(new); pb.set_position(new);
} }
pb.finish_with_message(format!("Downloaded {}", path)); pb.finish_with_message(format!("Downloaded {}", path));
return Ok(etag); remote_cache_key
} }

View File

@ -25,15 +25,7 @@ async fn main() {
let cli = Cli::parse(); let cli = Cli::parse();
match &cli.command { match &cli.command {
Commands::Serve(args) => { Commands::Serve(args) => serve::main(args).await,
serve::main(args) Commands::Download(args) => download::main(args).await,
.await
.expect("Error happens during the serve");
}
Commands::Download(args) => {
download::main(args)
.await
.expect("Error happens during the download");
}
} }
} }

View File

@ -26,12 +26,12 @@ where
Response::builder() Response::builder()
.header(header::CONTENT_TYPE, mime.as_ref()) .header(header::CONTENT_TYPE, mime.as_ref())
.body(body) .body(body)
.unwrap() .expect("Invalid response")
} }
None => Response::builder() None => Response::builder()
.status(StatusCode::NOT_FOUND) .status(StatusCode::NOT_FOUND)
.body(boxed(Full::from("404"))) .body(boxed(Full::from("404")))
.unwrap(), .expect("Invalid response"),
} }
} }
} }

View File

@ -45,7 +45,7 @@ pub async fn completion(
.max_decoding_length(64) .max_decoding_length(64)
.sampling_temperature(0.2) .sampling_temperature(0.2)
.build() .build()
.unwrap(); .expect("Invalid TextInferenceOptions");
let text = state.engine.inference(&request.prompt, options); let text = state.engine.inference(&request.prompt, options);
let language = request.language.unwrap_or("unknown".into()); let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text); let filtered_text = languages::remove_stop_words(&language, &text);
@ -90,7 +90,7 @@ impl CompletionState {
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(args.num_replicas_per_device)
.build() .build()
.unwrap(); .expect("Invalid TextInferenceEngineCreateOptions");
let engine = TextInferenceEngine::create(options); let engine = TextInferenceEngine::create(options);
Self { engine } Self { engine }
} }
@ -110,5 +110,5 @@ struct Metadata {
} }
fn read_metadata(model_dir: &ModelDir) -> Metadata { fn read_metadata(model_dir: &ModelDir) -> Metadata {
serdeconv::from_json_file(model_dir.metadata_file()).unwrap() serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata")
} }

View File

@ -3,7 +3,6 @@ mod completions;
mod events; mod events;
use crate::Cli; use crate::Cli;
use anyhow::Result;
use axum::{routing, Router, Server}; use axum::{routing, Router, Server};
use clap::{error::ErrorKind, Args, CommandFactory}; use clap::{error::ErrorKind, Args, CommandFactory};
use std::{ use std::{
@ -29,10 +28,10 @@ struct ApiDoc;
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device { pub enum Device {
#[strum(serialize = "cpu")] #[strum(serialize = "cpu")]
CPU, Cpu,
#[strum(serialize = "cuda")] #[strum(serialize = "cuda")]
CUDA, Cuda,
} }
#[derive(Args)] #[derive(Args)]
@ -45,7 +44,7 @@ pub struct ServeArgs {
port: u16, port: u16,
/// Device to run model inference. /// Device to run model inference.
#[clap(long, default_value_t=Device::CPU)] #[clap(long, default_value_t=Device::Cpu)]
device: Device, device: Device,
/// GPU indices to run models, only applicable for CUDA. /// GPU indices to run models, only applicable for CUDA.
@ -61,11 +60,11 @@ pub struct ServeArgs {
experimental_admin_panel: bool, experimental_admin_panel: bool,
} }
pub async fn main(args: &ServeArgs) -> Result<()> { pub async fn main(args: &ServeArgs) {
valid_args(args)?; valid_args(args);
// Ensure model exists. // Ensure model exists.
crate::download::download_model(&args.model, true).await?; crate::download::download_model(&args.model, true).await;
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
@ -77,8 +76,8 @@ pub async fn main(args: &ServeArgs) -> Result<()> {
println!("Listening at {}", address); println!("Listening at {}", address);
Server::bind(&address) Server::bind(&address)
.serve(app.into_make_service()) .serve(app.into_make_service())
.await?; .await
Ok(()) .expect("Error happends during model serving")
} }
fn api_router(args: &ServeArgs) -> Router { fn api_router(args: &ServeArgs) -> Router {
@ -99,8 +98,8 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
} }
} }
fn valid_args(args: &ServeArgs) -> Result<()> { fn valid_args(args: &ServeArgs) {
if args.device == Device::CUDA && args.num_replicas_per_device != 1 { if args.device == Device::Cuda && args.num_replicas_per_device != 1 {
Cli::command() Cli::command()
.error( .error(
ErrorKind::ValueValidation, ErrorKind::ValueValidation,
@ -109,7 +108,7 @@ fn valid_args(args: &ServeArgs) -> Result<()> {
.exit(); .exit();
} }
if args.device == Device::CPU && (args.device_indices.len() != 1 || args.device_indices[0] != 0) if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{ {
Cli::command() Cli::command()
.error( .error(
@ -118,6 +117,4 @@ fn valid_args(args: &ServeArgs) -> Result<()> {
) )
.exit(); .exit();
} }
Ok(())
} }