chore: migrate completion to new metadata format (#179)

support-coreml
Meng Zhang 2023-06-01 00:08:09 -07:00 committed by GitHub
parent e8dbd36663
commit 9131567257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 12 deletions

View File

@ -86,7 +86,7 @@ impl CompletionState {
.model_path(model_dir.ctranslate2_dir()) .model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file()) .tokenizer_path(model_dir.tokenizer_file())
.device(device) .device(device)
.model_type(metadata.transformers_info.auto_model) .model_type(metadata.auto_model)
.device_indices(args.device_indices.clone()) .device_indices(args.device_indices.clone())
.num_replicas_per_device(args.num_replicas_per_device) .num_replicas_per_device(args.num_replicas_per_device)
.build() .build()
@ -105,13 +105,7 @@ fn get_model_dir(model: &str) -> ModelDir {
} }
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct Metadata { struct Metadata {
transformers_info: TransformersInfo,
}
#[derive(Deserialize)]
struct TransformersInfo {
auto_model: String, auto_model: String,
} }

View File

@ -3,9 +3,9 @@ mod completions;
mod events; mod events;
use crate::Cli; use crate::Cli;
use anyhow::Result;
use axum::{routing, Router, Server}; use axum::{routing, Router, Server};
use clap::{error::ErrorKind, Args, CommandFactory}; use clap::{error::ErrorKind, Args, CommandFactory};
use hyper::Error;
use std::{ use std::{
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
sync::Arc, sync::Arc,
@ -61,8 +61,8 @@ pub struct ServeArgs {
experimental_admin_panel: bool, experimental_admin_panel: bool,
} }
pub async fn main(args: &ServeArgs) -> Result<(), Error> { pub async fn main(args: &ServeArgs) -> Result<()> {
valid_args(args); valid_args(args)?;
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.nest("/v1", api_router(args)) .nest("/v1", api_router(args))
@ -71,7 +71,10 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
println!("Listening at {}", address); println!("Listening at {}", address);
Server::bind(&address).serve(app.into_make_service()).await Server::bind(&address)
.serve(app.into_make_service())
.await?;
Ok(())
} }
fn api_router(args: &ServeArgs) -> Router { fn api_router(args: &ServeArgs) -> Router {
@ -92,7 +95,7 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
} }
} }
fn valid_args(args: &ServeArgs) { fn valid_args(args: &ServeArgs) -> Result<()> {
if args.device == Device::CUDA && args.num_replicas_per_device != 1 { if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
Cli::command() Cli::command()
.error( .error(
@ -111,4 +114,6 @@ fn valid_args(args: &ServeArgs) {
) )
.exit(); .exit();
} }
Ok(())
} }