refactor: improve error handlings, fix clippy warnings (#181)
* refactor: minor improvements on error handling * refactor: cleanup error handlings * update * update * fix * add clippy / test workflow * fix clippy * fix clippy * updatesupport-coreml
parent
4c6f1338a8
commit
3cac2607e7
|
|
@ -41,7 +41,7 @@ jobs:
|
||||||
run: cargo fmt --check
|
run: cargo fmt --check
|
||||||
|
|
||||||
release-docker:
|
release-docker:
|
||||||
needs: tests
|
needs: release-binary
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
@ -134,6 +134,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
toolchain: stable
|
toolchain: stable
|
||||||
target: ${{ matrix.target }}
|
target: ${{ matrix.target }}
|
||||||
|
components: clippy
|
||||||
|
|
||||||
- name: Sccache cache
|
- name: Sccache cache
|
||||||
uses: mozilla-actions/sccache-action@v0.0.3
|
uses: mozilla-actions/sccache-action@v0.0.3
|
||||||
|
|
@ -151,6 +152,12 @@ jobs:
|
||||||
~/.cargo/registry
|
~/.cargo/registry
|
||||||
~/.cargo/git
|
~/.cargo/git
|
||||||
|
|
||||||
|
- name: Run cargo clippy check
|
||||||
|
run: cargo clippy -- -Dwarnings
|
||||||
|
|
||||||
|
- name: Run cargo tests
|
||||||
|
run: cargo test
|
||||||
|
|
||||||
- name: Bulid release binary
|
- name: Bulid release binary
|
||||||
run: cargo build --release --target ${{ matrix.target }}
|
run: cargo build --release --target ${{ matrix.target }}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,11 +58,9 @@ fn link_static() -> PathBuf {
|
||||||
let dst = config.build();
|
let dst = config.build();
|
||||||
|
|
||||||
// Read static lib from generated deps.
|
// Read static lib from generated deps.
|
||||||
let cmake_generated_libs_str = std::fs::read_to_string(
|
let cmake_generated_libs_str =
|
||||||
&format!("/{}/build/cmake_generated_libs", dst.display()).to_string(),
|
std::fs::read_to_string(format!("/{}/build/cmake_generated_libs", dst.display())).unwrap();
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
read_cmake_generated(&cmake_generated_libs_str);
|
read_cmake_generated(&cmake_generated_libs_str);
|
||||||
|
|
||||||
return dst;
|
dst
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ impl Event<'_> {
|
||||||
writer.by_ref(),
|
writer.by_ref(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
write!(writer, "\n").unwrap();
|
writeln!(writer).unwrap();
|
||||||
writer.flush().unwrap();
|
writer.flush().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,11 @@ impl ModelDir {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metadata_file(&self) -> String {
|
pub fn metadata_file(&self) -> String {
|
||||||
return self.path_string("metadata.json");
|
self.path_string("metadata.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tokenizer_file(&self) -> String {
|
pub fn tokenizer_file(&self) -> String {
|
||||||
return self.path_string("tokenizer.json");
|
self.path_string("tokenizer.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ctranslate2_dir(&self) -> String {
|
pub fn ctranslate2_dir(&self) -> String {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -17,10 +17,14 @@ struct HFMetadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HFMetadata {
|
impl HFMetadata {
|
||||||
async fn from(model_id: &str) -> Result<HFMetadata> {
|
async fn from(model_id: &str) -> HFMetadata {
|
||||||
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
|
let api_url = format!("https://huggingface.co/api/models/{}", model_id);
|
||||||
let metadata = reqwest::get(api_url).await?.json::<HFMetadata>().await?;
|
reqwest::get(&api_url)
|
||||||
Ok(metadata)
|
.await
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to GET url '{}'", api_url))
|
||||||
|
.json::<HFMetadata>()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to parse HFMetadata '{}'", api_url))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -31,16 +35,15 @@ pub struct Metadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Metadata {
|
impl Metadata {
|
||||||
pub async fn from(model_id: &str) -> Result<Metadata> {
|
pub async fn from(model_id: &str) -> Metadata {
|
||||||
if let Some(metadata) = Self::from_local(model_id) {
|
if let Some(metadata) = Self::from_local(model_id) {
|
||||||
Ok(metadata)
|
metadata
|
||||||
} else {
|
} else {
|
||||||
let hf_metadata = HFMetadata::from(model_id).await?;
|
let hf_metadata = HFMetadata::from(model_id).await;
|
||||||
let metadata = Metadata {
|
Metadata {
|
||||||
auto_model: hf_metadata.transformers_info.auto_model,
|
auto_model: hf_metadata.transformers_info.auto_model,
|
||||||
etags: HashMap::new(),
|
etags: HashMap::new(),
|
||||||
};
|
}
|
||||||
Ok(metadata)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -54,28 +57,20 @@ impl Metadata {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn has_etag(&self, url: &str) -> bool {
|
pub fn local_cache_key(&self, path: &str) -> Option<&str> {
|
||||||
self.etags.get(url).is_some()
|
self.etags.get(path).map(|x| x.as_str())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn match_etag(&self, url: &str, path: &str) -> Result<bool> {
|
pub fn remote_cache_key(res: &reqwest::Response) -> &str {
|
||||||
let etag = self
|
res.headers()
|
||||||
.etags
|
|
||||||
.get(url)
|
|
||||||
.ok_or(anyhow!("Path doesn't exist: {}", path))?;
|
|
||||||
let etag_from_header = reqwest::get(url)
|
|
||||||
.await?
|
|
||||||
.headers()
|
|
||||||
.get("etag")
|
.get("etag")
|
||||||
.ok_or(anyhow!("URL doesn't have etag header: '{}'", url))?
|
.unwrap_or_else(|| panic!("Failed to GET ETAG header from '{}'", res.url()))
|
||||||
.to_str()?
|
.to_str()
|
||||||
.to_owned();
|
.unwrap_or_else(|_| panic!("Failed to convert ETAG header into string '{}'", res.url()))
|
||||||
|
|
||||||
Ok(etag == &etag_from_header)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_etag(&mut self, url: &str, path: &str) {
|
pub async fn set_local_cache_key(&mut self, path: &str, cache_key: &str) {
|
||||||
self.etags.insert(url.to_owned(), path.to_owned());
|
self.etags.insert(path.to_string(), cache_key.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save(&self, model_id: &str) -> Result<()> {
|
pub fn save(&self, model_id: &str) -> Result<()> {
|
||||||
|
|
@ -92,7 +87,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_hf() {
|
async fn test_hf() {
|
||||||
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await.unwrap();
|
let hf_metadata = HFMetadata::from("TabbyML/J-350M").await;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
hf_metadata.transformers_info.auto_model,
|
hf_metadata.transformers_info.auto_model,
|
||||||
"AutoModelForCausalLM"
|
"AutoModelForCausalLM"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
mod metadata;
|
mod metadata;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
@ -18,111 +17,108 @@ pub struct DownloadArgs {
|
||||||
model: String,
|
model: String,
|
||||||
|
|
||||||
/// If true, skip checking for remote model file.
|
/// If true, skip checking for remote model file.
|
||||||
#[clap(long, default_value_t = true)]
|
#[clap(long, default_value_t = false)]
|
||||||
prefer_local_file: bool,
|
prefer_local_file: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &DownloadArgs) -> Result<()> {
|
pub async fn main(args: &DownloadArgs) {
|
||||||
download_model(&args.model, args.prefer_local_file).await?;
|
download_model(&args.model, args.prefer_local_file).await;
|
||||||
Ok(())
|
println!("model '{}' is ready", args.model);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl metadata::Metadata {
|
impl metadata::Metadata {
|
||||||
async fn download(
|
async fn download(&mut self, model_id: &str, path: &str, prefer_local_file: bool) {
|
||||||
&mut self,
|
|
||||||
model_id: &str,
|
|
||||||
path: &str,
|
|
||||||
prefer_local_file: bool,
|
|
||||||
) -> Result<()> {
|
|
||||||
// Create url.
|
// Create url.
|
||||||
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
let url = format!("https://huggingface.co/{}/resolve/main/{}", model_id, path);
|
||||||
|
|
||||||
|
// Get cache key.
|
||||||
|
let local_cache_key = self.local_cache_key(path);
|
||||||
|
|
||||||
// Create destination path.
|
// Create destination path.
|
||||||
let filepath = ModelDir::new(model_id).path_string(path);
|
let filepath = ModelDir::new(model_id).path_string(path);
|
||||||
|
|
||||||
// Cache hit.
|
// Cache hit.
|
||||||
let mut cache_hit = false;
|
let mut local_file_ready = false;
|
||||||
if fs::metadata(&filepath).is_ok() && self.has_etag(&url) {
|
if !prefer_local_file && local_cache_key.is_some() && fs::metadata(&filepath).is_ok() {
|
||||||
if prefer_local_file || self.match_etag(&url, path).await? {
|
local_file_ready = true;
|
||||||
cache_hit = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cache_hit {
|
if !local_file_ready {
|
||||||
let etag = download_file(&url, &filepath).await?;
|
let etag = download_file(&url, &filepath, local_cache_key).await;
|
||||||
self.update_etag(&url, &etag).await
|
self.set_local_cache_key(path, &etag).await
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn download_model(model_id: &str, prefer_local_file: bool) -> Result<()> {
|
pub async fn download_model(model_id: &str, prefer_local_file: bool) {
|
||||||
let mut metadata = metadata::Metadata::from(model_id).await?;
|
let mut metadata = metadata::Metadata::from(model_id).await;
|
||||||
|
|
||||||
metadata
|
metadata
|
||||||
.download(model_id, "tokenizer.json", prefer_local_file)
|
.download(model_id, "tokenizer.json", prefer_local_file)
|
||||||
.await?;
|
.await;
|
||||||
metadata
|
metadata
|
||||||
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
.download(model_id, "ctranslate2/config.json", prefer_local_file)
|
||||||
.await?;
|
.await;
|
||||||
metadata
|
metadata
|
||||||
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
.download(model_id, "ctranslate2/vocabulary.txt", prefer_local_file)
|
||||||
.await?;
|
.await;
|
||||||
metadata
|
metadata
|
||||||
.download(
|
.download(
|
||||||
model_id,
|
model_id,
|
||||||
"ctranslate2/shared_vocabulary.txt",
|
"ctranslate2/shared_vocabulary.txt",
|
||||||
prefer_local_file,
|
prefer_local_file,
|
||||||
)
|
)
|
||||||
.await?;
|
.await;
|
||||||
metadata
|
metadata
|
||||||
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
.download(model_id, "ctranslate2/model.bin", prefer_local_file)
|
||||||
.await?;
|
.await;
|
||||||
metadata.save(model_id)?;
|
metadata
|
||||||
Ok(())
|
.save(model_id)
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to save model_id '{}'", model_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn download_file(url: &str, path: &str) -> Result<String> {
|
async fn download_file(url: &str, path: &str, local_cache_key: Option<&str>) -> String {
|
||||||
fs::create_dir_all(Path::new(path).parent().unwrap())?;
|
fs::create_dir_all(Path::new(path).parent().unwrap())
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to create path '{}'", path));
|
||||||
|
|
||||||
// Reqwest setup
|
// Reqwest setup
|
||||||
let res = reqwest::get(url)
|
let res = reqwest::get(url)
|
||||||
.await
|
.await
|
||||||
.or(Err(anyhow!("Failed to GET from '{}'", url)))?;
|
.unwrap_or_else(|_| panic!("Failed to GET from '{}'", url));
|
||||||
|
|
||||||
let etag = res
|
let remote_cache_key = metadata::Metadata::remote_cache_key(&res).to_string();
|
||||||
.headers()
|
if let Some(local_cache_key) = local_cache_key {
|
||||||
.get("etag")
|
if local_cache_key == remote_cache_key {
|
||||||
.ok_or(anyhow!("Failed to get etag from '{}", url))?
|
return remote_cache_key;
|
||||||
.to_str()?
|
}
|
||||||
.to_string();
|
}
|
||||||
|
|
||||||
let total_size = res
|
let total_size = res
|
||||||
.content_length()
|
.content_length()
|
||||||
.ok_or(anyhow!("Failed to get content length from '{}'", url))?;
|
.unwrap_or_else(|| panic!("Failed to get content length from '{}'", url));
|
||||||
|
|
||||||
// Indicatif setup
|
// Indicatif setup
|
||||||
let pb = ProgressBar::new(total_size);
|
let pb = ProgressBar::new(total_size);
|
||||||
pb.set_style(ProgressStyle::default_bar()
|
pb.set_style(ProgressStyle::default_bar()
|
||||||
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")?
|
.template("{msg}\n{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
|
||||||
|
.expect("Invalid progress style")
|
||||||
.progress_chars("#>-"));
|
.progress_chars("#>-"));
|
||||||
pb.set_message(format!("Downloading {}", path));
|
pb.set_message(format!("Downloading {}", path));
|
||||||
|
|
||||||
// download chunks
|
// download chunks
|
||||||
let mut file = fs::File::create(&path).or(Err(anyhow!("Failed to create file '{}'", &path)))?;
|
let mut file =
|
||||||
|
fs::File::create(path).unwrap_or_else(|_| panic!("Failed to create file '{}'", path));
|
||||||
let mut downloaded: u64 = 0;
|
let mut downloaded: u64 = 0;
|
||||||
let mut stream = res.bytes_stream();
|
let mut stream = res.bytes_stream();
|
||||||
|
|
||||||
while let Some(item) = stream.next().await {
|
while let Some(item) = stream.next().await {
|
||||||
let chunk = item.or(Err(anyhow!("Error while downloading file")))?;
|
let chunk = item.expect("Error while downloading file");
|
||||||
file.write_all(&chunk)
|
file.write_all(&chunk).expect("Error while writing to file");
|
||||||
.or(Err(anyhow!("Error while writing to file")))?;
|
|
||||||
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
let new = cmp::min(downloaded + (chunk.len() as u64), total_size);
|
||||||
downloaded = new;
|
downloaded = new;
|
||||||
pb.set_position(new);
|
pb.set_position(new);
|
||||||
}
|
}
|
||||||
|
|
||||||
pb.finish_with_message(format!("Downloaded {}", path));
|
pb.finish_with_message(format!("Downloaded {}", path));
|
||||||
return Ok(etag);
|
remote_cache_key
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,15 +25,7 @@ async fn main() {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
match &cli.command {
|
match &cli.command {
|
||||||
Commands::Serve(args) => {
|
Commands::Serve(args) => serve::main(args).await,
|
||||||
serve::main(args)
|
Commands::Download(args) => download::main(args).await,
|
||||||
.await
|
|
||||||
.expect("Error happens during the serve");
|
|
||||||
}
|
|
||||||
Commands::Download(args) => {
|
|
||||||
download::main(args)
|
|
||||||
.await
|
|
||||||
.expect("Error happens during the download");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,12 @@ where
|
||||||
Response::builder()
|
Response::builder()
|
||||||
.header(header::CONTENT_TYPE, mime.as_ref())
|
.header(header::CONTENT_TYPE, mime.as_ref())
|
||||||
.body(body)
|
.body(body)
|
||||||
.unwrap()
|
.expect("Invalid response")
|
||||||
}
|
}
|
||||||
None => Response::builder()
|
None => Response::builder()
|
||||||
.status(StatusCode::NOT_FOUND)
|
.status(StatusCode::NOT_FOUND)
|
||||||
.body(boxed(Full::from("404")))
|
.body(boxed(Full::from("404")))
|
||||||
.unwrap(),
|
.expect("Invalid response"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ pub async fn completion(
|
||||||
.max_decoding_length(64)
|
.max_decoding_length(64)
|
||||||
.sampling_temperature(0.2)
|
.sampling_temperature(0.2)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.expect("Invalid TextInferenceOptions");
|
||||||
let text = state.engine.inference(&request.prompt, options);
|
let text = state.engine.inference(&request.prompt, options);
|
||||||
let language = request.language.unwrap_or("unknown".into());
|
let language = request.language.unwrap_or("unknown".into());
|
||||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||||
|
|
@ -90,7 +90,7 @@ impl CompletionState {
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.expect("Invalid TextInferenceEngineCreateOptions");
|
||||||
let engine = TextInferenceEngine::create(options);
|
let engine = TextInferenceEngine::create(options);
|
||||||
Self { engine }
|
Self { engine }
|
||||||
}
|
}
|
||||||
|
|
@ -110,5 +110,5 @@ struct Metadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
fn read_metadata(model_dir: &ModelDir) -> Metadata {
|
||||||
serdeconv::from_json_file(model_dir.metadata_file()).unwrap()
|
serdeconv::from_json_file(model_dir.metadata_file()).expect("Invalid metadata")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ mod completions;
|
||||||
mod events;
|
mod events;
|
||||||
|
|
||||||
use crate::Cli;
|
use crate::Cli;
|
||||||
use anyhow::Result;
|
|
||||||
use axum::{routing, Router, Server};
|
use axum::{routing, Router, Server};
|
||||||
use clap::{error::ErrorKind, Args, CommandFactory};
|
use clap::{error::ErrorKind, Args, CommandFactory};
|
||||||
use std::{
|
use std::{
|
||||||
|
|
@ -29,10 +28,10 @@ struct ApiDoc;
|
||||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||||
pub enum Device {
|
pub enum Device {
|
||||||
#[strum(serialize = "cpu")]
|
#[strum(serialize = "cpu")]
|
||||||
CPU,
|
Cpu,
|
||||||
|
|
||||||
#[strum(serialize = "cuda")]
|
#[strum(serialize = "cuda")]
|
||||||
CUDA,
|
Cuda,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Args)]
|
#[derive(Args)]
|
||||||
|
|
@ -45,7 +44,7 @@ pub struct ServeArgs {
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// Device to run model inference.
|
/// Device to run model inference.
|
||||||
#[clap(long, default_value_t=Device::CPU)]
|
#[clap(long, default_value_t=Device::Cpu)]
|
||||||
device: Device,
|
device: Device,
|
||||||
|
|
||||||
/// GPU indices to run models, only applicable for CUDA.
|
/// GPU indices to run models, only applicable for CUDA.
|
||||||
|
|
@ -61,11 +60,11 @@ pub struct ServeArgs {
|
||||||
experimental_admin_panel: bool,
|
experimental_admin_panel: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &ServeArgs) -> Result<()> {
|
pub async fn main(args: &ServeArgs) {
|
||||||
valid_args(args)?;
|
valid_args(args);
|
||||||
|
|
||||||
// Ensure model exists.
|
// Ensure model exists.
|
||||||
crate::download::download_model(&args.model, true).await?;
|
crate::download::download_model(&args.model, true).await;
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
|
|
@ -77,8 +76,8 @@ pub async fn main(args: &ServeArgs) -> Result<()> {
|
||||||
println!("Listening at {}", address);
|
println!("Listening at {}", address);
|
||||||
Server::bind(&address)
|
Server::bind(&address)
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
.await?;
|
.await
|
||||||
Ok(())
|
.expect("Error happends during model serving")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs) -> Router {
|
fn api_router(args: &ServeArgs) -> Router {
|
||||||
|
|
@ -99,8 +98,8 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn valid_args(args: &ServeArgs) -> Result<()> {
|
fn valid_args(args: &ServeArgs) {
|
||||||
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
|
if args.device == Device::Cuda && args.num_replicas_per_device != 1 {
|
||||||
Cli::command()
|
Cli::command()
|
||||||
.error(
|
.error(
|
||||||
ErrorKind::ValueValidation,
|
ErrorKind::ValueValidation,
|
||||||
|
|
@ -109,7 +108,7 @@ fn valid_args(args: &ServeArgs) -> Result<()> {
|
||||||
.exit();
|
.exit();
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.device == Device::CPU && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
|
||||||
{
|
{
|
||||||
Cli::command()
|
Cli::command()
|
||||||
.error(
|
.error(
|
||||||
|
|
@ -118,6 +117,4 @@ fn valid_args(args: &ServeArgs) -> Result<()> {
|
||||||
)
|
)
|
||||||
.exit();
|
.exit();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue