feat(ee): implement user authentication api (#912)

* feat: impl user authentication

* resolve comments

* fix validation code name

* resolve comment
add-signin-page
Eric 2023-12-01 10:22:53 +08:00 committed by GitHub
parent ffd5ef3449
commit 79e704458d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 785 additions and 37 deletions

201
Cargo.lock generated
View File

@ -23,7 +23,7 @@ version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd"
dependencies = [
"getrandom 0.2.9",
"getrandom 0.2.11",
"once_cell",
"version_check",
]
@ -180,6 +180,18 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "argon2"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17ba4cac0a46bc1d2912652a751c47f2a9f3a7fe89bcae2275d418f5270402f9"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash",
]
[[package]]
name = "arrayvec"
version = "0.7.4"
@ -572,6 +584,12 @@ version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
[[package]]
name = "base64ct"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "beef"
version = "0.5.2"
@ -608,6 +626,15 @@ dependencies = [
"crunchy",
]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -715,11 +742,12 @@ dependencies = [
[[package]]
name = "cc"
version = "1.0.79"
version = "1.0.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
dependencies = [
"jobserver",
"libc",
]
[[package]]
@ -926,9 +954,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
version = "0.2.7"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58"
checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0"
dependencies = [
"libc",
]
@ -1731,9 +1759,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.9"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
dependencies = [
"cfg-if",
"libc",
@ -2085,6 +2113,22 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "idna"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "if_chain"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed"
[[package]]
name = "ignore"
version = "0.4.20"
@ -2230,6 +2274,20 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonwebtoken"
version = "9.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "155c4d7e39ad04c172c5e3a99c434ea3b4a7ba7960b38ecd562b270b097cce09"
dependencies = [
"base64 0.21.2",
"pem",
"ring 0.17.5",
"serde",
"serde_json",
"simple_asn1",
]
[[package]]
name = "juniper"
version = "0.15.11"
@ -2832,6 +2890,27 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.17"
@ -3132,6 +3211,17 @@ dependencies = [
"windows-targets 0.48.0",
]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core 0.6.4",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.12"
@ -3511,7 +3601,7 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom 0.2.9",
"getrandom 0.2.11",
]
[[package]]
@ -3590,21 +3680,21 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom 0.2.9",
"getrandom 0.2.11",
"redox_syscall 0.2.16",
"thiserror",
]
[[package]]
name = "regex"
version = "1.10.0"
version = "1.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87"
checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.1",
"regex-syntax 0.8.1",
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
]
[[package]]
@ -3618,13 +3708,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.1"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b"
checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.1",
"regex-syntax 0.8.2",
]
[[package]]
@ -3635,9 +3725,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.1"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33"
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
name = "requirements"
@ -3709,12 +3799,12 @@ dependencies = [
[[package]]
name = "ring"
version = "0.17.3"
version = "0.17.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e"
checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b"
dependencies = [
"cc",
"getrandom 0.2.9",
"getrandom 0.2.11",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
@ -3938,7 +4028,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c"
dependencies = [
"log",
"ring 0.17.3",
"ring 0.17.5",
"rustls-webpki",
"sct",
]
@ -3958,7 +4048,7 @@ version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring 0.17.3",
"ring 0.17.5",
"untrusted 0.9.0",
]
@ -4033,7 +4123,7 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring 0.17.3",
"ring 0.17.5",
"untrusted 0.9.0",
]
@ -4292,6 +4382,18 @@ dependencies = [
"libc",
]
[[package]]
name = "simple_asn1"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085"
dependencies = [
"num-bigint",
"num-traits",
"thiserror",
"time",
]
[[package]]
name = "sketches-ddsketch"
version = "0.2.1"
@ -4694,12 +4796,14 @@ name = "tabby-webserver"
version = "0.7.0-dev"
dependencies = [
"anyhow",
"argon2",
"async-trait",
"axum",
"bincode",
"chrono",
"futures",
"hyper",
"jsonwebtoken",
"juniper",
"juniper-axum",
"lazy_static",
@ -4720,6 +4824,7 @@ dependencies = [
"tracing",
"unicase",
"uuid 1.4.1",
"validator",
]
[[package]]
@ -5735,7 +5840,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna",
"idna 0.3.0",
"percent-encoding",
]
@ -5818,7 +5923,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
dependencies = [
"getrandom 0.2.9",
"getrandom 0.2.11",
]
[[package]]
@ -5827,7 +5932,7 @@ version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d"
dependencies = [
"getrandom 0.2.9",
"getrandom 0.2.11",
"rand 0.8.5",
"serde",
"uuid-macro-internal",
@ -5844,6 +5949,48 @@ dependencies = [
"syn 2.0.28",
]
[[package]]
name = "validator"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd"
dependencies = [
"idna 0.4.0",
"lazy_static",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af"
dependencies = [
"if_chain",
"lazy_static",
"proc-macro-error",
"proc-macro2",
"quote",
"regex",
"syn 1.0.109",
"validator_types",
]
[[package]]
name = "validator_types"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3"
dependencies = [
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "valuable"
version = "0.1.0"

View File

@ -7,12 +7,14 @@ homepage.workspace = true
[dependencies]
anyhow.workspace = true
argon2 = "0.5.1"
async-trait.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
chrono = "0.4"
futures.workspace = true
hyper = { workspace = true, features=["client"]}
jsonwebtoken = "9.1.0"
juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
@ -33,6 +35,7 @@ tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing.workspace = true
unicase = "2.7.0"
validator = { version = "0.16.1", features = ["derive"] }
[dependencies.uuid]
version = "1.3.3"

View File

@ -1,10 +1,40 @@
type RegisterResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}
type AuthError {
message: String!
code: String!
}
enum WorkerKind {
COMPLETION
CHAT
}
type Mutation {
resetRegistrationToken: String!
resetRegistrationToken(token: String): String!
register(email: String!, password1: String!, password2: String!): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
}
type UserInfo {
email: String!
isAdmin: Boolean!
}
type VerifyTokenResponse {
errors: [AuthError!]!
claims: Claims!
}
type Claims {
exp: Float!
iat: Float!
user: UserInfo!
}
type Query {
@ -23,6 +53,12 @@ type Worker {
cudaDevices: [String!]!
}
type TokenAuthResponse {
accessToken: String!
refreshToken: String!
errors: [AuthError!]!
}
schema {
query: Query
mutation: Mutation

View File

@ -2,14 +2,15 @@ use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use lazy_static::lazy_static;
use rusqlite::params;
use rusqlite::{params, OptionalExtension};
use rusqlite_migration::{AsyncMigrations, M};
use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;
lazy_static! {
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![M::up(
r#"
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![
M::up(
r#"
CREATE TABLE IF NOT EXISTS registration_token (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token VARCHAR(255) NOT NULL,
@ -18,7 +19,32 @@ lazy_static! {
CONSTRAINT `idx_token` UNIQUE (`token`)
);
"#
),]);
),
M::up(
r#"
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email VARCHAR(150) NOT NULL COLLATE NOCASE,
password_encrypted VARCHAR(128) NOT NULL,
is_admin BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
updated_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_email` UNIQUE (`email`)
);
"#
),
]);
}
#[allow(unused)]
pub struct User {
created_at: String,
updated_at: String,
pub id: u32,
pub email: String,
pub password_encrypted: String,
pub is_admin: bool,
}
async fn db_path() -> Result<PathBuf> {
@ -27,6 +53,7 @@ async fn db_path() -> Result<PathBuf> {
Ok(db_dir.join("db.sqlite"))
}
#[derive(Clone)]
pub struct DbConn {
conn: Arc<Connection>,
}
@ -55,7 +82,10 @@ impl DbConn {
conn: Arc::new(conn),
})
}
}
/// db read/write operations for `registration_token` table
impl DbConn {
/// Query token from database.
/// Since token is global unique for each tabby server, by right there's only one row in the table.
pub async fn read_registration_token(&self) -> Result<String> {
@ -96,6 +126,56 @@ impl DbConn {
}
}
/// db read/write operations for `users` table
impl DbConn {
pub async fn create_user(
&self,
email: String,
password_encrypted: String,
is_admin: bool,
) -> Result<()> {
let res = self
.conn
.call(move |c| {
c.execute(
r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#,
params![email, password_encrypted, is_admin],
)
})
.await?;
if res != 1 {
return Err(anyhow::anyhow!("failed to create user"));
}
Ok(())
}
pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
let email = email.to_string();
let user = self
.conn
.call(move |c| {
c.query_row(
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE email = ?"#,
params![email],
|row| {
Ok(User {
id: row.get(0)?,
email: row.get(1)?,
password_encrypted: row.get(2)?,
is_admin: row.get(3)?,
created_at: row.get(4)?,
updated_at: row.get(5)?,
})
},
).optional()
})
.await?;
Ok(user)
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -127,4 +207,29 @@ mod tests {
assert_eq!(new_token.len(), 36);
assert_ne!(old_token, new_token);
}
#[tokio::test]
async fn test_create_user() {
let conn = new_in_memory().await.unwrap();
let email = "test@example.com";
let passwd = "123456";
let is_admin = true;
conn.create_user(email.to_string(), passwd.to_string(), is_admin)
.await
.unwrap();
let user = conn.get_user_by_email(email).await.unwrap().unwrap();
assert_eq!(user.id, 1);
}
#[tokio::test]
async fn test_get_user_by_email() {
let conn = new_in_memory().await.unwrap();
let email = "hello@example.com";
let user = conn.get_user_by_email(email).await.unwrap();
assert!(user.is_none());
}
}

View File

@ -1,6 +1,17 @@
use juniper::{graphql_object, EmptySubscription, FieldResult, RootNode};
pub mod auth;
use crate::{api::Worker, server::ServerContext};
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};
use crate::{
api::Worker,
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
server::{
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
ServerContext,
},
};
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for ServerContext {}
@ -25,9 +36,47 @@ pub struct Mutation;
#[graphql_object(context = ServerContext)]
impl Mutation {
async fn reset_registration_token(ctx: &ServerContext) -> FieldResult<String> {
let token = ctx.reset_registration_token().await?;
Ok(token)
async fn reset_registration_token(
ctx: &ServerContext,
token: Option<String>,
) -> FieldResult<String> {
if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) {
if claims.user_info().is_admin() {
let reg_token = ctx.reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}
async fn register(
ctx: &ServerContext,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
ctx.auth().register(input).await
}
async fn token_auth(
ctx: &ServerContext,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.auth().token_auth(input).await
}
async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.auth().verify_token(token).await
}
}

View File

@ -0,0 +1,129 @@
use std::fmt::Debug;
use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value};
use serde::{Deserialize, Serialize};
use validator::ValidationError;
use crate::server::auth::JWT_DEFAULT_EXP;
#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
}
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));
FieldError::new("Invalid input parameters", ext.into())
}
}
#[derive(Debug, GraphQLObject)]
pub struct RegisterResponse {
access_token: String,
refresh_token: String,
}
impl RegisterResponse {
pub fn new(access_token: String, refresh_token: String) -> Self {
Self {
access_token,
refresh_token,
}
}
}
#[derive(Debug, GraphQLObject)]
pub struct TokenAuthResponse {
access_token: String,
refresh_token: String,
}
impl TokenAuthResponse {
pub fn new(access_token: String, refresh_token: String) -> Self {
Self {
access_token,
refresh_token,
}
}
}
#[derive(Debug, Default, GraphQLObject)]
pub struct RefreshTokenResponse {
access_token: String,
refresh_token: String,
refresh_expires_in: i32,
}
#[derive(Debug, GraphQLObject)]
pub struct VerifyTokenResponse {
claims: Claims,
}
impl VerifyTokenResponse {
pub fn new(claims: Claims) -> Self {
Self { claims }
}
}
#[derive(Debug, Default, PartialEq, Serialize, Deserialize, GraphQLObject)]
pub struct UserInfo {
email: String,
is_admin: bool,
}
impl UserInfo {
pub fn new(email: String, is_admin: bool) -> Self {
Self { email, is_admin }
}
pub fn is_admin(&self) -> bool {
self.is_admin
}
pub fn email(&self) -> &str {
&self.email
}
}
#[derive(Debug, Default, Serialize, Deserialize, GraphQLObject)]
pub struct Claims {
// Required. Expiration time (as UTC timestamp)
exp: f64,
// Optional. Issued at (as UTC timestamp)
iat: f64,
// Customized. user info
user: UserInfo,
}
impl Claims {
pub fn new(user: UserInfo) -> Self {
let now = jwt::get_current_timestamp();
Self {
iat: now as f64,
exp: (now + *JWT_DEFAULT_EXP) as f64,
user,
}
}
pub fn user_info(self) -> UserInfo {
self.user
}
}

View File

@ -1,3 +1,4 @@
pub mod auth;
mod proxy;
mod worker;
@ -12,6 +13,7 @@ use tracing::{info, warn};
use crate::{
api::{RegisterWorkerError, Worker, WorkerKind},
db::DbConn,
server::auth::AuthenticationService,
};
pub struct ServerContext {
@ -40,6 +42,10 @@ impl ServerContext {
}
}
pub fn auth(&self) -> impl AuthenticationService {
self.db_conn.clone()
}
/// Query current token from the database.
pub async fn read_registration_token(&self) -> Result<String> {
self.db_conn.read_registration_token().await

View File

@ -0,0 +1,273 @@
use std::env;
use argon2::{
password_hash,
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher, PasswordVerifier,
};
use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldResult, IntoFieldError};
use lazy_static::lazy_static;
use validator::Validate;
use crate::{
db::DbConn,
schema::auth::{
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
ValidationErrors, VerifyTokenResponse,
},
};
lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}
/// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters
/// - `code` argument specifies which parameter causes the failure
/// - `message` argument provides client friendly error message
///
#[derive(Validate)]
pub struct RegisterInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
#[validate(length(
min = 8,
code = "password1",
message = "Password must be at least 8 characters"
))]
#[validate(length(
max = 20,
code = "password1",
message = "Password must be at most 20 characters"
))]
#[validate(must_match(
code = "password1",
message = "Passwords do not match",
other = "password2"
))]
pub password1: String,
#[validate(length(
min = 8,
code = "password2",
message = "Password must be at least 8 characters"
))]
#[validate(length(
max = 20,
code = "password2",
message = "Password must be at most 20 characters"
))]
pub password2: String,
}
impl std::fmt::Debug for RegisterInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RegisterInput")
.field("email", &self.email)
.field("password1", &"********")
.field("password2", &"********")
.finish()
}
}
/// Input parameters for token_auth mutation
/// See `RegisterInput` for `validate` attribute usage
#[derive(Validate)]
pub struct TokenAuthInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
#[validate(length(
min = 8,
code = "password",
message = "Password must be at least 8 characters"
))]
#[validate(length(
max = 20,
code = "password",
message = "Password must be at most 20 characters"
))]
pub password: String,
}
impl std::fmt::Debug for TokenAuthInput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenAuthInput")
.field("email", &self.email)
.field("password", &"********")
.finish()
}
}
#[async_trait]
pub trait AuthenticationService {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
#[async_trait]
impl AuthenticationService for DbConn {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> {
input.validate().map_err(|err| {
let errors = err
.field_errors()
.into_iter()
.flat_map(|(_, errs)| errs)
.cloned()
.collect();
ValidationErrors { errors }.into_field_error()
})?;
// check if email exists
if let Some(_) = self.get_user_by_email(&input.email).await? {
return Err("Email already exists".into());
}
let pwd_hash = password_hash(&input.password1)?;
self.create_user(input.email.clone(), pwd_hash, false)
.await?;
let user = self.get_user_by_email(&input.email).await?.unwrap();
let access_token = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
user.is_admin,
)))?;
let resp = RegisterResponse::new(access_token, "".to_string());
Ok(resp)
}
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> {
input.validate().map_err(|err| {
let errors = err
.field_errors()
.into_iter()
.flat_map(|(_, errs)| errs)
.cloned()
.collect();
ValidationErrors { errors }.into_field_error()
})?;
let user = self.get_user_by_email(&input.email).await?;
let user = match user {
Some(user) => user,
None => return Err("User not found".into()),
};
if !password_verify(&input.password, &user.password_encrypted) {
return Err("Password incorrect".into());
}
let access_token = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
user.is_admin,
)))?;
let resp = TokenAuthResponse::new(access_token, "".to_string());
Ok(resp)
}
async fn refresh_token(&self, _refresh_token: String) -> FieldResult<RefreshTokenResponse> {
Ok(RefreshTokenResponse::default())
}
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse> {
let claims = validate_jwt(&access_token)?;
let resp = VerifyTokenResponse::new(claims);
Ok(resp)
}
}
fn password_hash(raw: &str) -> password_hash::Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2.hash_password(raw.as_bytes(), &salt)?.to_string();
Ok(hash)
}
fn password_verify(raw: &str, hash: &str) -> bool {
if let Ok(parsed_hash) = argon2::PasswordHash::new(hash) {
let argon2 = Argon2::default();
argon2.verify_password(raw.as_bytes(), &parsed_hash).is_ok()
} else {
false
}
}
fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}
fn jwt_token_secret() -> String {
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_password_hash() {
let raw = "12345678";
let hash = password_hash(raw).unwrap();
assert_eq!(hash.len(), 97);
assert!(hash.starts_with("$argon2id$v=19$m=19456,t=2,p=1$"));
}
#[test]
fn test_password_verify() {
let raw = "12345678";
let hash = password_hash(raw).unwrap();
assert!(password_verify(raw, &hash));
assert!(!password_verify(raw, "invalid hash"));
}
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty())
}
#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false));
}
}