diff --git a/Cargo.lock b/Cargo.lock index 9cc4989..f11d315 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -214,6 +214,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "assert_matches" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" + [[package]] name = "async-channel" version = "1.9.0" @@ -4797,6 +4803,7 @@ version = "0.7.0-dev" dependencies = [ "anyhow", "argon2", + "assert_matches", "async-trait", "axum", "bincode", diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 9a37571..288e915 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -46,4 +46,5 @@ features = [ ] [dev-dependencies] +assert_matches = "1.5.0" tokio = { workspace = true, features = ["macros"] } diff --git a/ee/tabby-webserver/src/repositories.rs b/ee/tabby-webserver/src/repositories/mod.rs similarity index 100% rename from ee/tabby-webserver/src/repositories.rs rename to ee/tabby-webserver/src/repositories/mod.rs diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index f3ea36c..a8767ba 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -3,9 +3,13 @@ use std::fmt::Debug; use anyhow::Result; use async_trait::async_trait; use jsonwebtoken as jwt; -use juniper::{FieldResult, GraphQLObject}; +use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; +use thiserror::Error; +use validator::ValidationErrors; + +use super::from_validation_errors; lazy_static! { static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret( @@ -48,6 +52,33 @@ impl RegisterResponse { } } +#[derive(Error, Debug)] +pub enum RegisterError { + #[error("Invalid input parameters")] + InvalidInput(#[from] ValidationErrors), + + #[error("Invitation code is not valid")] + InvalidInvitationCode, + + #[error("Email is already registered")] + DuplicateEmail, + + #[error(transparent)] + Other(#[from] anyhow::Error), + + #[error("Unknown error")] + Unknown, +} + +impl IntoFieldError for RegisterError { + fn into_field_error(self) -> FieldError { + match self { + Self::InvalidInput(errors) => from_validation_errors(errors), + _ => self.into(), + } + } +} + #[derive(Debug, GraphQLObject)] pub struct TokenAuthResponse { access_token: String, @@ -63,6 +94,39 @@ impl TokenAuthResponse { } } +#[derive(Error, Debug)] +pub enum CoreError { + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[derive(Error, Debug)] +pub enum TokenAuthError { + #[error("Invalid input parameters")] + InvalidInput(#[from] ValidationErrors), + + #[error("User not found")] + UserNotFound, + + #[error("Password is not valid")] + InvalidPassword, + + #[error(transparent)] + Other(#[from] anyhow::Error), + + #[error("Unknown error")] + Unknown, +} + +impl IntoFieldError for TokenAuthError { + fn into_field_error(self) -> FieldError { + match self { + Self::InvalidInput(errors) => from_validation_errors(errors), + _ => self.into(), + } + } +} + #[derive(Debug, Default, GraphQLObject)] pub struct RefreshTokenResponse { access_token: String, @@ -143,11 +207,17 @@ pub trait AuthenticationService: Send + Sync { password1: String, password2: String, invitation_code: Option, - ) -> FieldResult; - async fn token_auth(&self, email: String, password: String) -> FieldResult; - async fn refresh_token(&self, refresh_token: String) -> FieldResult; - async fn verify_token(&self, access_token: String) -> FieldResult; - async fn is_admin_initialized(&self) -> FieldResult; + ) -> std::result::Result; + + async fn token_auth( + &self, + email: String, + password: String, + ) -> std::result::Result; + + async fn refresh_token(&self, refresh_token: String) -> Result; + async fn verify_token(&self, access_token: String) -> Result; + async fn is_admin_initialized(&self) -> Result; async fn create_invitation(&self, email: String) -> Result; async fn list_invitations(&self) -> Result>; diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs index aed2b0c..90c60bf 100644 --- a/ee/tabby-webserver/src/schema/mod.rs +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -10,10 +10,10 @@ use juniper::{ }; use juniper_axum::FromAuth; use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; -use validator::ValidationError; +use validator::ValidationErrors; use self::{ - auth::{validate_jwt, Invitation}, + auth::{validate_jwt, Invitation, RegisterError, TokenAuthError}, worker::WorkerService, }; use crate::schema::{ @@ -40,6 +40,26 @@ impl FromAuth> for Context { } } +type Result = std::result::Result; + +#[derive(thiserror::Error, Debug)] +pub enum CoreError { + #[error("{0}")] + Unauthorized(&'static str), + + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl IntoFieldError for CoreError { + fn into_field_error(self) -> FieldError { + match self { + Self::Unauthorized(msg) => FieldError::new(msg, graphql_value!("Unauthorized")), + _ => self.into(), + } + } +} + // To make our context usable by Juniper, we have to implement a marker trait. impl juniper::Context for Context {} @@ -52,22 +72,24 @@ impl Query { ctx.locator.worker().list_workers().await } - async fn registration_token(ctx: &Context) -> FieldResult { + async fn registration_token(ctx: &Context) -> Result { let token = ctx.locator.worker().read_registration_token().await?; Ok(token) } - async fn is_admin_initialized(ctx: &Context) -> FieldResult { - ctx.locator.auth().is_admin_initialized().await + async fn is_admin_initialized(ctx: &Context) -> Result { + Ok(ctx.locator.auth().is_admin_initialized().await?) } - async fn invitations(ctx: &Context) -> FieldResult> { + async fn invitations(ctx: &Context) -> Result> { if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { return Ok(ctx.locator.auth().list_invitations().await?); } } - Err(unauthorized("Only admin is able to query invitations")) + Err(CoreError::Unauthorized( + "Only admin is able to query invitations", + )) } } @@ -76,14 +98,14 @@ pub struct Mutation; #[graphql_object(context = Context)] impl Mutation { - async fn reset_registration_token(ctx: &Context) -> FieldResult { + async fn reset_registration_token(ctx: &Context) -> Result { if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { let reg_token = ctx.locator.worker().reset_registration_token().await?; return Ok(reg_token); } } - Err(unauthorized( + Err(CoreError::Unauthorized( "Only admin is able to reset registration token", )) } @@ -94,7 +116,7 @@ impl Mutation { password1: String, password2: String, invitation_code: Option, - ) -> FieldResult { + ) -> Result { ctx.locator .auth() .register(email, password1, password2, invitation_code) @@ -105,59 +127,58 @@ impl Mutation { ctx: &Context, email: String, password: String, - ) -> FieldResult { + ) -> Result { ctx.locator.auth().token_auth(email, password).await } - async fn verify_token(ctx: &Context, token: String) -> FieldResult { - ctx.locator.auth().verify_token(token).await + async fn verify_token(ctx: &Context, token: String) -> Result { + Ok(ctx.locator.auth().verify_token(token).await?) } - async fn create_invitation(ctx: &Context, email: String) -> FieldResult { + async fn create_invitation(ctx: &Context, email: String) -> Result { if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { return Ok(ctx.locator.auth().create_invitation(email).await?); } } - Err(unauthorized("Only admin is able to create invitation")) + Err(CoreError::Unauthorized( + "Only admin is able to create invitation", + )) } - async fn delete_invitation(ctx: &Context, id: i32) -> FieldResult { + async fn delete_invitation(ctx: &Context, id: i32) -> Result { if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { return Ok(ctx.locator.auth().delete_invitation(id).await?); } } - Err(unauthorized("Only admin is able to delete invitation")) + Err(CoreError::Unauthorized( + "Only admin is able to delete invitation", + )) } } -#[derive(Debug)] -pub struct ValidationErrors { - pub errors: Vec, -} +fn from_validation_errors(error: ValidationErrors) -> FieldError { + let errors = error + .field_errors() + .into_iter() + .flat_map(|(_, errs)| errs) + .cloned() + .map(|err| { + let mut obj = Object::with_capacity(2); + obj.add_field("path", Value::scalar(err.code.to_string())); + obj.add_field( + "message", + Value::scalar(err.message.unwrap_or_default().to_string()), + ); + obj.into() + }) + .collect::>(); + let mut ext = Object::with_capacity(2); + ext.add_field("code", Value::scalar("validation-error".to_string())); + ext.add_field("errors", Value::list(errors)); -impl IntoFieldError for ValidationErrors { - fn into_field_error(self) -> FieldError { - let errors = self - .errors - .into_iter() - .map(|err| { - let mut obj = Object::with_capacity(2); - obj.add_field("path", Value::scalar(err.code.to_string())); - obj.add_field( - "message", - Value::scalar(err.message.unwrap_or_default().to_string()), - ); - obj.into() - }) - .collect::>(); - let mut ext = Object::with_capacity(2); - ext.add_field("code", Value::scalar("validation-error".to_string())); - ext.add_field("errors", Value::list(errors)); - - FieldError::new("Invalid input parameters", ext.into()) - } + FieldError::new("Invalid input parameters", ext.into()) } pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; @@ -165,7 +186,3 @@ pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription> pub fn create_schema() -> Schema { Schema::new(Query, Mutation, EmptySubscription::new()) } - -fn unauthorized(msg: &str) -> FieldError { - FieldError::new(msg, graphql_value!("Unauthorized")) -} diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index e6d3e53..bd14f3d 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -5,16 +5,13 @@ use argon2::{ Argon2, PasswordHasher, PasswordVerifier, }; use async_trait::async_trait; -use juniper::{FieldResult, IntoFieldError}; use validator::Validate; use super::db::DbConn; -use crate::schema::{ - auth::{ - generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, - RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse, - }, - ValidationErrors, +use crate::schema::auth::{ + generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse, + RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo, + VerifyTokenResponse, }; /// Input parameters for register mutation @@ -111,26 +108,17 @@ impl AuthenticationService for DbConn { password1: String, password2: String, invitation_code: Option, - ) -> FieldResult { + ) -> std::result::Result { let input = RegisterInput { email, password1, password2, }; - input.validate().map_err(|err| { - let errors = err - .field_errors() - .into_iter() - .flat_map(|(_, errs)| errs) - .cloned() - .collect(); - - ValidationErrors { errors }.into_field_error() - })?; + input.validate()?; let is_admin_initialized = self.is_admin_initialized().await?; if is_admin_initialized { - let err = Err("Invitation code is not valid".into()); + let err = Err(RegisterError::InvalidInvitationCode); let Some(invitation_code) = invitation_code else { return err; }; @@ -146,68 +134,67 @@ impl AuthenticationService for DbConn { // check if email exists if self.get_user_by_email(&input.email).await?.is_some() { - return Err("Email already exists".into()); + return Err(RegisterError::DuplicateEmail); } - let pwd_hash = password_hash(&input.password1)?; + let Ok(pwd_hash) = password_hash(&input.password1) else { + return Err(RegisterError::Unknown); + }; - let id = self.create_user(input.email.clone(), pwd_hash, !is_admin_initialized) + let id = self + .create_user(input.email.clone(), pwd_hash, !is_admin_initialized) .await?; let user = self.get_user(id).await?.unwrap(); - let access_token = generate_jwt(Claims::new(UserInfo::new( + let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( user.email.clone(), user.is_admin, - )))?; + ))) else { + return Err(RegisterError::Unknown); + }; let resp = RegisterResponse::new(access_token, "".to_string()); Ok(resp) } - async fn token_auth(&self, email: String, password: String) -> FieldResult { + async fn token_auth( + &self, + email: String, + password: String, + ) -> std::result::Result { let input = TokenAuthInput { email, password }; - input.validate().map_err(|err| { - let errors = err - .field_errors() - .into_iter() - .flat_map(|(_, errs)| errs) - .cloned() - .collect(); + input.validate()?; - ValidationErrors { errors }.into_field_error() - })?; - - let user = self.get_user_by_email(&input.email).await?; - - let user = match user { - Some(user) => user, - None => return Err("User not found".into()), + let Some(user) = self.get_user_by_email(&input.email).await? else { + return Err(TokenAuthError::UserNotFound); }; if !password_verify(&input.password, &user.password_encrypted) { - return Err("Password incorrect".into()); + return Err(TokenAuthError::InvalidPassword); } - let access_token = generate_jwt(Claims::new(UserInfo::new( + let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( user.email.clone(), user.is_admin, - )))?; + ))) else { + return Err(TokenAuthError::Unknown); + }; let resp = TokenAuthResponse::new(access_token, "".to_string()); Ok(resp) } - async fn refresh_token(&self, _refresh_token: String) -> FieldResult { + async fn refresh_token(&self, _refresh_token: String) -> Result { Ok(RefreshTokenResponse::default()) } - async fn verify_token(&self, access_token: String) -> FieldResult { + async fn verify_token(&self, access_token: String) -> Result { let claims = validate_jwt(&access_token)?; let resp = VerifyTokenResponse::new(claims); Ok(resp) } - async fn is_admin_initialized(&self) -> FieldResult { + async fn is_admin_initialized(&self) -> Result { let admin = self.list_admin_users().await?; Ok(!admin.is_empty()) } @@ -244,6 +231,8 @@ fn password_verify(raw: &str, hash: &str) -> bool { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use super::*; #[test] @@ -263,4 +252,103 @@ mod tests { assert!(password_verify(raw, &hash)); assert!(!password_verify(raw, "invalid hash")); } + + static ADMIN_EMAIL: &str = "test@example.com"; + static ADMIN_PASSWORD: &str = "123456789"; + + async fn create_admin_user(conn: &DbConn) -> i32 { + conn.register( + ADMIN_EMAIL.to_owned(), + ADMIN_PASSWORD.to_owned(), + ADMIN_PASSWORD.to_owned(), + None, + ) + .await + .unwrap(); + 1 + } + + #[tokio::test] + async fn test_auth_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + assert_matches!( + conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned()) + .await, + Err(TokenAuthError::UserNotFound) + ); + + create_admin_user(&conn).await; + + assert_matches!( + conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned()) + .await, + Err(TokenAuthError::InvalidPassword) + ); + + assert!(conn + .token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned()) + .await + .is_ok()); + } + + #[tokio::test] + async fn test_invitation_flow() { + let conn = DbConn::new_in_memory().await.unwrap(); + + assert!(!conn.is_admin_initialized().await.unwrap()); + create_admin_user(&conn).await; + + let email = "user@user.com"; + let password = "12345678"; + + conn.create_invitation(email.to_owned()).await.unwrap(); + let invitation = &conn.list_invitations().await.unwrap()[0]; + + // Admin initialized, registeration requires a invitation code; + assert_matches!( + conn.register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + None + ) + .await, + Err(RegisterError::InvalidInvitationCode) + ); + + // Invalid invitation code won't work. + assert_matches!( + conn.register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some("abc".to_owned()) + ) + .await, + Err(RegisterError::InvalidInvitationCode) + ); + + // Register success. + assert!(conn + .register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some(invitation.code.clone()) + ) + .await + .is_ok()); + + // Try register again with same email failed. + assert_matches!( + conn.register( + email.to_owned(), + password.to_owned(), + password.to_owned(), + Some(invitation.code.clone()) + ) + .await, + Err(RegisterError::DuplicateEmail) + ); + } } diff --git a/ee/tabby-webserver/src/service/db.rs b/ee/tabby-webserver/src/service/db.rs index 89747a9..80a1ec5 100644 --- a/ee/tabby-webserver/src/service/db.rs +++ b/ee/tabby-webserver/src/service/db.rs @@ -22,7 +22,8 @@ lazy_static! { CONSTRAINT `idx_token` UNIQUE (`token`) ); "# - ), + ) + .down("DROP TABLE registeration_token"), M::up( r#" CREATE TABLE users ( @@ -35,7 +36,8 @@ lazy_static! { CONSTRAINT `idx_email` UNIQUE (`email`) ); "# - ), + ) + .down("DROP TABLE users"), M::up( r#" CREATE TABLE invitations ( @@ -47,7 +49,8 @@ lazy_static! { CONSTRAINT `idx_code` UNIQUE (`code`) ); "# - ), + ) + .down("DROP TABLE invitations"), ]); } @@ -93,6 +96,12 @@ pub struct DbConn { } impl DbConn { + #[cfg(test)] + pub async fn new_in_memory() -> Result { + let conn = Connection::open_in_memory().await?; + DbConn::init_db(conn).await + } + pub async fn new() -> Result { let db_path = db_path().await?; let conn = Connection::open(db_path).await?; @@ -302,21 +311,14 @@ impl DbConn { #[cfg(test)] mod tests { - use juniper::FieldResult; use super::*; use crate::schema::auth::AuthenticationService; - async fn new_in_memory() -> Result { - let conn = Connection::open_in_memory().await?; - DbConn::init_db(conn).await - } - - async fn create_admin_user(conn: &DbConn) -> i32 { - let email = "test@example.com"; - let passwd = "123456"; - let is_admin = true; - conn.create_user(email.to_string(), passwd.to_string(), is_admin) + async fn create_user(conn: &DbConn) -> i32 { + let email: &str = "test@example.com"; + let password: &str = "123456789"; + conn.create_user(email.to_string(), password.to_string(), true) .await .unwrap() } @@ -328,14 +330,14 @@ mod tests { #[tokio::test] async fn test_token() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); let token = conn.read_registration_token().await.unwrap(); assert_eq!(token.len(), 36); } #[tokio::test] async fn test_update_token() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); let old_token = conn.read_registration_token().await.unwrap(); conn.reset_registration_token().await.unwrap(); @@ -346,16 +348,16 @@ mod tests { #[tokio::test] async fn test_create_user() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); - let id = create_admin_user(&conn).await; + let id = create_user(&conn).await; let user = conn.get_user(id).await.unwrap().unwrap(); assert_eq!(user.id, 1); } #[tokio::test] async fn test_get_user_by_email() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); let email = "hello@example.com"; let user = conn.get_user_by_email(email).await.unwrap(); @@ -365,16 +367,16 @@ mod tests { #[tokio::test] async fn test_is_admin_initialized() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); assert!(!conn.is_admin_initialized().await.unwrap()); - create_admin_user(&conn).await; + create_user(&conn).await; assert!(conn.is_admin_initialized().await.unwrap()); } #[tokio::test] async fn test_invitations() { - let conn = new_in_memory().await.unwrap(); + let conn = DbConn::new_in_memory().await.unwrap(); let email = "hello@example.com".to_owned(); conn.create_invitation(email).await.unwrap(); @@ -396,62 +398,4 @@ mod tests { let invitations = conn.list_invitations().await.unwrap(); assert!(invitations.is_empty()); } - - #[tokio::test] - async fn test_invitation_flow() { - let conn = new_in_memory().await.unwrap(); - - assert!(!conn.is_admin_initialized().await.unwrap()); - create_admin_user(&conn).await; - - let email = "user@user.com"; - let password = "12345678"; - - conn.create_invitation(email.to_owned()).await.unwrap(); - let invitation = &conn.list_invitations().await.unwrap()[0]; - - // Admin initialized, registeration requires a invitation code; - assert!( - conn.register( - email.to_owned(), - password.to_owned(), - password.to_owned(), - None - ) - .await.is_err() - ); - - // Invalid invitation code won't work. - assert!(conn - .register( - email.to_owned(), - password.to_owned(), - password.to_owned(), - Some("abc".to_owned()) - ) - .await - .is_err()); - - // Register success. - assert!(conn - .register( - email.to_owned(), - password.to_owned(), - password.to_owned(), - Some(invitation.code.clone()) - ) - .await - .is_ok()); - - // Try register again with same email failed. - assert!(conn - .register( - email.to_owned(), - password.to_owned(), - password.to_owned(), - Some(invitation.code.clone()) - ) - .await - .is_err()); - } }