feat: when is_admin_initialized, implement strict api access check in graphql / v1beta / v1 (#987)

r0.7
Meng Zhang 2023-12-08 19:59:56 +08:00 committed by GitHub
parent f4224f0417
commit d060888b5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 14 deletions

View File

@ -285,7 +285,7 @@ pub trait AuthenticationService: Send + Sync {
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>;
async fn create_invitation(&self, email: String) -> Result<i32>;

View File

@ -71,13 +71,37 @@ pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.locator.worker().list_workers().await
async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
if ctx.locator.auth().is_admin_initialized().await? {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let workers = ctx.locator.worker().list_workers().await;
return Ok(workers);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read workers",
))
} else {
Ok(ctx.locator.worker().list_workers().await)
}
}
async fn registration_token(ctx: &Context) -> Result<String> {
let token = ctx.locator.worker().read_registration_token().await?;
Ok(token)
if ctx.locator.auth().is_admin_initialized().await? {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let token = ctx.locator.worker().read_registration_token().await?;
return Ok(token);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read registeration_token",
))
} else {
let token = ctx.locator.worker().read_registration_token().await?;
Ok(token)
}
}
async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
@ -142,7 +166,7 @@ impl Mutation {
}
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
Ok(ctx.locator.auth().verify_token(token).await?)
Ok(ctx.locator.auth().verify_access_token(&token).await?)
}
async fn refresh_token(

View File

@ -220,8 +220,8 @@ impl AuthenticationService for DbConn {
Ok(resp)
}
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
let claims = validate_jwt(&access_token)?;
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse> {
let claims = validate_jwt(access_token)?;
let resp = VerifyTokenResponse::new(claims);
Ok(resp)
}

View File

@ -27,6 +27,7 @@ lazy_static! {
"#
)
.down("DROP TABLE registration_token"),
// ==== Above migrations released in 0.6.0 ====
M::up(
r#"
CREATE TABLE users (
@ -36,7 +37,10 @@ lazy_static! {
is_admin BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
updated_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_email` UNIQUE (`email`)
auth_token VARCHAR(128) NOT NULL,
CONSTRAINT `idx_email` UNIQUE (`email`)
CONSTRAINT `idx_auth_token` UNIQUE (`auth_token`)
);
"#
)

View File

@ -3,6 +3,7 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use rusqlite::{params, OptionalExtension, Row};
use uuid::Uuid;
use super::DbConn;
@ -15,11 +16,14 @@ pub struct User {
pub email: String,
pub password_encrypted: String,
pub is_admin: bool,
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
pub auth_token: String,
}
impl User {
fn select(clause: &str) -> String {
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "#
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token FROM users WHERE "#
.to_owned()
+ clause
}
@ -32,6 +36,7 @@ impl User {
is_admin: row.get(3)?,
created_at: row.get(4)?,
updated_at: row.get(5)?,
auth_token: row.get(6)?,
})
}
}
@ -47,9 +52,9 @@ impl DbConn {
.conn
.call(move |c| {
let mut stmt = c.prepare(
r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#,
r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
)?;
let id = stmt.insert((email, password_encrypted, is_admin))?;
let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?;
Ok(id)
})
.await?;
@ -98,6 +103,38 @@ impl DbConn {
Ok(users)
}
pub async fn verify_auth_token(&self, token: &str) -> bool {
let token = token.to_owned();
let id: Result<i32, _> = self
.conn
.call(move |c| {
c.query_row(
r#"SELECT id FROM users WHERE auth_token = ?"#,
params![token],
|row| row.get(0),
)
})
.await;
id.is_ok()
}
pub async fn reset_auth_token(&self, id: i32) -> Result<i32> {
self.conn
.call(move |c| {
let mut stmt = c.prepare(r#"UPDATE users SET auth_token = ? WHERE id = ?"#)?;
stmt.execute((Uuid::new_v4().to_string(), id))?;
Ok(())
})
.await?;
Ok(id)
}
}
fn generate_auth_token() -> String {
let uuid = Uuid::new_v4().to_string().replace('-', "");
format!("auth_{}", uuid)
}
#[cfg(test)]
@ -123,4 +160,21 @@ mod tests {
assert!(user.is_none());
}
#[tokio::test]
async fn test_auth_token() {
let conn = DbConn::new_in_memory().await.unwrap();
let id = create_user(&conn).await;
let user = conn.get_user(id).await.unwrap().unwrap();
assert!(!conn.verify_auth_token("abcd").await);
assert!(conn.verify_auth_token(&user.auth_token).await);
conn.reset_auth_token(id).await.unwrap();
let new_user = conn.get_user(id).await.unwrap().unwrap();
assert_eq!(user.email, new_user.email);
assert_ne!(user.auth_token, new_user.auth_token);
}
}

View File

@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use async_trait::async_trait;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use axum::{
http::{HeaderValue, Request},
middleware::Next,
response::IntoResponse,
};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn};
@ -41,6 +45,46 @@ impl ServerContext {
code,
}
}
async fn authorize_request(&self, request: &Request<Body>) -> bool {
let path = request.uri().path();
if (path.starts_with("/v1/") || path.starts_with("/v1beta/"))
// Authorization is enabled
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
{
let token = {
let authorization = request
.headers()
.get("authorization")
.map(HeaderValue::to_str)
.and_then(Result::ok);
if let Some(authorization) = authorization {
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some(("Bearer", contents)) => Some(contents),
_ => None,
}
} else {
None
}
};
if let Some(token) = token {
if self.db_conn.verify_access_token(token).await.is_err()
&& !self.db_conn.verify_auth_token(token).await
{
return false;
}
} else {
// Admin system is initialized, but there's no valid token.
return false;
}
}
true
}
}
#[async_trait]
@ -95,7 +139,13 @@ impl WorkerService for ServerContext {
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
let path = request.uri().path();
if !self.authorize_request(&request).await {
return axum::response::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
.into_response();
}
let remote_addr = request
.extensions()
@ -103,6 +153,7 @@ impl WorkerService for ServerContext {
.map(|ci| ci.0)
.expect("Unable to extract remote addr");
let path = request.uri().path();
let worker = if path.starts_with("/v1/completions") {
self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") {