chore: set max timeout for /v1/completions handler (#526)

* chore: set max timeout for /v1/completions handler

* refactor: extract sub routers

* fix
r0.3
Meng Zhang 2023-10-09 18:44:55 -07:00 committed by GitHub
parent 24eadf0de8
commit d21a4de79c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 35 deletions

1
Cargo.lock generated
View File

@ -3722,6 +3722,7 @@ dependencies = [
"http-body", "http-body",
"http-range-header", "http-range-header",
"pin-project-lite", "pin-project-lite",
"tokio",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
] ]

View File

@ -18,7 +18,7 @@ utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
serde = { workspace = true } serde = { workspace = true }
serdeconv = { workspace = true } serdeconv = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
tower-http = { version = "0.4.0", features = ["cors"] } tower-http = { version = "0.4.0", features = ["cors", "timeout"] }
clap = { version = "4.3.0", features = ["derive"] } clap = { version = "4.3.0", features = ["derive"] }
lazy_static = { workspace = true } lazy_static = { workspace = true }
rust-embed = "8.0.0" rust-embed = "8.0.0"

View File

@ -21,7 +21,7 @@ use tabby_common::{
}; };
use tabby_download::Downloader; use tabby_download::Downloader;
use tokio::time::sleep; use tokio::time::sleep;
use tower_http::cors::CorsLayer; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -201,10 +201,12 @@ fn api_router(args: &ServeArgs) -> Router {
None None
}; };
let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(args)); let health_state = Arc::new(health::HealthState::new(args));
let router = Router::new() routers.push({
Router::new()
.route("/v1/events", routing::post(events::log_event)) .route("/v1/events", routing::post(events::log_event))
/* Remove POST /v1/health route in next major version release. */
.route( .route(
"/v1/health", "/v1/health",
routing::post(health::health).with_state(health_state.clone()), routing::post(health::health).with_state(health_state.clone()),
@ -213,32 +215,41 @@ fn api_router(args: &ServeArgs) -> Router {
"/v1/health", "/v1/health",
routing::get(health::health).with_state(health_state), routing::get(health::health).with_state(health_state),
) )
});
routers.push({
Router::new()
.route( .route(
"/v1/completions", "/v1/completions",
routing::post(completions::completions).with_state(completion_state), routing::post(completions::completions).with_state(completion_state),
); )
.layer(TimeoutLayer::new(Duration::from_secs(3)))
});
let router = if let Some(chat_state) = chat_state { if let Some(chat_state) = chat_state {
router.route( routers.push({
Router::new().route(
"/v1beta/chat/completions", "/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state), routing::post(chat::completions).with_state(chat_state),
) )
} else { })
router }
};
let router = if let Some(index_server) = index_server { if let Some(index_server) = index_server {
info!("Index is ready, enabling /v1beta/search API route"); info!("Index is ready, enabling /v1beta/search API route");
router.route( routers.push({
Router::new().route(
"/v1beta/search", "/v1beta/search",
routing::get(search::search).with_state(index_server), routing::get(search::search).with_state(index_server),
) )
} else { })
router }
};
router let mut root = Router::new();
.layer(CorsLayer::permissive()) for router in routers {
root = root.merge(router);
}
root.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer()) .layer(opentelemetry_tracing_layer())
} }