From 2b8a07baa012205f9146bbd1eeae43d295e141b2 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 8 Dec 2023 15:10:57 +0800 Subject: [PATCH] implement auth in swagger --- crates/tabby/src/routes/chat.rs | 3 ++ crates/tabby/src/routes/completions.rs | 3 ++ crates/tabby/src/routes/events.rs | 3 ++ crates/tabby/src/routes/health.rs | 3 ++ crates/tabby/src/routes/search.rs | 7 +++- crates/tabby/src/serve.rs | 26 ++++++++++++- ee/tabby-webserver/src/schema/auth.rs | 7 +++- ee/tabby-webserver/src/schema/mod.rs | 2 +- ee/tabby-webserver/src/service/auth.rs | 2 +- ee/tabby-webserver/src/service/mod.rs | 53 +++++++++++++++++++++++++- 10 files changed, 99 insertions(+), 10 deletions(-) diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index 1b54333..86453f4 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -20,6 +20,9 @@ use crate::services::chat::{ChatCompletionRequest, ChatService}; responses( (status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"), (status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"), + ), + security( + ("auth_token" = []) ) )] #[instrument(skip(state, request))] diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index d394d18..8cb04ff 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -15,6 +15,9 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet responses( (status = 200, description = "Success", body = CompletionResponse, content_type = "application/json"), (status = 400, description = "Bad Request") + ), + security( + ("auth_token" = []) ) )] #[instrument(skip(state, request))] diff --git a/crates/tabby/src/routes/events.rs b/crates/tabby/src/routes/events.rs index c8f747d..8088551 100644 --- a/crates/tabby/src/routes/events.rs +++ b/crates/tabby/src/routes/events.rs @@ -16,6 +16,9 @@ use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind}; responses( (status = 200, description = "Success"), (status = 400, description = "Bad Request") + ), + security( + ("auth_token" = []) ) )] pub async fn log_event( diff --git a/crates/tabby/src/routes/health.rs b/crates/tabby/src/routes/health.rs index 9483ac2..3e7adfb 100644 --- a/crates/tabby/src/routes/health.rs +++ b/crates/tabby/src/routes/health.rs @@ -10,6 +10,9 @@ use crate::services::health; tag = "v1", responses( (status = 200, description = "Success", body = HealthState, content_type = "application/json"), + ), + security( + ("auth_token" = []) ) )] pub async fn health(State(state): State>) -> Json { diff --git a/crates/tabby/src/routes/search.rs b/crates/tabby/src/routes/search.rs index 99edade..e0d51a3 100644 --- a/crates/tabby/src/routes/search.rs +++ b/crates/tabby/src/routes/search.rs @@ -32,8 +32,11 @@ pub struct SearchQuery { responses( (status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"), (status = 501, description = "When code search is not enabled, the endpoint will returns 501 Not Implemented"), - ) - )] + ), + security( + ("auth_token" = []) + ) +)] #[instrument(skip(state, query))] pub async fn search( State(state): State>, diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index fa54cb4..c764a45 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -12,7 +12,10 @@ use tabby_common::{ use tokio::time::sleep; use tower_http::timeout::TimeoutLayer; use tracing::info; -use utoipa::OpenApi; +use utoipa::{ + openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme}, + Modify, OpenApi, +}; use utoipa_swagger_ui::SwaggerUi; use crate::{ @@ -63,7 +66,8 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi api::code::SearchResponse, api::code::Hit, api::code::HitDocument - )) + )), + modifiers(&SecurityAddon), )] struct ApiDoc; @@ -245,3 +249,21 @@ fn start_heartbeat(args: &ServeArgs) { } }); } + +struct SecurityAddon; + +impl Modify for SecurityAddon { + fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { + if let Some(components) = &mut openapi.components { + components.add_security_scheme( + "auth_token", + SecurityScheme::Http( + HttpBuilder::new() + .scheme(HttpAuthScheme::Bearer) + .bearer_format("uuid") + .build(), + ), + ) + } + } +} diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index e4ca305..f8b7529 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -16,6 +16,7 @@ use super::from_validation_errors; lazy_static! { static ref JWT_TOKEN_SECRET: String = jwt_token_secret(); + static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret( JWT_TOKEN_SECRET.as_bytes() ); @@ -41,7 +42,9 @@ fn jwt_token_secret() -> String { let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") { Ok(x) => x, Err(_) => { - warn!(r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes."); + warn!( + r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes." + ); Uuid::new_v4().to_string() } }; @@ -277,7 +280,7 @@ pub trait AuthenticationService: Send + Sync { &self, refresh_token: String, ) -> std::result::Result; - async fn verify_token(&self, access_token: String) -> Result; + async fn verify_access_token(&self, access_token: String) -> Result; async fn is_admin_initialized(&self) -> Result; async fn create_invitation(&self, email: String) -> Result; diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 798a56f..502d8ca 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -142,7 +142,7 @@ impl Mutation { } async fn verify_token(ctx: &Context, token: String) -> Result { - Ok(ctx.locator.auth().verify_token(token).await?) + Ok(ctx.locator.auth().verify_access_token(token).await?) } async fn refresh_token( diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 6cccb87..269a598 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -220,7 +220,7 @@ impl AuthenticationService for DbConn { Ok(resp) } - async fn verify_token(&self, access_token: String) -> Result { + async fn verify_access_token(&self, access_token: String) -> Result { let claims = validate_jwt(&access_token)?; let resp = VerifyTokenResponse::new(claims); Ok(resp) diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index c25795d..46f237f 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; use async_trait::async_trait; -use axum::{http::Request, middleware::Next, response::IntoResponse}; +use axum::{ + http::{HeaderValue, Request}, + middleware::Next, + response::IntoResponse, +}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tracing::{info, warn}; @@ -41,6 +45,44 @@ impl ServerContext { code, } } + + async fn authorize_request(&self, request: &Request) -> bool { + let path = request.uri().path(); + if (path.starts_with("/v1/") || path.starts_with("/v1beta/")) + // Authorization is enabled + && self.db_conn.is_admin_initialized().await.unwrap_or(false) + { + let auth_token = { + let authorization = request + .headers() + .get("authorization") + .map(HeaderValue::to_str) + .and_then(Result::ok); + + if let Some(authorization) = authorization { + let split = authorization.split_once(' '); + match split { + // Found proper bearer + Some(("Bearer", contents)) => Some(contents), + _ => None, + } + } else { + None + } + }; + + if let Some(auth_token) = auth_token { + if !self.db_conn.verify_auth_token(auth_token).await { + return false; + } + } else { + // Admin system is initialized, but there's no valid token. + return false; + } + } + + true + } } #[async_trait] @@ -95,7 +137,13 @@ impl WorkerService for ServerContext { request: Request, next: Next, ) -> axum::response::Response { - let path = request.uri().path(); + if !self.authorize_request(&request).await { + return axum::response::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap() + .into_response(); + } let remote_addr = request .extensions() @@ -103,6 +151,7 @@ impl WorkerService for ServerContext { .map(|ci| ci.0) .expect("Unable to extract remote addr"); + let path = request.uri().path(); let worker = if path.starts_with("/v1/completions") { self.completion.select().await } else if path.starts_with("/v1beta/chat/completions") {