feat: returns more information in /v1/health

improve-workflow
Meng Zhang 2023-06-13 13:11:07 -07:00
parent df67b13639
commit b2734aed59
2 changed files with 45 additions and 13 deletions

View File

@ -0,0 +1,34 @@
use std::sync::Arc;
use axum::{extract::State, Json};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct HealthState {
model: String,
device: String,
compute_type: String,
}
impl HealthState {
pub fn new(args: &super::ServeArgs) -> Self {
Self {
model: args.model.clone(),
device: args.device.to_string(),
compute_type: args.compute_type.to_string(),
}
}
}
#[utoipa::path(
post,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
)
)]
pub async fn health(State(state): State<Arc<HealthState>>) -> Json<HealthState> {
Json(state.as_ref().clone())
}

View File

@ -1,5 +1,6 @@
mod completions; mod completions;
mod events; mod events;
mod health;
use std::{ use std::{
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
@ -29,13 +30,14 @@ OpenAPI documentation for [tabby](https://github.com/TabbyML/tabby), a self-host
(url = "https://tabbyml.app.tabbyml.com/tabby", description = "Local server"), (url = "https://tabbyml.app.tabbyml.com/tabby", description = "Local server"),
(url = "http://localhost:8080", description = "Local server"), (url = "http://localhost:8080", description = "Local server"),
), ),
paths(events::log_event, completions::completion, health), paths(events::log_event, completions::completion, health::health),
components(schemas( components(schemas(
events::LogEventRequest, events::LogEventRequest,
completions::CompletionRequest, completions::CompletionRequest,
completions::CompletionResponse, completions::CompletionResponse,
completions::Segments, completions::Segments,
completions::Choice, completions::Choice,
health::HealthState,
)) ))
)] )]
struct ApiDoc; struct ApiDoc;
@ -134,7 +136,10 @@ pub async fn main(args: &ServeArgs) {
fn api_router(args: &ServeArgs) -> Router { fn api_router(args: &ServeArgs) -> Router {
Router::new() Router::new()
.route("/events", routing::post(events::log_event)) .route("/events", routing::post(events::log_event))
.route("/health", routing::post(health)) .route(
"/health",
routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))),
)
.route( .route(
"/completions", "/completions",
routing::post(completions::completion) routing::post(completions::completion)
@ -159,16 +164,9 @@ fn valid_args(args: &ServeArgs) {
} }
if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 { if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 {
fatal!("CPU device only supports int8 compute type"); match args.compute_type {
ComputeType::Auto | ComputeType::Int8 => {}
_ => fatal!("CPU device only supports int8 compute type"),
}
} }
} }
#[utoipa::path(
post,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Health"),
)
)]
async fn health() {}