implement auth in swagger

support-auth-token
Meng Zhang 2023-12-08 15:10:57 +08:00
parent d2d158c747
commit 2b8a07baa0
10 changed files with 99 additions and 10 deletions

View File

@ -20,6 +20,9 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
responses(
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
),
security(
("auth_token" = [])
)
)]
#[instrument(skip(state, request))]

View File

@ -15,6 +15,9 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet
responses(
(status = 200, description = "Success", body = CompletionResponse, content_type = "application/json"),
(status = 400, description = "Bad Request")
),
security(
("auth_token" = [])
)
)]
#[instrument(skip(state, request))]

View File

@ -16,6 +16,9 @@ use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind};
responses(
(status = 200, description = "Success"),
(status = 400, description = "Bad Request")
),
security(
("auth_token" = [])
)
)]
pub async fn log_event(

View File

@ -10,6 +10,9 @@ use crate::services::health;
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
),
security(
("auth_token" = [])
)
)]
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {

View File

@ -32,8 +32,11 @@ pub struct SearchQuery {
responses(
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
(status = 501, description = "When code search is not enabled, the endpoint will returns 501 Not Implemented"),
)
)]
),
security(
("auth_token" = [])
)
)]
#[instrument(skip(state, query))]
pub async fn search(
State(state): State<Arc<dyn CodeSearch>>,

View File

@ -12,7 +12,10 @@ use tabby_common::{
use tokio::time::sleep;
use tower_http::timeout::TimeoutLayer;
use tracing::info;
use utoipa::OpenApi;
use utoipa::{
openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
Modify, OpenApi,
};
use utoipa_swagger_ui::SwaggerUi;
use crate::{
@ -63,7 +66,8 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
api::code::SearchResponse,
api::code::Hit,
api::code::HitDocument
))
)),
modifiers(&SecurityAddon),
)]
struct ApiDoc;
@ -245,3 +249,21 @@ fn start_heartbeat(args: &ServeArgs) {
}
});
}
struct SecurityAddon;
impl Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
if let Some(components) = &mut openapi.components {
components.add_security_scheme(
"auth_token",
SecurityScheme::Http(
HttpBuilder::new()
.scheme(HttpAuthScheme::Bearer)
.bearer_format("uuid")
.build(),
),
)
}
}
}

View File

@ -16,6 +16,7 @@ use super::from_validation_errors;
lazy_static! {
static ref JWT_TOKEN_SECRET: String = jwt_token_secret();
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
JWT_TOKEN_SECRET.as_bytes()
);
@ -41,7 +42,9 @@ fn jwt_token_secret() -> String {
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
Ok(x) => x,
Err(_) => {
warn!(r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes.");
warn!(
r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes."
);
Uuid::new_v4().to_string()
}
};
@ -277,7 +280,7 @@ pub trait AuthenticationService: Send + Sync {
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>;
async fn create_invitation(&self, email: String) -> Result<i32>;

View File

@ -142,7 +142,7 @@ impl Mutation {
}
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
Ok(ctx.locator.auth().verify_token(token).await?)
Ok(ctx.locator.auth().verify_access_token(token).await?)
}
async fn refresh_token(

View File

@ -220,7 +220,7 @@ impl AuthenticationService for DbConn {
Ok(resp)
}
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
let claims = validate_jwt(&access_token)?;
let resp = VerifyTokenResponse::new(claims);
Ok(resp)

View File

@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use async_trait::async_trait;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use axum::{
http::{HeaderValue, Request},
middleware::Next,
response::IntoResponse,
};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn};
@ -41,6 +45,44 @@ impl ServerContext {
code,
}
}
async fn authorize_request(&self, request: &Request<Body>) -> bool {
let path = request.uri().path();
if (path.starts_with("/v1/") || path.starts_with("/v1beta/"))
// Authorization is enabled
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
{
let auth_token = {
let authorization = request
.headers()
.get("authorization")
.map(HeaderValue::to_str)
.and_then(Result::ok);
if let Some(authorization) = authorization {
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some(("Bearer", contents)) => Some(contents),
_ => None,
}
} else {
None
}
};
if let Some(auth_token) = auth_token {
if !self.db_conn.verify_auth_token(auth_token).await {
return false;
}
} else {
// Admin system is initialized, but there's no valid token.
return false;
}
}
true
}
}
#[async_trait]
@ -95,7 +137,13 @@ impl WorkerService for ServerContext {
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
let path = request.uri().path();
if !self.authorize_request(&request).await {
return axum::response::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
.into_response();
}
let remote_addr = request
.extensions()
@ -103,6 +151,7 @@ impl WorkerService for ServerContext {
.map(|ci| ci.0)
.expect("Unable to extract remote addr");
let path = request.uri().path();
let worker = if path.starts_with("/v1/completions") {
self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") {