chore: migrate completion to new metadata format (#179)

support-coreml
Meng Zhang 2023-06-01 00:08:09 -07:00 committed by GitHub
parent e8dbd36663
commit 9131567257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 12 deletions

View File

@ -86,7 +86,7 @@ impl CompletionState {
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.transformers_info.auto_model)
.model_type(metadata.auto_model)
.device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device)
.build()
@ -105,13 +105,7 @@ fn get_model_dir(model: &str) -> ModelDir {
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct Metadata {
transformers_info: TransformersInfo,
}
#[derive(Deserialize)]
struct TransformersInfo {
auto_model: String,
}

View File

@ -3,9 +3,9 @@ mod completions;
mod events;
use crate::Cli;
use anyhow::Result;
use axum::{routing, Router, Server};
use clap::{error::ErrorKind, Args, CommandFactory};
use hyper::Error;
use std::{
net::{Ipv4Addr, SocketAddr},
sync::Arc,
@ -61,8 +61,8 @@ pub struct ServeArgs {
experimental_admin_panel: bool,
}
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
valid_args(args);
pub async fn main(args: &ServeArgs) -> Result<()> {
valid_args(args)?;
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.nest("/v1", api_router(args))
@ -71,7 +71,10 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
println!("Listening at {}", address);
Server::bind(&address).serve(app.into_make_service()).await
Server::bind(&address)
.serve(app.into_make_service())
.await?;
Ok(())
}
fn api_router(args: &ServeArgs) -> Router {
@ -92,7 +95,7 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
}
}
fn valid_args(args: &ServeArgs) {
fn valid_args(args: &ServeArgs) -> Result<()> {
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
Cli::command()
.error(
@ -111,4 +114,6 @@ fn valid_args(args: &ServeArgs) {
)
.exit();
}
Ok(())
}