refactor(webserver): extract RegisterError and TokenAuthError and add unit test (#936)
* refactor: extract RegisterError and TokenAuthError * update * update test * fix token auth test * cleanup * fix * add down operations * cleanup error type * [autofix.ci] apply automated fixes * update * cleanup * Process InvalidationErrors directly * update error handling --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>add-signin-page
parent
258322ede4
commit
f3a31082ef
|
|
@ -214,6 +214,12 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "assert_matches"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-channel"
|
name = "async-channel"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
|
|
@ -4797,6 +4803,7 @@ version = "0.7.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"argon2",
|
"argon2",
|
||||||
|
"assert_matches",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"bincode",
|
"bincode",
|
||||||
|
|
|
||||||
|
|
@ -46,4 +46,5 @@ features = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
assert_matches = "1.5.0"
|
||||||
tokio = { workspace = true, features = ["macros"] }
|
tokio = { workspace = true, features = ["macros"] }
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,13 @@ use std::fmt::Debug;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use jsonwebtoken as jwt;
|
use jsonwebtoken as jwt;
|
||||||
use juniper::{FieldResult, GraphQLObject};
|
use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use thiserror::Error;
|
||||||
|
use validator::ValidationErrors;
|
||||||
|
|
||||||
|
use super::from_validation_errors;
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
||||||
|
|
@ -48,6 +52,33 @@ impl RegisterResponse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum RegisterError {
|
||||||
|
#[error("Invalid input parameters")]
|
||||||
|
InvalidInput(#[from] ValidationErrors),
|
||||||
|
|
||||||
|
#[error("Invitation code is not valid")]
|
||||||
|
InvalidInvitationCode,
|
||||||
|
|
||||||
|
#[error("Email is already registered")]
|
||||||
|
DuplicateEmail,
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
|
||||||
|
#[error("Unknown error")]
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: ScalarValue> IntoFieldError<S> for RegisterError {
|
||||||
|
fn into_field_error(self) -> FieldError<S> {
|
||||||
|
match self {
|
||||||
|
Self::InvalidInput(errors) => from_validation_errors(errors),
|
||||||
|
_ => self.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, GraphQLObject)]
|
#[derive(Debug, GraphQLObject)]
|
||||||
pub struct TokenAuthResponse {
|
pub struct TokenAuthResponse {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
|
|
@ -63,6 +94,39 @@ impl TokenAuthResponse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum CoreError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum TokenAuthError {
|
||||||
|
#[error("Invalid input parameters")]
|
||||||
|
InvalidInput(#[from] ValidationErrors),
|
||||||
|
|
||||||
|
#[error("User not found")]
|
||||||
|
UserNotFound,
|
||||||
|
|
||||||
|
#[error("Password is not valid")]
|
||||||
|
InvalidPassword,
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
|
||||||
|
#[error("Unknown error")]
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError {
|
||||||
|
fn into_field_error(self) -> FieldError<S> {
|
||||||
|
match self {
|
||||||
|
Self::InvalidInput(errors) => from_validation_errors(errors),
|
||||||
|
_ => self.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, GraphQLObject)]
|
#[derive(Debug, Default, GraphQLObject)]
|
||||||
pub struct RefreshTokenResponse {
|
pub struct RefreshTokenResponse {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
|
|
@ -143,11 +207,17 @@ pub trait AuthenticationService: Send + Sync {
|
||||||
password1: String,
|
password1: String,
|
||||||
password2: String,
|
password2: String,
|
||||||
invitation_code: Option<String>,
|
invitation_code: Option<String>,
|
||||||
) -> FieldResult<RegisterResponse>;
|
) -> std::result::Result<RegisterResponse, RegisterError>;
|
||||||
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
|
|
||||||
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
|
async fn token_auth(
|
||||||
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
|
&self,
|
||||||
async fn is_admin_initialized(&self) -> FieldResult<bool>;
|
email: String,
|
||||||
|
password: String,
|
||||||
|
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;
|
||||||
|
|
||||||
|
async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
|
||||||
|
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
||||||
|
async fn is_admin_initialized(&self) -> Result<bool>;
|
||||||
|
|
||||||
async fn create_invitation(&self, email: String) -> Result<i32>;
|
async fn create_invitation(&self, email: String) -> Result<i32>;
|
||||||
async fn list_invitations(&self) -> Result<Vec<Invitation>>;
|
async fn list_invitations(&self) -> Result<Vec<Invitation>>;
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ use juniper::{
|
||||||
};
|
};
|
||||||
use juniper_axum::FromAuth;
|
use juniper_axum::FromAuth;
|
||||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||||
use validator::ValidationError;
|
use validator::ValidationErrors;
|
||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
auth::{validate_jwt, Invitation},
|
auth::{validate_jwt, Invitation, RegisterError, TokenAuthError},
|
||||||
worker::WorkerService,
|
worker::WorkerService,
|
||||||
};
|
};
|
||||||
use crate::schema::{
|
use crate::schema::{
|
||||||
|
|
@ -40,6 +40,26 @@ impl FromAuth<Arc<dyn ServiceLocator>> for Context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Result<T, E = CoreError> = std::result::Result<T, E>;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum CoreError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Unauthorized(&'static str),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: ScalarValue> IntoFieldError<S> for CoreError {
|
||||||
|
fn into_field_error(self) -> FieldError<S> {
|
||||||
|
match self {
|
||||||
|
Self::Unauthorized(msg) => FieldError::new(msg, graphql_value!("Unauthorized")),
|
||||||
|
_ => self.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||||
impl juniper::Context for Context {}
|
impl juniper::Context for Context {}
|
||||||
|
|
||||||
|
|
@ -52,22 +72,24 @@ impl Query {
|
||||||
ctx.locator.worker().list_workers().await
|
ctx.locator.worker().list_workers().await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn registration_token(ctx: &Context) -> FieldResult<String> {
|
async fn registration_token(ctx: &Context) -> Result<String> {
|
||||||
let token = ctx.locator.worker().read_registration_token().await?;
|
let token = ctx.locator.worker().read_registration_token().await?;
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_admin_initialized(ctx: &Context) -> FieldResult<bool> {
|
async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
|
||||||
ctx.locator.auth().is_admin_initialized().await
|
Ok(ctx.locator.auth().is_admin_initialized().await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn invitations(ctx: &Context) -> FieldResult<Vec<Invitation>> {
|
async fn invitations(ctx: &Context) -> Result<Vec<Invitation>> {
|
||||||
if let Some(claims) = &ctx.claims {
|
if let Some(claims) = &ctx.claims {
|
||||||
if claims.user_info().is_admin() {
|
if claims.user_info().is_admin() {
|
||||||
return Ok(ctx.locator.auth().list_invitations().await?);
|
return Ok(ctx.locator.auth().list_invitations().await?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(unauthorized("Only admin is able to query invitations"))
|
Err(CoreError::Unauthorized(
|
||||||
|
"Only admin is able to query invitations",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,14 +98,14 @@ pub struct Mutation;
|
||||||
|
|
||||||
#[graphql_object(context = Context)]
|
#[graphql_object(context = Context)]
|
||||||
impl Mutation {
|
impl Mutation {
|
||||||
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
|
async fn reset_registration_token(ctx: &Context) -> Result<String> {
|
||||||
if let Some(claims) = &ctx.claims {
|
if let Some(claims) = &ctx.claims {
|
||||||
if claims.user_info().is_admin() {
|
if claims.user_info().is_admin() {
|
||||||
let reg_token = ctx.locator.worker().reset_registration_token().await?;
|
let reg_token = ctx.locator.worker().reset_registration_token().await?;
|
||||||
return Ok(reg_token);
|
return Ok(reg_token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(unauthorized(
|
Err(CoreError::Unauthorized(
|
||||||
"Only admin is able to reset registration token",
|
"Only admin is able to reset registration token",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
@ -94,7 +116,7 @@ impl Mutation {
|
||||||
password1: String,
|
password1: String,
|
||||||
password2: String,
|
password2: String,
|
||||||
invitation_code: Option<String>,
|
invitation_code: Option<String>,
|
||||||
) -> FieldResult<RegisterResponse> {
|
) -> Result<RegisterResponse, RegisterError> {
|
||||||
ctx.locator
|
ctx.locator
|
||||||
.auth()
|
.auth()
|
||||||
.register(email, password1, password2, invitation_code)
|
.register(email, password1, password2, invitation_code)
|
||||||
|
|
@ -105,43 +127,43 @@ impl Mutation {
|
||||||
ctx: &Context,
|
ctx: &Context,
|
||||||
email: String,
|
email: String,
|
||||||
password: String,
|
password: String,
|
||||||
) -> FieldResult<TokenAuthResponse> {
|
) -> Result<TokenAuthResponse, TokenAuthError> {
|
||||||
ctx.locator.auth().token_auth(email, password).await
|
ctx.locator.auth().token_auth(email, password).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
|
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
||||||
ctx.locator.auth().verify_token(token).await
|
Ok(ctx.locator.auth().verify_token(token).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_invitation(ctx: &Context, email: String) -> FieldResult<i32> {
|
async fn create_invitation(ctx: &Context, email: String) -> Result<i32> {
|
||||||
if let Some(claims) = &ctx.claims {
|
if let Some(claims) = &ctx.claims {
|
||||||
if claims.user_info().is_admin() {
|
if claims.user_info().is_admin() {
|
||||||
return Ok(ctx.locator.auth().create_invitation(email).await?);
|
return Ok(ctx.locator.auth().create_invitation(email).await?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(unauthorized("Only admin is able to create invitation"))
|
Err(CoreError::Unauthorized(
|
||||||
|
"Only admin is able to create invitation",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_invitation(ctx: &Context, id: i32) -> FieldResult<i32> {
|
async fn delete_invitation(ctx: &Context, id: i32) -> Result<i32> {
|
||||||
if let Some(claims) = &ctx.claims {
|
if let Some(claims) = &ctx.claims {
|
||||||
if claims.user_info().is_admin() {
|
if claims.user_info().is_admin() {
|
||||||
return Ok(ctx.locator.auth().delete_invitation(id).await?);
|
return Ok(ctx.locator.auth().delete_invitation(id).await?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(unauthorized("Only admin is able to delete invitation"))
|
Err(CoreError::Unauthorized(
|
||||||
|
"Only admin is able to delete invitation",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> {
|
||||||
pub struct ValidationErrors {
|
let errors = error
|
||||||
pub errors: Vec<ValidationError>,
|
.field_errors()
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
|
||||||
fn into_field_error(self) -> FieldError<S> {
|
|
||||||
let errors = self
|
|
||||||
.errors
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
.flat_map(|(_, errs)| errs)
|
||||||
|
.cloned()
|
||||||
.map(|err| {
|
.map(|err| {
|
||||||
let mut obj = Object::with_capacity(2);
|
let mut obj = Object::with_capacity(2);
|
||||||
obj.add_field("path", Value::scalar(err.code.to_string()));
|
obj.add_field("path", Value::scalar(err.code.to_string()));
|
||||||
|
|
@ -158,14 +180,9 @@ impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
||||||
|
|
||||||
FieldError::new("Invalid input parameters", ext.into())
|
FieldError::new("Invalid input parameters", ext.into())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
||||||
|
|
||||||
pub fn create_schema() -> Schema {
|
pub fn create_schema() -> Schema {
|
||||||
Schema::new(Query, Mutation, EmptySubscription::new())
|
Schema::new(Query, Mutation, EmptySubscription::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unauthorized(msg: &str) -> FieldError {
|
|
||||||
FieldError::new(msg, graphql_value!("Unauthorized"))
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,13 @@ use argon2::{
|
||||||
Argon2, PasswordHasher, PasswordVerifier,
|
Argon2, PasswordHasher, PasswordVerifier,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use juniper::{FieldResult, IntoFieldError};
|
|
||||||
use validator::Validate;
|
use validator::Validate;
|
||||||
|
|
||||||
use super::db::DbConn;
|
use super::db::DbConn;
|
||||||
use crate::schema::{
|
use crate::schema::auth::{
|
||||||
auth::{
|
generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse,
|
||||||
generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation,
|
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo,
|
||||||
RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
|
VerifyTokenResponse,
|
||||||
},
|
|
||||||
ValidationErrors,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Input parameters for register mutation
|
/// Input parameters for register mutation
|
||||||
|
|
@ -111,26 +108,17 @@ impl AuthenticationService for DbConn {
|
||||||
password1: String,
|
password1: String,
|
||||||
password2: String,
|
password2: String,
|
||||||
invitation_code: Option<String>,
|
invitation_code: Option<String>,
|
||||||
) -> FieldResult<RegisterResponse> {
|
) -> std::result::Result<RegisterResponse, RegisterError> {
|
||||||
let input = RegisterInput {
|
let input = RegisterInput {
|
||||||
email,
|
email,
|
||||||
password1,
|
password1,
|
||||||
password2,
|
password2,
|
||||||
};
|
};
|
||||||
input.validate().map_err(|err| {
|
input.validate()?;
|
||||||
let errors = err
|
|
||||||
.field_errors()
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|(_, errs)| errs)
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
ValidationErrors { errors }.into_field_error()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let is_admin_initialized = self.is_admin_initialized().await?;
|
let is_admin_initialized = self.is_admin_initialized().await?;
|
||||||
if is_admin_initialized {
|
if is_admin_initialized {
|
||||||
let err = Err("Invitation code is not valid".into());
|
let err = Err(RegisterError::InvalidInvitationCode);
|
||||||
let Some(invitation_code) = invitation_code else {
|
let Some(invitation_code) = invitation_code else {
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
|
|
@ -146,68 +134,67 @@ impl AuthenticationService for DbConn {
|
||||||
|
|
||||||
// check if email exists
|
// check if email exists
|
||||||
if self.get_user_by_email(&input.email).await?.is_some() {
|
if self.get_user_by_email(&input.email).await?.is_some() {
|
||||||
return Err("Email already exists".into());
|
return Err(RegisterError::DuplicateEmail);
|
||||||
}
|
}
|
||||||
|
|
||||||
let pwd_hash = password_hash(&input.password1)?;
|
let Ok(pwd_hash) = password_hash(&input.password1) else {
|
||||||
|
return Err(RegisterError::Unknown);
|
||||||
|
};
|
||||||
|
|
||||||
let id = self.create_user(input.email.clone(), pwd_hash, !is_admin_initialized)
|
let id = self
|
||||||
|
.create_user(input.email.clone(), pwd_hash, !is_admin_initialized)
|
||||||
.await?;
|
.await?;
|
||||||
let user = self.get_user(id).await?.unwrap();
|
let user = self.get_user(id).await?.unwrap();
|
||||||
|
|
||||||
let access_token = generate_jwt(Claims::new(UserInfo::new(
|
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
||||||
user.email.clone(),
|
user.email.clone(),
|
||||||
user.is_admin,
|
user.is_admin,
|
||||||
)))?;
|
))) else {
|
||||||
|
return Err(RegisterError::Unknown);
|
||||||
|
};
|
||||||
|
|
||||||
let resp = RegisterResponse::new(access_token, "".to_string());
|
let resp = RegisterResponse::new(access_token, "".to_string());
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
|
async fn token_auth(
|
||||||
|
&self,
|
||||||
|
email: String,
|
||||||
|
password: String,
|
||||||
|
) -> std::result::Result<TokenAuthResponse, TokenAuthError> {
|
||||||
let input = TokenAuthInput { email, password };
|
let input = TokenAuthInput { email, password };
|
||||||
input.validate().map_err(|err| {
|
input.validate()?;
|
||||||
let errors = err
|
|
||||||
.field_errors()
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|(_, errs)| errs)
|
|
||||||
.cloned()
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
ValidationErrors { errors }.into_field_error()
|
let Some(user) = self.get_user_by_email(&input.email).await? else {
|
||||||
})?;
|
return Err(TokenAuthError::UserNotFound);
|
||||||
|
|
||||||
let user = self.get_user_by_email(&input.email).await?;
|
|
||||||
|
|
||||||
let user = match user {
|
|
||||||
Some(user) => user,
|
|
||||||
None => return Err("User not found".into()),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if !password_verify(&input.password, &user.password_encrypted) {
|
if !password_verify(&input.password, &user.password_encrypted) {
|
||||||
return Err("Password incorrect".into());
|
return Err(TokenAuthError::InvalidPassword);
|
||||||
}
|
}
|
||||||
|
|
||||||
let access_token = generate_jwt(Claims::new(UserInfo::new(
|
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
||||||
user.email.clone(),
|
user.email.clone(),
|
||||||
user.is_admin,
|
user.is_admin,
|
||||||
)))?;
|
))) else {
|
||||||
|
return Err(TokenAuthError::Unknown);
|
||||||
|
};
|
||||||
|
|
||||||
let resp = TokenAuthResponse::new(access_token, "".to_string());
|
let resp = TokenAuthResponse::new(access_token, "".to_string());
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn refresh_token(&self, _refresh_token: String) -> FieldResult<RefreshTokenResponse> {
|
async fn refresh_token(&self, _refresh_token: String) -> Result<RefreshTokenResponse> {
|
||||||
Ok(RefreshTokenResponse::default())
|
Ok(RefreshTokenResponse::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse> {
|
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
||||||
let claims = validate_jwt(&access_token)?;
|
let claims = validate_jwt(&access_token)?;
|
||||||
let resp = VerifyTokenResponse::new(claims);
|
let resp = VerifyTokenResponse::new(claims);
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_admin_initialized(&self) -> FieldResult<bool> {
|
async fn is_admin_initialized(&self) -> Result<bool> {
|
||||||
let admin = self.list_admin_users().await?;
|
let admin = self.list_admin_users().await?;
|
||||||
Ok(!admin.is_empty())
|
Ok(!admin.is_empty())
|
||||||
}
|
}
|
||||||
|
|
@ -244,6 +231,8 @@ fn password_verify(raw: &str, hash: &str) -> bool {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use assert_matches::assert_matches;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -263,4 +252,103 @@ mod tests {
|
||||||
assert!(password_verify(raw, &hash));
|
assert!(password_verify(raw, &hash));
|
||||||
assert!(!password_verify(raw, "invalid hash"));
|
assert!(!password_verify(raw, "invalid hash"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ADMIN_EMAIL: &str = "test@example.com";
|
||||||
|
static ADMIN_PASSWORD: &str = "123456789";
|
||||||
|
|
||||||
|
async fn create_admin_user(conn: &DbConn) -> i32 {
|
||||||
|
conn.register(
|
||||||
|
ADMIN_EMAIL.to_owned(),
|
||||||
|
ADMIN_PASSWORD.to_owned(),
|
||||||
|
ADMIN_PASSWORD.to_owned(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_auth_token() {
|
||||||
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
assert_matches!(
|
||||||
|
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
|
||||||
|
.await,
|
||||||
|
Err(TokenAuthError::UserNotFound)
|
||||||
|
);
|
||||||
|
|
||||||
|
create_admin_user(&conn).await;
|
||||||
|
|
||||||
|
assert_matches!(
|
||||||
|
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
|
||||||
|
.await,
|
||||||
|
Err(TokenAuthError::InvalidPassword)
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(conn
|
||||||
|
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_invitation_flow() {
|
||||||
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
|
assert!(!conn.is_admin_initialized().await.unwrap());
|
||||||
|
create_admin_user(&conn).await;
|
||||||
|
|
||||||
|
let email = "user@user.com";
|
||||||
|
let password = "12345678";
|
||||||
|
|
||||||
|
conn.create_invitation(email.to_owned()).await.unwrap();
|
||||||
|
let invitation = &conn.list_invitations().await.unwrap()[0];
|
||||||
|
|
||||||
|
// Admin initialized, registeration requires a invitation code;
|
||||||
|
assert_matches!(
|
||||||
|
conn.register(
|
||||||
|
email.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
.await,
|
||||||
|
Err(RegisterError::InvalidInvitationCode)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Invalid invitation code won't work.
|
||||||
|
assert_matches!(
|
||||||
|
conn.register(
|
||||||
|
email.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
Some("abc".to_owned())
|
||||||
|
)
|
||||||
|
.await,
|
||||||
|
Err(RegisterError::InvalidInvitationCode)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Register success.
|
||||||
|
assert!(conn
|
||||||
|
.register(
|
||||||
|
email.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
Some(invitation.code.clone())
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
|
||||||
|
// Try register again with same email failed.
|
||||||
|
assert_matches!(
|
||||||
|
conn.register(
|
||||||
|
email.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
password.to_owned(),
|
||||||
|
Some(invitation.code.clone())
|
||||||
|
)
|
||||||
|
.await,
|
||||||
|
Err(RegisterError::DuplicateEmail)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@ lazy_static! {
|
||||||
CONSTRAINT `idx_token` UNIQUE (`token`)
|
CONSTRAINT `idx_token` UNIQUE (`token`)
|
||||||
);
|
);
|
||||||
"#
|
"#
|
||||||
),
|
)
|
||||||
|
.down("DROP TABLE registeration_token"),
|
||||||
M::up(
|
M::up(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE users (
|
CREATE TABLE users (
|
||||||
|
|
@ -35,7 +36,8 @@ lazy_static! {
|
||||||
CONSTRAINT `idx_email` UNIQUE (`email`)
|
CONSTRAINT `idx_email` UNIQUE (`email`)
|
||||||
);
|
);
|
||||||
"#
|
"#
|
||||||
),
|
)
|
||||||
|
.down("DROP TABLE users"),
|
||||||
M::up(
|
M::up(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE invitations (
|
CREATE TABLE invitations (
|
||||||
|
|
@ -47,7 +49,8 @@ lazy_static! {
|
||||||
CONSTRAINT `idx_code` UNIQUE (`code`)
|
CONSTRAINT `idx_code` UNIQUE (`code`)
|
||||||
);
|
);
|
||||||
"#
|
"#
|
||||||
),
|
)
|
||||||
|
.down("DROP TABLE invitations"),
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -93,6 +96,12 @@ pub struct DbConn {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DbConn {
|
impl DbConn {
|
||||||
|
#[cfg(test)]
|
||||||
|
pub async fn new_in_memory() -> Result<Self> {
|
||||||
|
let conn = Connection::open_in_memory().await?;
|
||||||
|
DbConn::init_db(conn).await
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn new() -> Result<Self> {
|
pub async fn new() -> Result<Self> {
|
||||||
let db_path = db_path().await?;
|
let db_path = db_path().await?;
|
||||||
let conn = Connection::open(db_path).await?;
|
let conn = Connection::open(db_path).await?;
|
||||||
|
|
@ -302,21 +311,14 @@ impl DbConn {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use juniper::FieldResult;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::schema::auth::AuthenticationService;
|
use crate::schema::auth::AuthenticationService;
|
||||||
|
|
||||||
async fn new_in_memory() -> Result<DbConn> {
|
async fn create_user(conn: &DbConn) -> i32 {
|
||||||
let conn = Connection::open_in_memory().await?;
|
let email: &str = "test@example.com";
|
||||||
DbConn::init_db(conn).await
|
let password: &str = "123456789";
|
||||||
}
|
conn.create_user(email.to_string(), password.to_string(), true)
|
||||||
|
|
||||||
async fn create_admin_user(conn: &DbConn) -> i32 {
|
|
||||||
let email = "test@example.com";
|
|
||||||
let passwd = "123456";
|
|
||||||
let is_admin = true;
|
|
||||||
conn.create_user(email.to_string(), passwd.to_string(), is_admin)
|
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
@ -328,14 +330,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_token() {
|
async fn test_token() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
let token = conn.read_registration_token().await.unwrap();
|
let token = conn.read_registration_token().await.unwrap();
|
||||||
assert_eq!(token.len(), 36);
|
assert_eq!(token.len(), 36);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_update_token() {
|
async fn test_update_token() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
let old_token = conn.read_registration_token().await.unwrap();
|
let old_token = conn.read_registration_token().await.unwrap();
|
||||||
conn.reset_registration_token().await.unwrap();
|
conn.reset_registration_token().await.unwrap();
|
||||||
|
|
@ -346,16 +348,16 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_user() {
|
async fn test_create_user() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
let id = create_admin_user(&conn).await;
|
let id = create_user(&conn).await;
|
||||||
let user = conn.get_user(id).await.unwrap().unwrap();
|
let user = conn.get_user(id).await.unwrap().unwrap();
|
||||||
assert_eq!(user.id, 1);
|
assert_eq!(user.id, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_get_user_by_email() {
|
async fn test_get_user_by_email() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
let email = "hello@example.com";
|
let email = "hello@example.com";
|
||||||
let user = conn.get_user_by_email(email).await.unwrap();
|
let user = conn.get_user_by_email(email).await.unwrap();
|
||||||
|
|
@ -365,16 +367,16 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_is_admin_initialized() {
|
async fn test_is_admin_initialized() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
assert!(!conn.is_admin_initialized().await.unwrap());
|
assert!(!conn.is_admin_initialized().await.unwrap());
|
||||||
create_admin_user(&conn).await;
|
create_user(&conn).await;
|
||||||
assert!(conn.is_admin_initialized().await.unwrap());
|
assert!(conn.is_admin_initialized().await.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_invitations() {
|
async fn test_invitations() {
|
||||||
let conn = new_in_memory().await.unwrap();
|
let conn = DbConn::new_in_memory().await.unwrap();
|
||||||
|
|
||||||
let email = "hello@example.com".to_owned();
|
let email = "hello@example.com".to_owned();
|
||||||
conn.create_invitation(email).await.unwrap();
|
conn.create_invitation(email).await.unwrap();
|
||||||
|
|
@ -396,62 +398,4 @@ mod tests {
|
||||||
let invitations = conn.list_invitations().await.unwrap();
|
let invitations = conn.list_invitations().await.unwrap();
|
||||||
assert!(invitations.is_empty());
|
assert!(invitations.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_invitation_flow() {
|
|
||||||
let conn = new_in_memory().await.unwrap();
|
|
||||||
|
|
||||||
assert!(!conn.is_admin_initialized().await.unwrap());
|
|
||||||
create_admin_user(&conn).await;
|
|
||||||
|
|
||||||
let email = "user@user.com";
|
|
||||||
let password = "12345678";
|
|
||||||
|
|
||||||
conn.create_invitation(email.to_owned()).await.unwrap();
|
|
||||||
let invitation = &conn.list_invitations().await.unwrap()[0];
|
|
||||||
|
|
||||||
// Admin initialized, registeration requires a invitation code;
|
|
||||||
assert!(
|
|
||||||
conn.register(
|
|
||||||
email.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
.await.is_err()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Invalid invitation code won't work.
|
|
||||||
assert!(conn
|
|
||||||
.register(
|
|
||||||
email.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
Some("abc".to_owned())
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.is_err());
|
|
||||||
|
|
||||||
// Register success.
|
|
||||||
assert!(conn
|
|
||||||
.register(
|
|
||||||
email.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
Some(invitation.code.clone())
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.is_ok());
|
|
||||||
|
|
||||||
// Try register again with same email failed.
|
|
||||||
assert!(conn
|
|
||||||
.register(
|
|
||||||
email.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
password.to_owned(),
|
|
||||||
Some(invitation.code.clone())
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.is_err());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue