diff --git a/.github/workflows/docker.rust.yml b/.github/workflows/docker.rust.yml index fea6d5e..0c94ef0 100644 --- a/.github/workflows/docker.rust.yml +++ b/.github/workflows/docker.rust.yml @@ -41,7 +41,7 @@ jobs: run: cargo fmt --check release-docker: - needs: tests + needs: release-binary runs-on: ubuntu-latest permissions: contents: read @@ -134,6 +134,7 @@ jobs: with: toolchain: stable target: ${{ matrix.target }} + components: clippy - name: Sccache cache uses: mozilla-actions/sccache-action@v0.0.3 @@ -151,6 +152,12 @@ jobs: ~/.cargo/registry ~/.cargo/git + - name: Run cargo clippy check + run: cargo clippy -- -Dwarnings + + - name: Run cargo tests + run: cargo test + - name: Bulid release binary run: cargo build --release --target ${{ matrix.target }} diff --git a/crates/ctranslate2-bindings/build.rs b/crates/ctranslate2-bindings/build.rs index 6ac5585..6529bf6 100644 --- a/crates/ctranslate2-bindings/build.rs +++ b/crates/ctranslate2-bindings/build.rs @@ -58,11 +58,9 @@ fn link_static() -> PathBuf { let dst = config.build(); // Read static lib from generated deps. - let cmake_generated_libs_str = std::fs::read_to_string( - &format!("/{}/build/cmake_generated_libs", dst.display()).to_string(), - ) - .unwrap(); + let cmake_generated_libs_str = + std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap(); read_cmake_generated(&cmake_generated_libs_str); - return dst; + dst } diff --git a/crates/tabby-common/src/events.rs b/crates/tabby-common/src/events.rs index d454f82..6ee30bc 100644 --- a/crates/tabby-common/src/events.rs +++ b/crates/tabby-common/src/events.rs @@ -67,7 +67,7 @@ impl Event<'_> { writer.by_ref(), ) .unwrap(); - write!(writer, "\n").unwrap(); + writeln!(writer).unwrap(); writer.flush().unwrap(); } } diff --git a/crates/tabby-common/src/path.rs b/crates/tabby-common/src/path.rs index a837d75..67459ca 100644 --- a/crates/tabby-common/src/path.rs +++ b/crates/tabby-common/src/path.rs @@ -33,11 +33,11 @@ impl ModelDir { } pub fn metadata_file(&self) -> String { - return self.path_string("metadata.json"); + self.path_string("metadata.json") } pub fn tokenizer_file(&self) -> String { - return self.path_string("tokenizer.json"); + self.path_string("tokenizer.json") } pub fn ctranslate2_dir(&self) -> String { diff --git a/crates/tabby/src/download/metadata.rs b/crates/tabby/src/download/metadata.rs index 9f21d93..d46d3b8 100644 --- a/crates/tabby/src/download/metadata.rs +++ b/crates/tabby/src/download/metadata.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -17,10 +17,14 @@ struct HFMetadata { } impl HFMetadata { - async fn from(model_id: &str) -> Result { + async fn from(model_id: &str) -> HFMetadata { let api_url = format!("https://huggingface.co/api/models/{}", model_id); - let metadata = reqwest::get(api_url).await?.json::().await?; - Ok(metadata) + reqwest::get(&api_url) + .await + .unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url)) + .json::() + .await + .unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url)) } } @@ -31,16 +35,15 @@ pub struct Metadata { } impl Metadata { - pub async fn from(model_id: &str) -> Result { + pub async fn from(model_id: &str) -> Metadata { if let Some(metadata) = Self::from_local(model_id) { - Ok(metadata) + metadata } else { - let hf_metadata = HFMetadata::from(model_id).await?; - let metadata = Metadata { + let hf_metadata = HFMetadata::from(model_id).await; + Metadata { auto_model: hf_metadata.transformers_info.auto_model, etags: HashMap::new(), - }; - Ok(metadata) + } } } @@ -54,28 +57,20 @@ impl Metadata { } } - pub fn has_etag(&self, url: &str) -> bool { - self.etags.get(url).is_some() + pub fn local_cache_key(&self, path: &str) -> Option<&str> { + self.etags.get(path).map(|x| x.as_str()) } - pub async fn match_etag(&self, url: &str, path: &str) -> Result { - let etag = self - .etags - .get(url) - .ok_or(anyhow!("Path doesn't exist: {}", path))?; - let etag_from_header = reqwest::get(url) - .await? - .headers() + pub fn remote_cache_key(res: &reqwest::Response) -> &str { + res.headers() .get("etag") - .ok_or(anyhow!("URL doesn't have etag header: '{}'", url))? - .to_str()? - .to_owned(); - - Ok(etag == &etag_from_header) + .unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url())) + .to_str() + .unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url())) } - pub async fn update_etag(&mut self, url: &str, path: &str) { - self.etags.insert(url.to_owned(), path.to_owned()); + pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) { + self.etags.insert(path.to_string(), cache_key.to_string()); } pub fn save(&self, model_id: &str) -> Result<()> { @@ -92,7 +87,7 @@ mod tests { #[tokio::test] async fn test_hf() { - let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap(); + let hf_metadata = HFMetadata::from("TabbyML/J-350M").await; assert_eq!( hf_metadata.transformers_info.auto_model, "AutoModelForCausalLM" diff --git a/crates/tabby/src/download/mod.rs b/crates/tabby/src/download/mod.rs index f3e48d6..49fc258 100644 --- a/crates/tabby/src/download/mod.rs +++ b/crates/tabby/src/download/mod.rs @@ -1,6 +1,5 @@ mod metadata; -use anyhow::{anyhow, Result}; use std::cmp; use std::fs; use std::io::Write; @@ -18,111 +17,108 @@ pub struct DownloadArgs { model: String, /// If true, skip checking for remote model file. - #[clap(long, default_value_t = true)] + #[clap(long, default_value_t = false)] prefer_local_file: bool, } -pub async fn main(args: &DownloadArgs) -> Result<()> { - download_model(&args.model, args.prefer_local_file).await?; - Ok(()) +pub async fn main(args: &DownloadArgs) { + download_model(&args.model, args.prefer_local_file).await; + println!("model '{}' is ready", args.model); } impl metadata::Metadata { - async fn download( - &mut self, - model_id: &str, - path: &str, - prefer_local_file: bool, - ) -> Result<()> { + async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) { // Create url. let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path); + // Get cache key. + let local_cache_key = self.local_cache_key(path); + // Create destination path. let filepath = ModelDir::new(model_id).path_string(path); // Cache hit. - let mut cache_hit = false; - if fs::metadata(&filepath).is_ok() && self.has_etag(&url) { - if prefer_local_file || self.match_etag(&url, path).await? { - cache_hit = true - } + let mut local_file_ready = false; + if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() { + local_file_ready = true; } - if !cache_hit { - let etag = download_file(&url, &filepath).await?; - self.update_etag(&url, &etag).await + if !local_file_ready { + let etag = download_file(&url, &filepath, local_cache_key).await; + self.set_local_cache_key(path, &etag).await } - - Ok(()) } } -pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> { - let mut metadata = metadata::Metadata::from(model_id).await?; +pub async fn download_model(model_id: &str, prefer_local_file: bool) { + let mut metadata = metadata::Metadata::from(model_id).await; metadata .download(model_id, "tokenizer.json", prefer_local_file) - .await?; + .await; metadata .download(model_id, "ctranslate2/config.json", prefer_local_file) - .await?; + .await; metadata .download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file) - .await?; + .await; metadata .download( model_id, "ctranslate2/shared_vocabulary.txt", prefer_local_file, ) - .await?; + .await; metadata .download(model_id, "ctranslate2/model.bin", prefer_local_file) - .await?; - metadata.save(model_id)?; - Ok(()) + .await; + metadata + .save(model_id) + .unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id)); } -async fn download_file(url: &str, path: &str) -> Result { - fs::create_dir_all(Path::new(path).parent().unwrap())?; +async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String { + fs::create_dir_all(Path::new(path).parent().unwrap()) + .unwrap_or_else(|_| panic!("Failed to create path '{}'", path)); // Reqwest setup let res = reqwest::get(url) .await - .or(Err(anyhow!("Failed to GET from '{}'", url)))?; + .unwrap_or_else(|_| panic!("Failed to GET from '{}'", url)); - let etag = res - .headers() - .get("etag") - .ok_or(anyhow!("Failed to get etag from '{}", url))? - .to_str()? - .to_string(); + let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string(); + if let Some(local_cache_key) = local_cache_key { + if local_cache_key == remote_cache_key { + return remote_cache_key; + } + } let total_size = res .content_length() - .ok_or(anyhow!("Failed to get content length from '{}'", url))?; + .unwrap_or_else(|| panic!("Failed to get content length from '{}'", url)); // Indicatif setup let pb = ProgressBar::new(total_size); pb.set_style(ProgressStyle::default_bar() - .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")? + .template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})") + .expect("Invalid progress style") .progress_chars("#>-")); pb.set_message(format!("Downloading {}", path)); // download chunks - let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?; + let mut file = + fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path)); let mut downloaded: u64 = 0; let mut stream = res.bytes_stream(); while let Some(item) = stream.next().await { - let chunk = item.or(Err(anyhow!("Error while downloading file")))?; - file.write_all(&chunk) - .or(Err(anyhow!("Error while writing to file")))?; + let chunk = item.expect("Error while downloading file"); + file.write_all(&chunk).expect("Error while writing to file"); let new = cmp::min(downloaded + (chunk.len() as u64), total_size); downloaded = new; pb.set_position(new); } pb.finish_with_message(format!("Downloaded {}", path)); - return Ok(etag); + remote_cache_key } diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index d7d9d80..2d10ef4 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -25,15 +25,7 @@ async fn main() { let cli = Cli::parse(); match &cli.command { - Commands::Serve(args) => { - serve::main(args) - .await - .expect("Error happens during the serve"); - } - Commands::Download(args) => { - download::main(args) - .await - .expect("Error happens during the download"); - } + Commands::Serve(args) => serve::main(args).await, + Commands::Download(args) => download::main(args).await, } } diff --git a/crates/tabby/src/serve/admin.rs b/crates/tabby/src/serve/admin.rs index c98faa5..357c05a 100644 --- a/crates/tabby/src/serve/admin.rs +++ b/crates/tabby/src/serve/admin.rs @@ -26,12 +26,12 @@ where Response::builder() .header(header::CONTENT_TYPE, mime.as_ref()) .body(body) - .unwrap() + .expect("Invalid response") } None => Response::builder() .status(StatusCode::NOT_FOUND) .body(boxed(Full::from("404"))) - .unwrap(), + .expect("Invalid response"), } } } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 46d79ea..fe0ffea 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -45,7 +45,7 @@ pub async fn completion( .max_decoding_length(64) .sampling_temperature(0.2) .build() - .unwrap(); + .expect("Invalid TextInferenceOptions"); let text = state.engine.inference(&request.prompt, options); let language = request.language.unwrap_or("unknown".into()); let filtered_text = languages::remove_stop_words(&language, &text); @@ -90,7 +90,7 @@ impl CompletionState { .device_indices(args.device_indices.clone()) .num_replicas_per_device(args.num_replicas_per_device) .build() - .unwrap(); + .expect("Invalid TextInferenceEngineCreateOptions"); let engine = TextInferenceEngine::create(options); Self { engine } } @@ -110,5 +110,5 @@ struct Metadata { } fn read_metadata(model_dir: &ModelDir) -> Metadata { - serdeconv::from_json_file(model_dir.metadata_file()).unwrap() + serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata") } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 98402de..895a133 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -3,7 +3,6 @@ mod completions; mod events; use crate::Cli; -use anyhow::Result; use axum::{routing, Router, Server}; use clap::{error::ErrorKind, Args, CommandFactory}; use std::{ @@ -29,10 +28,10 @@ struct ApiDoc; #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] pub enum Device { #[strum(serialize = "cpu")] - CPU, + Cpu, #[strum(serialize = "cuda")] - CUDA, + Cuda, } #[derive(Args)] @@ -45,7 +44,7 @@ pub struct ServeArgs { port: u16, /// Device to run model inference. - #[clap(long, default_value_t=Device::CPU)] + #[clap(long, default_value_t=Device::Cpu)] device: Device, /// GPU indices to run models, only applicable for CUDA. @@ -61,11 +60,11 @@ pub struct ServeArgs { experimental_admin_panel: bool, } -pub async fn main(args: &ServeArgs) -> Result<()> { - valid_args(args)?; +pub async fn main(args: &ServeArgs) { + valid_args(args); // Ensure model exists. - crate::download::download_model(&args.model, true).await?; + crate::download::download_model(&args.model, true).await; let app = Router::new() .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) @@ -77,8 +76,8 @@ pub async fn main(args: &ServeArgs) -> Result<()> { println!("Listening at {}", address); Server::bind(&address) .serve(app.into_make_service()) - .await?; - Ok(()) + .await + .expect("Error happends during model serving") } fn api_router(args: &ServeArgs) -> Router { @@ -99,8 +98,8 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter { } } -fn valid_args(args: &ServeArgs) -> Result<()> { - if args.device == Device::CUDA && args.num_replicas_per_device != 1 { +fn valid_args(args: &ServeArgs) { + if args.device == Device::Cuda && args.num_replicas_per_device != 1 { Cli::command() .error( ErrorKind::ValueValidation, @@ -109,7 +108,7 @@ fn valid_args(args: &ServeArgs) -> Result<()> { .exit(); } - if args.device == Device::CPU && (args.device_indices.len() != 1 || args.device_indices[0] != 0) + if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0) { Cli::command() .error( @@ -118,6 +117,4 @@ fn valid_args(args: &ServeArgs) -> Result<()> { ) .exit(); } - - Ok(()) }