chore: migrate completion to new metadata format (#179)
parent
e8dbd36663
commit
9131567257
|
|
@ -86,7 +86,7 @@ impl CompletionState {
|
||||||
.model_path(model_dir.ctranslate2_dir())
|
.model_path(model_dir.ctranslate2_dir())
|
||||||
.tokenizer_path(model_dir.tokenizer_file())
|
.tokenizer_path(model_dir.tokenizer_file())
|
||||||
.device(device)
|
.device(device)
|
||||||
.model_type(metadata.transformers_info.auto_model)
|
.model_type(metadata.auto_model)
|
||||||
.device_indices(args.device_indices.clone())
|
.device_indices(args.device_indices.clone())
|
||||||
.num_replicas_per_device(args.num_replicas_per_device)
|
.num_replicas_per_device(args.num_replicas_per_device)
|
||||||
.build()
|
.build()
|
||||||
|
|
@ -105,13 +105,7 @@ fn get_model_dir(model: &str) -> ModelDir {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
transformers_info: TransformersInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct TransformersInfo {
|
|
||||||
auto_model: String,
|
auto_model: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ mod completions;
|
||||||
mod events;
|
mod events;
|
||||||
|
|
||||||
use crate::Cli;
|
use crate::Cli;
|
||||||
|
use anyhow::Result;
|
||||||
use axum::{routing, Router, Server};
|
use axum::{routing, Router, Server};
|
||||||
use clap::{error::ErrorKind, Args, CommandFactory};
|
use clap::{error::ErrorKind, Args, CommandFactory};
|
||||||
use hyper::Error;
|
|
||||||
use std::{
|
use std::{
|
||||||
net::{Ipv4Addr, SocketAddr},
|
net::{Ipv4Addr, SocketAddr},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
|
|
@ -61,8 +61,8 @@ pub struct ServeArgs {
|
||||||
experimental_admin_panel: bool,
|
experimental_admin_panel: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
pub async fn main(args: &ServeArgs) -> Result<()> {
|
||||||
valid_args(args);
|
valid_args(args)?;
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||||
.nest("/v1", api_router(args))
|
.nest("/v1", api_router(args))
|
||||||
|
|
@ -71,7 +71,10 @@ pub async fn main(args: &ServeArgs) -> Result<(), Error> {
|
||||||
|
|
||||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||||
println!("Listening at {}", address);
|
println!("Listening at {}", address);
|
||||||
Server::bind(&address).serve(app.into_make_service()).await
|
Server::bind(&address)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs) -> Router {
|
fn api_router(args: &ServeArgs) -> Router {
|
||||||
|
|
@ -92,7 +95,7 @@ fn fallback(experimental_admin_panel: bool) -> routing::MethodRouter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn valid_args(args: &ServeArgs) {
|
fn valid_args(args: &ServeArgs) -> Result<()> {
|
||||||
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
|
if args.device == Device::CUDA && args.num_replicas_per_device != 1 {
|
||||||
Cli::command()
|
Cli::command()
|
||||||
.error(
|
.error(
|
||||||
|
|
@ -111,4 +114,6 @@ fn valid_args(args: &ServeArgs) {
|
||||||
)
|
)
|
||||||
.exit();
|
.exit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue