implement auth in swagger
parent
d2d158c747
commit
2b8a07baa0
|
|
@ -20,6 +20,9 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
|
(status = 200, description = "Success", body = ChatCompletionChunk, content_type = "application/jsonstream"),
|
||||||
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
|
(status = 405, description = "When chat model is not specified, the endpoint will returns 405 Method Not Allowed"),
|
||||||
|
),
|
||||||
|
security(
|
||||||
|
("auth_token" = [])
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, request))]
|
#[instrument(skip(state, request))]
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success", body = CompletionResponse, content_type = "application/json"),
|
(status = 200, description = "Success", body = CompletionResponse, content_type = "application/json"),
|
||||||
(status = 400, description = "Bad Request")
|
(status = 400, description = "Bad Request")
|
||||||
|
),
|
||||||
|
security(
|
||||||
|
("auth_token" = [])
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, request))]
|
#[instrument(skip(state, request))]
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind};
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success"),
|
(status = 200, description = "Success"),
|
||||||
(status = 400, description = "Bad Request")
|
(status = 400, description = "Bad Request")
|
||||||
|
),
|
||||||
|
security(
|
||||||
|
("auth_token" = [])
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn log_event(
|
pub async fn log_event(
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,9 @@ use crate::services::health;
|
||||||
tag = "v1",
|
tag = "v1",
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
||||||
|
),
|
||||||
|
security(
|
||||||
|
("auth_token" = [])
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
|
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,11 @@ pub struct SearchQuery {
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
|
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
|
||||||
(status = 501, description = "When code search is not enabled, the endpoint will returns 501 Not Implemented"),
|
(status = 501, description = "When code search is not enabled, the endpoint will returns 501 Not Implemented"),
|
||||||
)
|
),
|
||||||
)]
|
security(
|
||||||
|
("auth_token" = [])
|
||||||
|
)
|
||||||
|
)]
|
||||||
#[instrument(skip(state, query))]
|
#[instrument(skip(state, query))]
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
State(state): State<Arc<dyn CodeSearch>>,
|
State(state): State<Arc<dyn CodeSearch>>,
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,10 @@ use tabby_common::{
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tower_http::timeout::TimeoutLayer;
|
use tower_http::timeout::TimeoutLayer;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::{
|
||||||
|
openapi::security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
|
||||||
|
Modify, OpenApi,
|
||||||
|
};
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
@ -63,7 +66,8 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
api::code::SearchResponse,
|
api::code::SearchResponse,
|
||||||
api::code::Hit,
|
api::code::Hit,
|
||||||
api::code::HitDocument
|
api::code::HitDocument
|
||||||
))
|
)),
|
||||||
|
modifiers(&SecurityAddon),
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
|
|
@ -245,3 +249,21 @@ fn start_heartbeat(args: &ServeArgs) {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SecurityAddon;
|
||||||
|
|
||||||
|
impl Modify for SecurityAddon {
|
||||||
|
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||||||
|
if let Some(components) = &mut openapi.components {
|
||||||
|
components.add_security_scheme(
|
||||||
|
"auth_token",
|
||||||
|
SecurityScheme::Http(
|
||||||
|
HttpBuilder::new()
|
||||||
|
.scheme(HttpAuthScheme::Bearer)
|
||||||
|
.bearer_format("uuid")
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ use super::from_validation_errors;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref JWT_TOKEN_SECRET: String = jwt_token_secret();
|
static ref JWT_TOKEN_SECRET: String = jwt_token_secret();
|
||||||
|
|
||||||
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
||||||
JWT_TOKEN_SECRET.as_bytes()
|
JWT_TOKEN_SECRET.as_bytes()
|
||||||
);
|
);
|
||||||
|
|
@ -41,7 +42,9 @@ fn jwt_token_secret() -> String {
|
||||||
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
|
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
warn!(r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes.");
|
warn!(
|
||||||
|
r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes."
|
||||||
|
);
|
||||||
Uuid::new_v4().to_string()
|
Uuid::new_v4().to_string()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -277,7 +280,7 @@ pub trait AuthenticationService: Send + Sync {
|
||||||
&self,
|
&self,
|
||||||
refresh_token: String,
|
refresh_token: String,
|
||||||
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
||||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
||||||
async fn is_admin_initialized(&self) -> Result<bool>;
|
async fn is_admin_initialized(&self) -> Result<bool>;
|
||||||
|
|
||||||
async fn create_invitation(&self, email: String) -> Result<i32>;
|
async fn create_invitation(&self, email: String) -> Result<i32>;
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ impl Mutation {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
||||||
Ok(ctx.locator.auth().verify_token(token).await?)
|
Ok(ctx.locator.auth().verify_access_token(token).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn refresh_token(
|
async fn refresh_token(
|
||||||
|
|
|
||||||
|
|
@ -220,7 +220,7 @@ impl AuthenticationService for DbConn {
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
||||||
let claims = validate_jwt(&access_token)?;
|
let claims = validate_jwt(&access_token)?;
|
||||||
let resp = VerifyTokenResponse::new(claims);
|
let resp = VerifyTokenResponse::new(claims);
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
use axum::{
|
||||||
|
http::{HeaderValue, Request},
|
||||||
|
middleware::Next,
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
@ -41,6 +45,44 @@ impl ServerContext {
|
||||||
code,
|
code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn authorize_request(&self, request: &Request<Body>) -> bool {
|
||||||
|
let path = request.uri().path();
|
||||||
|
if (path.starts_with("/v1/") || path.starts_with("/v1beta/"))
|
||||||
|
// Authorization is enabled
|
||||||
|
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let auth_token = {
|
||||||
|
let authorization = request
|
||||||
|
.headers()
|
||||||
|
.get("authorization")
|
||||||
|
.map(HeaderValue::to_str)
|
||||||
|
.and_then(Result::ok);
|
||||||
|
|
||||||
|
if let Some(authorization) = authorization {
|
||||||
|
let split = authorization.split_once(' ');
|
||||||
|
match split {
|
||||||
|
// Found proper bearer
|
||||||
|
Some(("Bearer", contents)) => Some(contents),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(auth_token) = auth_token {
|
||||||
|
if !self.db_conn.verify_auth_token(auth_token).await {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Admin system is initialized, but there's no valid token.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -95,7 +137,13 @@ impl WorkerService for ServerContext {
|
||||||
request: Request<Body>,
|
request: Request<Body>,
|
||||||
next: Next<Body>,
|
next: Next<Body>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
let path = request.uri().path();
|
if !self.authorize_request(&request).await {
|
||||||
|
return axum::response::Response::builder()
|
||||||
|
.status(StatusCode::UNAUTHORIZED)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap()
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
let remote_addr = request
|
let remote_addr = request
|
||||||
.extensions()
|
.extensions()
|
||||||
|
|
@ -103,6 +151,7 @@ impl WorkerService for ServerContext {
|
||||||
.map(|ci| ci.0)
|
.map(|ci| ci.0)
|
||||||
.expect("Unable to extract remote addr");
|
.expect("Unable to extract remote addr");
|
||||||
|
|
||||||
|
let path = request.uri().path();
|
||||||
let worker = if path.starts_with("/v1/completions") {
|
let worker = if path.starts_with("/v1/completions") {
|
||||||
self.completion.select().await
|
self.completion.select().await
|
||||||
} else if path.starts_with("/v1beta/chat/completions") {
|
} else if path.starts_with("/v1beta/chat/completions") {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue