diff --git a/Cargo.lock b/Cargo.lock index 3e3434b..3cd2051 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3722,6 +3722,7 @@ dependencies = [ "http-body", "http-range-header", "pin-project-lite", + "tokio", "tower-layer", "tower-service", ] diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index b13bdab..ed4d3d4 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -18,7 +18,7 @@ utoipa-swagger-ui = { version = "3.1", features = ["axum"] } serde = { workspace = true } serdeconv = { workspace = true } serde_json = { workspace = true } -tower-http = { version = "0.4.0", features = ["cors"] } +tower-http = { version = "0.4.0", features = ["cors", "timeout"] } clap = { version = "4.3.0", features = ["derive"] } lazy_static = { workspace = true } rust-embed = "8.0.0" diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index ad595fe..c15e852 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -21,7 +21,7 @@ use tabby_common::{ }; use tabby_download::Downloader; use tokio::time::sleep; -use tower_http::cors::CorsLayer; +use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tracing::{debug, info, warn}; use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa_swagger_ui::SwaggerUi; @@ -201,44 +201,55 @@ fn api_router(args: &ServeArgs) -> Router { None }; + let mut routers = vec![]; + let health_state = Arc::new(health::HealthState::new(args)); - let router = Router::new() - .route("/v1/events", routing::post(events::log_event)) - /* Remove POST /v1/health route in next major version release. */ - .route( - "/v1/health", - routing::post(health::health).with_state(health_state.clone()), - ) - .route( - "/v1/health", - routing::get(health::health).with_state(health_state), - ) - .route( - "/v1/completions", - routing::post(completions::completions).with_state(completion_state), - ); + routers.push({ + Router::new() + .route("/v1/events", routing::post(events::log_event)) + .route( + "/v1/health", + routing::post(health::health).with_state(health_state.clone()), + ) + .route( + "/v1/health", + routing::get(health::health).with_state(health_state), + ) + }); - let router = if let Some(chat_state) = chat_state { - router.route( - "/v1beta/chat/completions", - routing::post(chat::completions).with_state(chat_state), - ) - } else { - router - }; + routers.push({ + Router::new() + .route( + "/v1/completions", + routing::post(completions::completions).with_state(completion_state), + ) + .layer(TimeoutLayer::new(Duration::from_secs(3))) + }); - let router = if let Some(index_server) = index_server { + if let Some(chat_state) = chat_state { + routers.push({ + Router::new().route( + "/v1beta/chat/completions", + routing::post(chat::completions).with_state(chat_state), + ) + }) + } + + if let Some(index_server) = index_server { info!("Index is ready, enabling /v1beta/search API route"); - router.route( - "/v1beta/search", - routing::get(search::search).with_state(index_server), - ) - } else { - router - }; + routers.push({ + Router::new().route( + "/v1beta/search", + routing::get(search::search).with_state(index_server), + ) + }) + } - router - .layer(CorsLayer::permissive()) + let mut root = Router::new(); + for router in routers { + root = root.merge(router); + } + root.layer(CorsLayer::permissive()) .layer(opentelemetry_tracing_layer()) }