diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 51105c3..a6dc98e 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -285,7 +285,7 @@ pub trait AuthenticationService: Send + Sync { &self, refresh_token: String, ) -> std::result::Result; - async fn verify_token(&self, access_token: String) -> Result; + async fn verify_access_token(&self, access_token: &str) -> Result; async fn is_admin_initialized(&self) -> Result; async fn create_invitation(&self, email: String) -> Result; diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index 798a56f..70b286a 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -71,13 +71,37 @@ pub struct Query; #[graphql_object(context = Context)] impl Query { - async fn workers(ctx: &Context) -> Vec { - ctx.locator.worker().list_workers().await + async fn workers(ctx: &Context) -> Result> { + if ctx.locator.auth().is_admin_initialized().await? { + if let Some(claims) = &ctx.claims { + if claims.user_info().is_admin() { + let workers = ctx.locator.worker().list_workers().await; + return Ok(workers); + } + } + Err(CoreError::Unauthorized( + "Only admin is able to read workers", + )) + } else { + Ok(ctx.locator.worker().list_workers().await) + } } async fn registration_token(ctx: &Context) -> Result { - let token = ctx.locator.worker().read_registration_token().await?; - Ok(token) + if ctx.locator.auth().is_admin_initialized().await? { + if let Some(claims) = &ctx.claims { + if claims.user_info().is_admin() { + let token = ctx.locator.worker().read_registration_token().await?; + return Ok(token); + } + } + Err(CoreError::Unauthorized( + "Only admin is able to read registeration_token", + )) + } else { + let token = ctx.locator.worker().read_registration_token().await?; + Ok(token) + } } async fn is_admin_initialized(ctx: &Context) -> Result { @@ -142,7 +166,7 @@ impl Mutation { } async fn verify_token(ctx: &Context, token: String) -> Result { - Ok(ctx.locator.auth().verify_token(token).await?) + Ok(ctx.locator.auth().verify_access_token(&token).await?) } async fn refresh_token( diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index 6cccb87..0b5345c 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -220,8 +220,8 @@ impl AuthenticationService for DbConn { Ok(resp) } - async fn verify_token(&self, access_token: String) -> Result { - let claims = validate_jwt(&access_token)?; + async fn verify_access_token(&self, access_token: &str) -> Result { + let claims = validate_jwt(access_token)?; let resp = VerifyTokenResponse::new(claims); Ok(resp) } diff --git a/ee/tabby-webserver/src/service/db/mod.rs b/ee/tabby-webserver/src/service/db/mod.rs index 1ac9453..ee56978 100644 --- a/ee/tabby-webserver/src/service/db/mod.rs +++ b/ee/tabby-webserver/src/service/db/mod.rs @@ -27,6 +27,7 @@ lazy_static! { "# ) .down("DROP TABLE registration_token"), + // ==== Above migrations released in 0.6.0 ==== M::up( r#" CREATE TABLE users ( @@ -36,7 +37,10 @@ lazy_static! { is_admin BOOLEAN NOT NULL DEFAULT 0, created_at TIMESTAMP DEFAULT (DATETIME('now')), updated_at TIMESTAMP DEFAULT (DATETIME('now')), - CONSTRAINT `idx_email` UNIQUE (`email`) + auth_token VARCHAR(128) NOT NULL, + + CONSTRAINT `idx_email` UNIQUE (`email`) + CONSTRAINT `idx_auth_token` UNIQUE (`auth_token`) ); "# ) diff --git a/ee/tabby-webserver/src/service/db/users.rs b/ee/tabby-webserver/src/service/db/users.rs index 7403afd..245c7af 100644 --- a/ee/tabby-webserver/src/service/db/users.rs +++ b/ee/tabby-webserver/src/service/db/users.rs @@ -3,6 +3,7 @@ use anyhow::Result; use chrono::{DateTime, Utc}; use rusqlite::{params, OptionalExtension, Row}; +use uuid::Uuid; use super::DbConn; @@ -15,11 +16,14 @@ pub struct User { pub email: String, pub password_encrypted: String, pub is_admin: bool, + + /// To authenticate IDE extensions / plugins to access code completion / chat api endpoints. + pub auth_token: String, } impl User { fn select(clause: &str) -> String { - r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "# + r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at, auth_token FROM users WHERE "# .to_owned() + clause } @@ -32,6 +36,7 @@ impl User { is_admin: row.get(3)?, created_at: row.get(4)?, updated_at: row.get(5)?, + auth_token: row.get(6)?, }) } } @@ -47,9 +52,9 @@ impl DbConn { .conn .call(move |c| { let mut stmt = c.prepare( - r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, + r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#, )?; - let id = stmt.insert((email, password_encrypted, is_admin))?; + let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?; Ok(id) }) .await?; @@ -98,6 +103,38 @@ impl DbConn { Ok(users) } + + pub async fn verify_auth_token(&self, token: &str) -> bool { + let token = token.to_owned(); + let id: Result = self + .conn + .call(move |c| { + c.query_row( + r#"SELECT id FROM users WHERE auth_token = ?"#, + params![token], + |row| row.get(0), + ) + }) + .await; + id.is_ok() + } + + pub async fn reset_auth_token(&self, id: i32) -> Result { + self.conn + .call(move |c| { + let mut stmt = c.prepare(r#"UPDATE users SET auth_token = ? WHERE id = ?"#)?; + stmt.execute((Uuid::new_v4().to_string(), id))?; + Ok(()) + }) + .await?; + + Ok(id) + } +} + +fn generate_auth_token() -> String { + let uuid = Uuid::new_v4().to_string().replace('-', ""); + format!("auth_{}", uuid) } #[cfg(test)] @@ -123,4 +160,21 @@ mod tests { assert!(user.is_none()); } + + #[tokio::test] + async fn test_auth_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + let id = create_user(&conn).await; + + let user = conn.get_user(id).await.unwrap().unwrap(); + + assert!(!conn.verify_auth_token("abcd").await); + + assert!(conn.verify_auth_token(&user.auth_token).await); + + conn.reset_auth_token(id).await.unwrap(); + let new_user = conn.get_user(id).await.unwrap().unwrap(); + assert_eq!(user.email, new_user.email); + assert_ne!(user.auth_token, new_user.auth_token); + } } diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index c25795d..5e3f94d 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -8,7 +8,11 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; use async_trait::async_trait; -use axum::{http::Request, middleware::Next, response::IntoResponse}; +use axum::{ + http::{HeaderValue, Request}, + middleware::Next, + response::IntoResponse, +}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tracing::{info, warn}; @@ -41,6 +45,46 @@ impl ServerContext { code, } } + + async fn authorize_request(&self, request: &Request) -> bool { + let path = request.uri().path(); + if (path.starts_with("/v1/") || path.starts_with("/v1beta/")) + // Authorization is enabled + && self.db_conn.is_admin_initialized().await.unwrap_or(false) + { + let token = { + let authorization = request + .headers() + .get("authorization") + .map(HeaderValue::to_str) + .and_then(Result::ok); + + if let Some(authorization) = authorization { + let split = authorization.split_once(' '); + match split { + // Found proper bearer + Some(("Bearer", contents)) => Some(contents), + _ => None, + } + } else { + None + } + }; + + if let Some(token) = token { + if self.db_conn.verify_access_token(token).await.is_err() + && !self.db_conn.verify_auth_token(token).await + { + return false; + } + } else { + // Admin system is initialized, but there's no valid token. + return false; + } + } + + true + } } #[async_trait] @@ -95,7 +139,13 @@ impl WorkerService for ServerContext { request: Request, next: Next, ) -> axum::response::Response { - let path = request.uri().path(); + if !self.authorize_request(&request).await { + return axum::response::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap() + .into_response(); + } let remote_addr = request .extensions() @@ -103,6 +153,7 @@ impl WorkerService for ServerContext { .map(|ci| ci.0) .expect("Unable to extract remote addr"); + let path = request.uri().path(); let worker = if path.starts_with("/v1/completions") { self.completion.select().await } else if path.starts_with("/v1beta/chat/completions") {