feat: when is_admin_initialized, implement strict api access check in graphql / v1beta / v1 (#987)
parent
f4224f0417
commit
d060888b5c
|
|
@ -285,7 +285,7 @@ pub trait AuthenticationService: Send + Sync {
|
|||
&self,
|
||||
refresh_token: String,
|
||||
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
||||
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse>;
|
||||
async fn is_admin_initialized(&self) -> Result<bool>;
|
||||
|
||||
async fn create_invitation(&self, email: String) -> Result<i32>;
|
||||
|
|
|
|||
|
|
@ -71,13 +71,37 @@ pub struct Query;
|
|||
|
||||
#[graphql_object(context = Context)]
|
||||
impl Query {
|
||||
async fn workers(ctx: &Context) -> Vec<Worker> {
|
||||
ctx.locator.worker().list_workers().await
|
||||
async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
|
||||
if ctx.locator.auth().is_admin_initialized().await? {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.user_info().is_admin() {
|
||||
let workers = ctx.locator.worker().list_workers().await;
|
||||
return Ok(workers);
|
||||
}
|
||||
}
|
||||
Err(CoreError::Unauthorized(
|
||||
"Only admin is able to read workers",
|
||||
))
|
||||
} else {
|
||||
Ok(ctx.locator.worker().list_workers().await)
|
||||
}
|
||||
}
|
||||
|
||||
async fn registration_token(ctx: &Context) -> Result<String> {
|
||||
let token = ctx.locator.worker().read_registration_token().await?;
|
||||
Ok(token)
|
||||
if ctx.locator.auth().is_admin_initialized().await? {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.user_info().is_admin() {
|
||||
let token = ctx.locator.worker().read_registration_token().await?;
|
||||
return Ok(token);
|
||||
}
|
||||
}
|
||||
Err(CoreError::Unauthorized(
|
||||
"Only admin is able to read registeration_token",
|
||||
))
|
||||
} else {
|
||||
let token = ctx.locator.worker().read_registration_token().await?;
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
|
||||
|
|
@ -142,7 +166,7 @@ impl Mutation {
|
|||
}
|
||||
|
||||
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
||||
Ok(ctx.locator.auth().verify_token(token).await?)
|
||||
Ok(ctx.locator.auth().verify_access_token(&token).await?)
|
||||
}
|
||||
|
||||
async fn refresh_token(
|
||||
|
|
|
|||
|
|
@ -220,8 +220,8 @@ impl AuthenticationService for DbConn {
|
|||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
||||
let claims = validate_jwt(&access_token)?;
|
||||
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse> {
|
||||
let claims = validate_jwt(access_token)?;
|
||||
let resp = VerifyTokenResponse::new(claims);
|
||||
Ok(resp)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ lazy_static! {
|
|||
"#
|
||||
)
|
||||
.down("DROP TABLE registration_token"),
|
||||
// ==== Above migrations released in 0.6.0 ====
|
||||
M::up(
|
||||
r#"
|
||||
CREATE TABLE users (
|
||||
|
|
@ -36,7 +37,10 @@ lazy_static! {
|
|||
is_admin BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT (DATETIME('now')),
|
||||
updated_at TIMESTAMP DEFAULT (DATETIME('now')),
|
||||
CONSTRAINT `idx_email` UNIQUE (`email`)
|
||||
auth_token VARCHAR(128) NOT NULL,
|
||||
|
||||
CONSTRAINT `idx_email` UNIQUE (`email`)
|
||||
CONSTRAINT `idx_auth_token` UNIQUE (`auth_token`)
|
||||
);
|
||||
"#
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rusqlite::{params, OptionalExtension, Row};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::DbConn;
|
||||
|
||||
|
|
@ -15,11 +16,14 @@ pub struct User {
|
|||
pub email: String,
|
||||
pub password_encrypted: String,
|
||||
pub is_admin: bool,
|
||||
|
||||
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
|
||||
pub auth_token: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn select(clause: &str) -> String {
|
||||
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "#
|
||||
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token FROM users WHERE "#
|
||||
.to_owned()
|
||||
+ clause
|
||||
}
|
||||
|
|
@ -32,6 +36,7 @@ impl User {
|
|||
is_admin: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
updated_at: row.get(5)?,
|
||||
auth_token: row.get(6)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -47,9 +52,9 @@ impl DbConn {
|
|||
.conn
|
||||
.call(move |c| {
|
||||
let mut stmt = c.prepare(
|
||||
r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#,
|
||||
r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
|
||||
)?;
|
||||
let id = stmt.insert((email, password_encrypted, is_admin))?;
|
||||
let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?;
|
||||
Ok(id)
|
||||
})
|
||||
.await?;
|
||||
|
|
@ -98,6 +103,38 @@ impl DbConn {
|
|||
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub async fn verify_auth_token(&self, token: &str) -> bool {
|
||||
let token = token.to_owned();
|
||||
let id: Result<i32, _> = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
c.query_row(
|
||||
r#"SELECT id FROM users WHERE auth_token = ?"#,
|
||||
params![token],
|
||||
|row| row.get(0),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
id.is_ok()
|
||||
}
|
||||
|
||||
pub async fn reset_auth_token(&self, id: i32) -> Result<i32> {
|
||||
self.conn
|
||||
.call(move |c| {
|
||||
let mut stmt = c.prepare(r#"UPDATE users SET auth_token = ? WHERE id = ?"#)?;
|
||||
stmt.execute((Uuid::new_v4().to_string(), id))?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_auth_token() -> String {
|
||||
let uuid = Uuid::new_v4().to_string().replace('-', "");
|
||||
format!("auth_{}", uuid)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -123,4 +160,21 @@ mod tests {
|
|||
|
||||
assert!(user.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_token() {
|
||||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
let id = create_user(&conn).await;
|
||||
|
||||
let user = conn.get_user(id).await.unwrap().unwrap();
|
||||
|
||||
assert!(!conn.verify_auth_token("abcd").await);
|
||||
|
||||
assert!(conn.verify_auth_token(&user.auth_token).await);
|
||||
|
||||
conn.reset_auth_token(id).await.unwrap();
|
||||
let new_user = conn.get_user(id).await.unwrap().unwrap();
|
||||
assert_eq!(user.email, new_user.email);
|
||||
assert_ne!(user.auth_token, new_user.auth_token);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc};
|
|||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
||||
use axum::{
|
||||
http::{HeaderValue, Request},
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||
use tracing::{info, warn};
|
||||
|
|
@ -41,6 +45,46 @@ impl ServerContext {
|
|||
code,
|
||||
}
|
||||
}
|
||||
|
||||
async fn authorize_request(&self, request: &Request<Body>) -> bool {
|
||||
let path = request.uri().path();
|
||||
if (path.starts_with("/v1/") || path.starts_with("/v1beta/"))
|
||||
// Authorization is enabled
|
||||
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
|
||||
{
|
||||
let token = {
|
||||
let authorization = request
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.map(HeaderValue::to_str)
|
||||
.and_then(Result::ok);
|
||||
|
||||
if let Some(authorization) = authorization {
|
||||
let split = authorization.split_once(' ');
|
||||
match split {
|
||||
// Found proper bearer
|
||||
Some(("Bearer", contents)) => Some(contents),
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(token) = token {
|
||||
if self.db_conn.verify_access_token(token).await.is_err()
|
||||
&& !self.db_conn.verify_auth_token(token).await
|
||||
{
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Admin system is initialized, but there's no valid token.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -95,7 +139,13 @@ impl WorkerService for ServerContext {
|
|||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response {
|
||||
let path = request.uri().path();
|
||||
if !self.authorize_request(&request).await {
|
||||
return axum::response::Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let remote_addr = request
|
||||
.extensions()
|
||||
|
|
@ -103,6 +153,7 @@ impl WorkerService for ServerContext {
|
|||
.map(|ci| ci.0)
|
||||
.expect("Unable to extract remote addr");
|
||||
|
||||
let path = request.uri().path();
|
||||
let worker = if path.starts_with("/v1/completions") {
|
||||
self.completion.select().await
|
||||
} else if path.starts_with("/v1beta/chat/completions") {
|
||||
|
|
|
|||
Loading…
Reference in New Issue