feat: when is_admin_initialized, implement strict api access check in graphql / v1beta / v1 (#987)

r0.7
Meng Zhang 2023-12-08 19:59:56 +08:00 committed by GitHub
parent f4224f0417
commit d060888b5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 14 deletions

View File

@ -285,7 +285,7 @@ pub trait AuthenticationService: Send + Sync {
&self, &self,
refresh_token: String, refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>; ) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>; async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>; async fn is_admin_initialized(&self) -> Result<bool>;
async fn create_invitation(&self, email: String) -> Result<i32>; async fn create_invitation(&self, email: String) -> Result<i32>;

View File

@ -71,14 +71,38 @@ pub struct Query;
#[graphql_object(context = Context)] #[graphql_object(context = Context)]
impl Query { impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> { async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
ctx.locator.worker().list_workers().await if ctx.locator.auth().is_admin_initialized().await? {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let workers = ctx.locator.worker().list_workers().await;
return Ok(workers);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read workers",
))
} else {
Ok(ctx.locator.worker().list_workers().await)
}
} }
async fn registration_token(ctx: &Context) -> Result<String> { async fn registration_token(ctx: &Context) -> Result<String> {
if ctx.locator.auth().is_admin_initialized().await? {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let token = ctx.locator.worker().read_registration_token().await?;
return Ok(token);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read registeration_token",
))
} else {
let token = ctx.locator.worker().read_registration_token().await?; let token = ctx.locator.worker().read_registration_token().await?;
Ok(token) Ok(token)
} }
}
async fn is_admin_initialized(ctx: &Context) -> Result<bool> { async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
Ok(ctx.locator.auth().is_admin_initialized().await?) Ok(ctx.locator.auth().is_admin_initialized().await?)
@ -142,7 +166,7 @@ impl Mutation {
} }
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> { async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
Ok(ctx.locator.auth().verify_token(token).await?) Ok(ctx.locator.auth().verify_access_token(&token).await?)
} }
async fn refresh_token( async fn refresh_token(

View File

@ -220,8 +220,8 @@ impl AuthenticationService for DbConn {
Ok(resp) Ok(resp)
} }
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> { async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse> {
let claims = validate_jwt(&access_token)?; let claims = validate_jwt(access_token)?;
let resp = VerifyTokenResponse::new(claims); let resp = VerifyTokenResponse::new(claims);
Ok(resp) Ok(resp)
} }

View File

@ -27,6 +27,7 @@ lazy_static! {
"# "#
) )
.down("DROP TABLE registration_token"), .down("DROP TABLE registration_token"),
// ==== Above migrations released in 0.6.0 ====
M::up( M::up(
r#" r#"
CREATE TABLE users ( CREATE TABLE users (
@ -36,7 +37,10 @@ lazy_static! {
is_admin BOOLEAN NOT NULL DEFAULT 0, is_admin BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT (DATETIME('now')), created_at TIMESTAMP DEFAULT (DATETIME('now')),
updated_at TIMESTAMP DEFAULT (DATETIME('now')), updated_at TIMESTAMP DEFAULT (DATETIME('now')),
auth_token VARCHAR(128) NOT NULL,
CONSTRAINT `idx_email` UNIQUE (`email`) CONSTRAINT `idx_email` UNIQUE (`email`)
CONSTRAINT `idx_auth_token` UNIQUE (`auth_token`)
); );
"# "#
) )

View File

@ -3,6 +3,7 @@
use anyhow::Result; use anyhow::Result;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rusqlite::{params, OptionalExtension, Row}; use rusqlite::{params, OptionalExtension, Row};
use uuid::Uuid;
use super::DbConn; use super::DbConn;
@ -15,11 +16,14 @@ pub struct User {
pub email: String, pub email: String,
pub password_encrypted: String, pub password_encrypted: String,
pub is_admin: bool, pub is_admin: bool,
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
pub auth_token: String,
} }
impl User { impl User {
fn select(clause: &str) -> String { fn select(clause: &str) -> String {
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "# r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token FROM users WHERE "#
.to_owned() .to_owned()
+ clause + clause
} }
@ -32,6 +36,7 @@ impl User {
is_admin: row.get(3)?, is_admin: row.get(3)?,
created_at: row.get(4)?, created_at: row.get(4)?,
updated_at: row.get(5)?, updated_at: row.get(5)?,
auth_token: row.get(6)?,
}) })
} }
} }
@ -47,9 +52,9 @@ impl DbConn {
.conn .conn
.call(move |c| { .call(move |c| {
let mut stmt = c.prepare( let mut stmt = c.prepare(
r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
)?; )?;
let id = stmt.insert((email, password_encrypted, is_admin))?; let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?;
Ok(id) Ok(id)
}) })
.await?; .await?;
@ -98,6 +103,38 @@ impl DbConn {
Ok(users) Ok(users)
} }
pub async fn verify_auth_token(&self, token: &str) -> bool {
let token = token.to_owned();
let id: Result<i32, _> = self
.conn
.call(move |c| {
c.query_row(
r#"SELECT id FROM users WHERE auth_token = ?"#,
params![token],
|row| row.get(0),
)
})
.await;
id.is_ok()
}
pub async fn reset_auth_token(&self, id: i32) -> Result<i32> {
self.conn
.call(move |c| {
let mut stmt = c.prepare(r#"UPDATE users SET auth_token = ? WHERE id = ?"#)?;
stmt.execute((Uuid::new_v4().to_string(), id))?;
Ok(())
})
.await?;
Ok(id)
}
}
fn generate_auth_token() -> String {
let uuid = Uuid::new_v4().to_string().replace('-', "");
format!("auth_{}", uuid)
} }
#[cfg(test)] #[cfg(test)]
@ -123,4 +160,21 @@ mod tests {
assert!(user.is_none()); assert!(user.is_none());
} }
#[tokio::test]
async fn test_auth_token() {
let conn = DbConn::new_in_memory().await.unwrap();
let id = create_user(&conn).await;
let user = conn.get_user(id).await.unwrap().unwrap();
assert!(!conn.verify_auth_token("abcd").await);
assert!(conn.verify_auth_token(&user.auth_token).await);
conn.reset_auth_token(id).await.unwrap();
let new_user = conn.get_user(id).await.unwrap().unwrap();
assert_eq!(user.email, new_user.email);
assert_ne!(user.auth_token, new_user.auth_token);
}
} }

