411 lines
12 KiB
Rust
411 lines
12 KiB
Rust
use anyhow::Result;
|
|
use argon2::{
|
|
password_hash,
|
|
password_hash::{rand_core::OsRng, SaltString},
|
|
Argon2, PasswordHasher, PasswordVerifier,
|
|
};
|
|
use async_trait::async_trait;
|
|
use validator::Validate;
|
|
|
|
use super::db::DbConn;
|
|
use crate::schema::auth::{
|
|
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Claims, Invitation,
|
|
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
|
|
TokenAuthResponse, UserInfo, VerifyTokenResponse,
|
|
};
|
|
|
|
/// Input parameters for register mutation
|
|
/// `validate` attribute is used to validate the input parameters
|
|
/// - `code` argument specifies which parameter causes the failure
|
|
/// - `message` argument provides client friendly error message
|
|
///
|
|
#[derive(Validate)]
|
|
struct RegisterInput {
|
|
#[validate(email(code = "email", message = "Email is invalid"))]
|
|
#[validate(length(
|
|
max = 128,
|
|
code = "email",
|
|
message = "Email must be at most 128 characters"
|
|
))]
|
|
email: String,
|
|
#[validate(length(
|
|
min = 8,
|
|
code = "password1",
|
|
message = "Password must be at least 8 characters"
|
|
))]
|
|
#[validate(length(
|
|
max = 20,
|
|
code = "password1",
|
|
message = "Password must be at most 20 characters"
|
|
))]
|
|
password1: String,
|
|
#[validate(length(
|
|
min = 8,
|
|
code = "password2",
|
|
message = "Password must be at least 8 characters"
|
|
))]
|
|
#[validate(must_match(
|
|
code = "password2",
|
|
message = "Passwords do not match",
|
|
other = "password1"
|
|
))]
|
|
#[validate(length(
|
|
max = 20,
|
|
code = "password2",
|
|
message = "Password must be at most 20 characters"
|
|
))]
|
|
password2: String,
|
|
}
|
|
|
|
impl std::fmt::Debug for RegisterInput {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("RegisterInput")
|
|
.field("email", &self.email)
|
|
.field("password1", &"********")
|
|
.field("password2", &"********")
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
/// Input parameters for token_auth mutation
|
|
/// See `RegisterInput` for `validate` attribute usage
|
|
#[derive(Validate)]
|
|
struct TokenAuthInput {
|
|
#[validate(email(code = "email", message = "Email is invalid"))]
|
|
#[validate(length(
|
|
max = 128,
|
|
code = "email",
|
|
message = "Email must be at most 128 characters"
|
|
))]
|
|
email: String,
|
|
#[validate(length(
|
|
min = 8,
|
|
code = "password",
|
|
message = "Password must be at least 8 characters"
|
|
))]
|
|
#[validate(length(
|
|
max = 20,
|
|
code = "password",
|
|
message = "Password must be at most 20 characters"
|
|
))]
|
|
password: String,
|
|
}
|
|
|
|
impl std::fmt::Debug for TokenAuthInput {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("TokenAuthInput")
|
|
.field("email", &self.email)
|
|
.field("password", &"********")
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AuthenticationService for DbConn {
|
|
async fn register(
|
|
&self,
|
|
email: String,
|
|
password1: String,
|
|
password2: String,
|
|
invitation_code: Option<String>,
|
|
) -> std::result::Result<RegisterResponse, RegisterError> {
|
|
let input = RegisterInput {
|
|
email,
|
|
password1,
|
|
password2,
|
|
};
|
|
input.validate()?;
|
|
|
|
let is_admin_initialized = self.is_admin_initialized().await?;
|
|
if is_admin_initialized {
|
|
let err = Err(RegisterError::InvalidInvitationCode);
|
|
let Some(invitation_code) = invitation_code else {
|
|
return err;
|
|
};
|
|
|
|
let Some(invitation) = self.get_invitation_by_code(&invitation_code).await? else {
|
|
return err;
|
|
};
|
|
|
|
if invitation.email != input.email {
|
|
return err;
|
|
}
|
|
};
|
|
|
|
// check if email exists
|
|
if self.get_user_by_email(&input.email).await?.is_some() {
|
|
return Err(RegisterError::DuplicateEmail);
|
|
}
|
|
|
|
let Ok(pwd_hash) = password_hash(&input.password1) else {
|
|
return Err(RegisterError::Unknown);
|
|
};
|
|
|
|
let id = self
|
|
.create_user(input.email.clone(), pwd_hash, !is_admin_initialized)
|
|
.await?;
|
|
let user = self.get_user(id).await?.unwrap();
|
|
|
|
let refresh_token = generate_refresh_token();
|
|
self.create_refresh_token(id, &refresh_token).await?;
|
|
|
|
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
|
user.email.clone(),
|
|
user.is_admin,
|
|
))) else {
|
|
return Err(RegisterError::Unknown);
|
|
};
|
|
|
|
let resp = RegisterResponse::new(access_token, refresh_token);
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn token_auth(
|
|
&self,
|
|
email: String,
|
|
password: String,
|
|
) -> std::result::Result<TokenAuthResponse, TokenAuthError> {
|
|
let input = TokenAuthInput { email, password };
|
|
input.validate()?;
|
|
|
|
let Some(user) = self.get_user_by_email(&input.email).await? else {
|
|
return Err(TokenAuthError::UserNotFound);
|
|
};
|
|
|
|
if !password_verify(&input.password, &user.password_encrypted) {
|
|
return Err(TokenAuthError::InvalidPassword);
|
|
}
|
|
|
|
let refresh_token = generate_refresh_token();
|
|
self.create_refresh_token(user.id, &refresh_token).await?;
|
|
|
|
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
|
user.email.clone(),
|
|
user.is_admin,
|
|
))) else {
|
|
return Err(TokenAuthError::Unknown);
|
|
};
|
|
|
|
let resp = TokenAuthResponse::new(access_token, refresh_token);
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn refresh_token(
|
|
&self,
|
|
token: String,
|
|
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> {
|
|
let Some(refresh_token) = self.get_refresh_token(&token).await? else {
|
|
return Err(RefreshTokenError::InvalidRefreshToken);
|
|
};
|
|
if refresh_token.is_expired() {
|
|
return Err(RefreshTokenError::ExpiredRefreshToken);
|
|
}
|
|
let Some(user) = self.get_user(refresh_token.user_id).await? else {
|
|
return Err(RefreshTokenError::UserNotFound);
|
|
};
|
|
|
|
let new_token = generate_refresh_token();
|
|
self.replace_refresh_token(&token, &new_token).await?;
|
|
|
|
// refresh token update is done, generate new access token based on user info
|
|
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
|
user.email.clone(),
|
|
user.is_admin,
|
|
))) else {
|
|
return Err(RefreshTokenError::Unknown);
|
|
};
|
|
|
|
let resp = RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at);
|
|
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
|
let claims = validate_jwt(&access_token)?;
|
|
let resp = VerifyTokenResponse::new(claims);
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn is_admin_initialized(&self) -> Result<bool> {
|
|
let admin = self.list_admin_users().await?;
|
|
Ok(!admin.is_empty())
|
|
}
|
|
|
|
async fn create_invitation(&self, email: String) -> Result<i32> {
|
|
self.create_invitation(email).await
|
|
}
|
|
|
|
async fn list_invitations(&self) -> Result<Vec<Invitation>> {
|
|
self.list_invitations().await
|
|
}
|
|
|
|
async fn delete_invitation(&self, id: i32) -> Result<i32> {
|
|
self.delete_invitation(id).await
|
|
}
|
|
}
|
|
|
|
fn password_hash(raw: &str) -> password_hash::Result<String> {
|
|
let salt = SaltString::generate(&mut OsRng);
|
|
let argon2 = Argon2::default();
|
|
let hash = argon2.hash_password(raw.as_bytes(), &salt)?.to_string();
|
|
|
|
Ok(hash)
|
|
}
|
|
|
|
fn password_verify(raw: &str, hash: &str) -> bool {
|
|
if let Ok(parsed_hash) = argon2::PasswordHash::new(hash) {
|
|
let argon2 = Argon2::default();
|
|
argon2.verify_password(raw.as_bytes(), &parsed_hash).is_ok()
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use assert_matches::assert_matches;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_password_hash() {
|
|
let raw = "12345678";
|
|
let hash = password_hash(raw).unwrap();
|
|
|
|
assert_eq!(hash.len(), 97);
|
|
assert!(hash.starts_with("$argon2id$v=19$m=19456,t=2,p=1$"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_password_verify() {
|
|
let raw = "12345678";
|
|
let hash = password_hash(raw).unwrap();
|
|
|
|
assert!(password_verify(raw, &hash));
|
|
assert!(!password_verify(raw, "invalid hash"));
|
|
}
|
|
|
|
static ADMIN_EMAIL: &str = "test@example.com";
|
|
static ADMIN_PASSWORD: &str = "123456789";
|
|
|
|
async fn register_admin_user(conn: &DbConn) -> RegisterResponse {
|
|
conn.register(
|
|
ADMIN_EMAIL.to_owned(),
|
|
ADMIN_PASSWORD.to_owned(),
|
|
ADMIN_PASSWORD.to_owned(),
|
|
None,
|
|
)
|
|
.await
|
|
.unwrap()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_auth_token() {
|
|
let conn = DbConn::new_in_memory().await.unwrap();
|
|
assert_matches!(
|
|
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
|
|
.await,
|
|
Err(TokenAuthError::UserNotFound)
|
|
);
|
|
|
|
register_admin_user(&conn).await;
|
|
|
|
assert_matches!(
|
|
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
|
|
.await,
|
|
Err(TokenAuthError::InvalidPassword)
|
|
);
|
|
|
|
let resp1 = conn
|
|
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
|
|
.await
|
|
.unwrap();
|
|
let resp2 = conn
|
|
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
|
|
.await
|
|
.unwrap();
|
|
// each auth should generate a new refresh token
|
|
assert_ne!(resp1.refresh_token, resp2.refresh_token);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_invitation_flow() {
|
|
let conn = DbConn::new_in_memory().await.unwrap();
|
|
|
|
assert!(!conn.is_admin_initialized().await.unwrap());
|
|
register_admin_user(&conn).await;
|
|
|
|
let email = "user@user.com";
|
|
let password = "12345678";
|
|
|
|
conn.create_invitation(email.to_owned()).await.unwrap();
|
|
let invitation = &conn.list_invitations().await.unwrap()[0];
|
|
|
|
// Admin initialized, registeration requires a invitation code;
|
|
assert_matches!(
|
|
conn.register(
|
|
email.to_owned(),
|
|
password.to_owned(),
|
|
password.to_owned(),
|
|
None
|
|
)
|
|
.await,
|
|
Err(RegisterError::InvalidInvitationCode)
|
|
);
|
|
|
|
// Invalid invitation code won't work.
|
|
assert_matches!(
|
|
conn.register(
|
|
email.to_owned(),
|
|
password.to_owned(),
|
|
password.to_owned(),
|
|
Some("abc".to_owned())
|
|
)
|
|
.await,
|
|
Err(RegisterError::InvalidInvitationCode)
|
|
);
|
|
|
|
// Register success.
|
|
assert!(conn
|
|
.register(
|
|
email.to_owned(),
|
|
password.to_owned(),
|
|
password.to_owned(),
|
|
Some(invitation.code.clone())
|
|
)
|
|
.await
|
|
.is_ok());
|
|
|
|
// Try register again with same email failed.
|
|
assert_matches!(
|
|
conn.register(
|
|
email.to_owned(),
|
|
password.to_owned(),
|
|
password.to_owned(),
|
|
Some(invitation.code.clone())
|
|
)
|
|
.await,
|
|
Err(RegisterError::DuplicateEmail)
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_refresh_token() {
|
|
let conn = DbConn::new_in_memory().await.unwrap();
|
|
let reg = register_admin_user(&conn).await;
|
|
|
|
let resp1 = conn.refresh_token(reg.refresh_token.clone()).await.unwrap();
|
|
// new access token should be valid
|
|
assert!(validate_jwt(&resp1.access_token).is_ok());
|
|
// refresh token should be renewed
|
|
assert_ne!(reg.refresh_token, resp1.refresh_token);
|
|
|
|
let resp2 = conn
|
|
.refresh_token(resp1.refresh_token.clone())
|
|
.await
|
|
.unwrap();
|
|
// expire time should be no change
|
|
assert_eq!(resp1.refresh_expires_at, resp2.refresh_expires_at);
|
|
}
|
|
}
|