chore: set max timeout for /v1/completions handler (#526)

* chore: set max timeout for /v1/completions handler

* refactor: extract sub routers

* fix
r0.3
Meng Zhang 2023-10-09 18:44:55 -07:00 committed by GitHub
parent 24eadf0de8
commit d21a4de79c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 35 deletions

1
Cargo.lock generated
View File

@ -3722,6 +3722,7 @@ dependencies = [
"http-body",
"http-range-header",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
]

View File

@ -18,7 +18,7 @@ utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
serde = { workspace = true }
serdeconv = { workspace = true }
serde_json = { workspace = true }
tower-http = { version = "0.4.0", features = ["cors"] }
tower-http = { version = "0.4.0", features = ["cors", "timeout"] }
clap = { version = "4.3.0", features = ["derive"] }
lazy_static = { workspace = true }
rust-embed = "8.0.0"

View File

@ -21,7 +21,7 @@ use tabby_common::{
};
use tabby_download::Downloader;
use tokio::time::sleep;
use tower_http::cors::CorsLayer;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::{debug, info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi;
@ -201,44 +201,55 @@ fn api_router(args: &ServeArgs) -> Router {
None
};
let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(args));
let router = Router::new()
.route("/v1/events", routing::post(events::log_event))
/* Remove POST /v1/health route in next major version release. */
.route(
"/v1/health",
routing::post(health::health).with_state(health_state.clone()),
)
.route(
"/v1/health",
routing::get(health::health).with_state(health_state),
)
.route(
"/v1/completions",
routing::post(completions::completions).with_state(completion_state),
);
routers.push({
Router::new()
.route("/v1/events", routing::post(events::log_event))
.route(
"/v1/health",
routing::post(health::health).with_state(health_state.clone()),
)
.route(
"/v1/health",
routing::get(health::health).with_state(health_state),
)
});
let router = if let Some(chat_state) = chat_state {
router.route(
"/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state),
)
} else {
router
};
routers.push({
Router::new()
.route(
"/v1/completions",
routing::post(completions::completions).with_state(completion_state),
)
.layer(TimeoutLayer::new(Duration::from_secs(3)))
});
let router = if let Some(index_server) = index_server {
if let Some(chat_state) = chat_state {
routers.push({
Router::new().route(
"/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state),
)
})
}
if let Some(index_server) = index_server {
info!("Index is ready, enabling /v1beta/search API route");
router.route(
"/v1beta/search",
routing::get(search::search).with_state(index_server),
)
} else {
router
};
routers.push({
Router::new().route(
"/v1beta/search",
routing::get(search::search).with_state(index_server),
)
})
}
router
.layer(CorsLayer::permissive())
let mut root = Router::new();
for router in routers {
root = root.merge(router);
}
root.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer())
}