diff --git a/crates/tabby/src/serve/health.rs b/crates/tabby/src/serve/health.rs new file mode 100644 index 0000000..283eb60 --- /dev/null +++ b/crates/tabby/src/serve/health.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use axum::{extract::State, Json}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct HealthState { + model: String, + device: String, + compute_type: String, +} + +impl HealthState { + pub fn new(args: &super::ServeArgs) -> Self { + Self { + model: args.model.clone(), + device: args.device.to_string(), + compute_type: args.compute_type.to_string(), + } + } +} + +#[utoipa::path( + post, + path = "/v1/health", + tag = "v1", + responses( + (status = 200, description = "Success", body = HealthState, content_type = "application/json"), + ) +)] +pub async fn health(State(state): State>) -> Json { + Json(state.as_ref().clone()) +} diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 85841f1..2302e44 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -1,5 +1,6 @@ mod completions; mod events; +mod health; use std::{ net::{Ipv4Addr, SocketAddr}, @@ -29,13 +30,14 @@ OpenAPI documentation for [tabby](https://github.com/TabbyML/tabby), a self-host (url = "https://tabbyml.app.tabbyml.com/tabby", description = "Local server"), (url = "http://localhost:8080", description = "Local server"), ), - paths(events::log_event, completions::completion, health), + paths(events::log_event, completions::completion, health::health), components(schemas( events::LogEventRequest, completions::CompletionRequest, completions::CompletionResponse, completions::Segments, completions::Choice, + health::HealthState, )) )] struct ApiDoc; @@ -134,7 +136,10 @@ pub async fn main(args: &ServeArgs) { fn api_router(args: &ServeArgs) -> Router { Router::new() .route("/events", routing::post(events::log_event)) - .route("/health", routing::post(health)) + .route( + "/health", + routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))), + ) .route( "/completions", routing::post(completions::completion) @@ -159,16 +164,9 @@ fn valid_args(args: &ServeArgs) { } if args.device == Device::Cpu && args.compute_type != ComputeType::Int8 { - fatal!("CPU device only supports int8 compute type"); + match args.compute_type { + ComputeType::Auto | ComputeType::Int8 => {} + _ => fatal!("CPU device only supports int8 compute type"), + } } } - -#[utoipa::path( - post, - path = "/v1/health", - tag = "v1", - responses( - (status = 200, description = "Health"), - ) -)] -async fn health() {}