View File

@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{http::Request, middleware::Next, response::IntoResponse}; use axum::{
http::{HeaderValue, Request},
middleware::Next,
response::IntoResponse,
};
use hyper::{client::HttpConnector, Body, Client, StatusCode}; use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn}; use tracing::{info, warn};
@ -41,6 +45,46 @@ impl ServerContext {
code, code,
} }
} }
async fn authorize_request(&self, request: &Request<Body>) -> bool {
let path = request.uri().path();
if (path.starts_with("/v1/") || path.starts_with("/v1beta/"))
// Authorization is enabled
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
{
let token = {
let authorization = request
.headers()
.get("authorization")
.map(HeaderValue::to_str)
.and_then(Result::ok);
if let Some(authorization) = authorization {
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some(("Bearer", contents)) => Some(contents),
_ => None,
}
} else {
None
}
};
if let Some(token) = token {
if self.db_conn.verify_access_token(token).await.is_err()
&& !self.db_conn.verify_auth_token(token).await
{
return false;
}
} else {
// Admin system is initialized, but there's no valid token.
return false;
}
}
true
}
} }
#[async_trait] #[async_trait]
@ -95,7 +139,13 @@ impl WorkerService for ServerContext {
request: Request<Body>, request: Request<Body>,
next: Next<Body>, next: Next<Body>,
) -> axum::response::Response { ) -> axum::response::Response {
let path = request.uri().path(); if !self.authorize_request(&request).await {
return axum::response::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
.into_response();
}
let remote_addr = request let remote_addr = request
.extensions() .extensions()
@ -103,6 +153,7 @@ impl WorkerService for ServerContext {
.map(|ci| ci.0) .map(|ci| ci.0)
.expect("Unable to extract remote addr"); .expect("Unable to extract remote addr");
let path = request.uri().path();
let worker = if path.starts_with("/v1/completions") { let worker = if path.starts_with("/v1/completions") {
self.completion.select().await self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") { } else if path.starts_with("/v1beta/chat/completions") {