From 79e704458dbce17b95fb6df2f0f68edf4f207f38 Mon Sep 17 00:00:00 2001 From: Eric Date: Fri, 1 Dec 2023 10:22:53 +0800 Subject: [PATCH] feat(ee): implement user authentication api (#912) * feat: impl user authentication * resolve comments * fix validation code name * resolve comment --- Cargo.lock | 201 +++++++++++++--- ee/tabby-webserver/Cargo.toml | 3 + ee/tabby-webserver/graphql/schema.graphql | 38 ++- ee/tabby-webserver/src/db.rs | 113 ++++++++- ee/tabby-webserver/src/schema.rs | 59 ++++- ee/tabby-webserver/src/schema/auth.rs | 129 ++++++++++ ee/tabby-webserver/src/server.rs | 6 + ee/tabby-webserver/src/server/auth.rs | 273 ++++++++++++++++++++++ 8 files changed, 785 insertions(+), 37 deletions(-) create mode 100644 ee/tabby-webserver/src/schema/auth.rs create mode 100644 ee/tabby-webserver/src/server/auth.rs diff --git a/Cargo.lock b/Cargo.lock index 4ae0434..9cc4989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,7 +23,7 @@ version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", "once_cell", "version_check", ] @@ -180,6 +180,18 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "argon2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17ba4cac0a46bc1d2912652a751c47f2a9f3a7fe89bcae2275d418f5270402f9" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "arrayvec" version = "0.7.4" @@ -572,6 +584,12 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "beef" version = "0.5.2" @@ -608,6 +626,15 @@ dependencies = [ "crunchy", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -715,11 +742,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", + "libc", ] [[package]] @@ -926,9 +954,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] @@ -1731,9 +1759,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "libc", @@ -2085,6 +2113,22 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "if_chain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed" + [[package]] name = "ignore" version = "0.4.20" @@ -2230,6 +2274,20 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "9.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "155c4d7e39ad04c172c5e3a99c434ea3b4a7ba7960b38ecd562b270b097cce09" +dependencies = [ + "base64 0.21.2", + "pem", + "ring 0.17.5", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "juniper" version = "0.15.11" @@ -2832,6 +2890,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.17" @@ -3132,6 +3211,17 @@ dependencies = [ "windows-targets 0.48.0", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "paste" version = "1.0.12" @@ -3511,7 +3601,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", ] [[package]] @@ -3590,21 +3680,21 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", "redox_syscall 0.2.16", "thiserror", ] [[package]] name = "regex" -version = "1.10.0" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.1", - "regex-syntax 0.8.1", + "regex-automata 0.4.3", + "regex-syntax 0.8.2", ] [[package]] @@ -3618,13 +3708,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.1" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.1", + "regex-syntax 0.8.2", ] [[package]] @@ -3635,9 +3725,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "requirements" @@ -3709,12 +3799,12 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.3" +version = "0.17.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e" +checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" dependencies = [ "cc", - "getrandom 0.2.9", + "getrandom 0.2.11", "libc", "spin 0.9.8", "untrusted 0.9.0", @@ -3938,7 +4028,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "446e14c5cda4f3f30fe71863c34ec70f5ac79d6087097ad0bb433e1be5edf04c" dependencies = [ "log", - "ring 0.17.3", + "ring 0.17.5", "rustls-webpki", "sct", ] @@ -3958,7 +4048,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.3", + "ring 0.17.5", "untrusted 0.9.0", ] @@ -4033,7 +4123,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.3", + "ring 0.17.5", "untrusted 0.9.0", ] @@ -4292,6 +4382,18 @@ dependencies = [ "libc", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "sketches-ddsketch" version = "0.2.1" @@ -4694,12 +4796,14 @@ name = "tabby-webserver" version = "0.7.0-dev" dependencies = [ "anyhow", + "argon2", "async-trait", "axum", "bincode", "chrono", "futures", "hyper", + "jsonwebtoken", "juniper", "juniper-axum", "lazy_static", @@ -4720,6 +4824,7 @@ dependencies = [ "tracing", "unicase", "uuid 1.4.1", + "validator", ] [[package]] @@ -5735,7 +5840,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" dependencies = [ "form_urlencoded", - "idna", + "idna 0.3.0", "percent-encoding", ] @@ -5818,7 +5923,7 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", ] [[package]] @@ -5827,7 +5932,7 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ - "getrandom 0.2.9", + "getrandom 0.2.11", "rand 0.8.5", "serde", "uuid-macro-internal", @@ -5844,6 +5949,48 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "validator" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" +dependencies = [ + "idna 0.4.0", + "lazy_static", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc44ca3088bb3ba384d9aecf40c6a23a676ce23e09bdaca2073d99c207f864af" +dependencies = [ + "if_chain", + "lazy_static", + "proc-macro-error", + "proc-macro2", + "quote", + "regex", + "syn 1.0.109", + "validator_types", +] + +[[package]] +name = "validator_types" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "111abfe30072511849c5910134e8baf8dc05de4c0e5903d681cbd5c9c4d611e3" +dependencies = [ + "proc-macro2", + "syn 1.0.109", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 71a05e9..9a37571 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -7,12 +7,14 @@ homepage.workspace = true [dependencies] anyhow.workspace = true +argon2 = "0.5.1" async-trait.workspace = true axum = { workspace = true, features = ["ws"] } bincode = "1.3.3" chrono = "0.4" futures.workspace = true hyper = { workspace = true, features=["client"]} +jsonwebtoken = "9.1.0" juniper.workspace = true juniper-axum = { path = "../../crates/juniper-axum" } lazy_static = "1.4.0" @@ -33,6 +35,7 @@ tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.4.0", features = ["fs", "trace"] } tracing.workspace = true unicase = "2.7.0" +validator = { version = "0.16.1", features = ["derive"] } [dependencies.uuid] version = "1.3.3" diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 6b67428..3a95548 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -1,10 +1,40 @@ +type RegisterResponse { + accessToken: String! + refreshToken: String! + errors: [AuthError!]! +} + +type AuthError { + message: String! + code: String! +} + enum WorkerKind { COMPLETION CHAT } type Mutation { - resetRegistrationToken: String! + resetRegistrationToken(token: String): String! + register(email: String!, password1: String!, password2: String!): RegisterResponse! + tokenAuth(email: String!, password: String!): TokenAuthResponse! + verifyToken(token: String!): VerifyTokenResponse! +} + +type UserInfo { + email: String! + isAdmin: Boolean! +} + +type VerifyTokenResponse { + errors: [AuthError!]! + claims: Claims! +} + +type Claims { + exp: Float! + iat: Float! + user: UserInfo! } type Query { @@ -23,6 +53,12 @@ type Worker { cudaDevices: [String!]! } +type TokenAuthResponse { + accessToken: String! + refreshToken: String! + errors: [AuthError!]! +} + schema { query: Query mutation: Mutation diff --git a/ee/tabby-webserver/src/db.rs b/ee/tabby-webserver/src/db.rs index 57d820c..d3e1045 100644 --- a/ee/tabby-webserver/src/db.rs +++ b/ee/tabby-webserver/src/db.rs @@ -2,14 +2,15 @@ use std::{path::PathBuf, sync::Arc}; use anyhow::Result; use lazy_static::lazy_static; -use rusqlite::params; +use rusqlite::{params, OptionalExtension}; use rusqlite_migration::{AsyncMigrations, M}; use tabby_common::path::tabby_root; use tokio_rusqlite::Connection; lazy_static! { - static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![M::up( - r#" + static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![ + M::up( + r#" CREATE TABLE IF NOT EXISTS registration_token ( id INTEGER PRIMARY KEY AUTOINCREMENT, token VARCHAR(255) NOT NULL, @@ -18,7 +19,32 @@ lazy_static! { CONSTRAINT `idx_token` UNIQUE (`token`) ); "# - ),]); + ), + M::up( + r#" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email VARCHAR(150) NOT NULL COLLATE NOCASE, + password_encrypted VARCHAR(128) NOT NULL, + is_admin BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + updated_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_email` UNIQUE (`email`) + ); + "# + ), + ]); +} + +#[allow(unused)] +pub struct User { + created_at: String, + updated_at: String, + + pub id: u32, + pub email: String, + pub password_encrypted: String, + pub is_admin: bool, } async fn db_path() -> Result { @@ -27,6 +53,7 @@ async fn db_path() -> Result { Ok(db_dir.join("db.sqlite")) } +#[derive(Clone)] pub struct DbConn { conn: Arc, } @@ -55,7 +82,10 @@ impl DbConn { conn: Arc::new(conn), }) } +} +/// db read/write operations for `registration_token` table +impl DbConn { /// Query token from database. /// Since token is global unique for each tabby server, by right there's only one row in the table. pub async fn read_registration_token(&self) -> Result { @@ -96,6 +126,56 @@ impl DbConn { } } +/// db read/write operations for `users` table +impl DbConn { + pub async fn create_user( + &self, + email: String, + password_encrypted: String, + is_admin: bool, + ) -> Result<()> { + let res = self + .conn + .call(move |c| { + c.execute( + r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, + params![email, password_encrypted, is_admin], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to create user")); + } + + Ok(()) + } + + pub async fn get_user_by_email(&self, email: &str) -> Result> { + let email = email.to_string(); + let user = self + .conn + .call(move |c| { + c.query_row( + r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE email = ?"#, + params![email], + |row| { + Ok(User { + id: row.get(0)?, + email: row.get(1)?, + password_encrypted: row.get(2)?, + is_admin: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + }) + }, + ).optional() + }) + .await?; + + Ok(user) + } +} + #[cfg(test)] mod tests { use super::*; @@ -127,4 +207,29 @@ mod tests { assert_eq!(new_token.len(), 36); assert_ne!(old_token, new_token); } + + #[tokio::test] + async fn test_create_user() { + let conn = new_in_memory().await.unwrap(); + + let email = "test@example.com"; + let passwd = "123456"; + let is_admin = true; + conn.create_user(email.to_string(), passwd.to_string(), is_admin) + .await + .unwrap(); + + let user = conn.get_user_by_email(email).await.unwrap().unwrap(); + assert_eq!(user.id, 1); + } + + #[tokio::test] + async fn test_get_user_by_email() { + let conn = new_in_memory().await.unwrap(); + + let email = "hello@example.com"; + let user = conn.get_user_by_email(email).await.unwrap(); + + assert!(user.is_none()); + } } diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs index 44e0f04..c6088b4 100644 --- a/ee/tabby-webserver/src/schema.rs +++ b/ee/tabby-webserver/src/schema.rs @@ -1,6 +1,17 @@ -use juniper::{graphql_object, EmptySubscription, FieldResult, RootNode}; +pub mod auth; -use crate::{api::Worker, server::ServerContext}; +use juniper::{ + graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode, +}; + +use crate::{ + api::Worker, + schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse}, + server::{ + auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput}, + ServerContext, + }, +}; // To make our context usable by Juniper, we have to implement a marker trait. impl juniper::Context for ServerContext {} @@ -25,9 +36,47 @@ pub struct Mutation; #[graphql_object(context = ServerContext)] impl Mutation { - async fn reset_registration_token(ctx: &ServerContext) -> FieldResult { - let token = ctx.reset_registration_token().await?; - Ok(token) + async fn reset_registration_token( + ctx: &ServerContext, + token: Option, + ) -> FieldResult { + if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) { + if claims.user_info().is_admin() { + let reg_token = ctx.reset_registration_token().await?; + return Ok(reg_token); + } + } + Err(FieldError::new( + "Only admin is able to reset registration token", + graphql_value!("Unauthorized"), + )) + } + + async fn register( + ctx: &ServerContext, + email: String, + password1: String, + password2: String, + ) -> FieldResult { + let input = RegisterInput { + email, + password1, + password2, + }; + ctx.auth().register(input).await + } + + async fn token_auth( + ctx: &ServerContext, + email: String, + password: String, + ) -> FieldResult { + let input = TokenAuthInput { email, password }; + ctx.auth().token_auth(input).await + } + + async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult { + ctx.auth().verify_token(token).await } } diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs new file mode 100644 index 0000000..1f478f8 --- /dev/null +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -0,0 +1,129 @@ +use std::fmt::Debug; + +use jsonwebtoken as jwt; +use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value}; +use serde::{Deserialize, Serialize}; +use validator::ValidationError; + +use crate::server::auth::JWT_DEFAULT_EXP; + +#[derive(Debug)] +pub struct ValidationErrors { + pub errors: Vec, +} + +impl IntoFieldError for ValidationErrors { + fn into_field_error(self) -> FieldError { + let errors = self + .errors + .into_iter() + .map(|err| { + let mut obj = Object::with_capacity(2); + obj.add_field("path", Value::scalar(err.code.to_string())); + obj.add_field( + "message", + Value::scalar(err.message.unwrap_or_default().to_string()), + ); + obj.into() + }) + .collect::>(); + let mut ext = Object::with_capacity(2); + ext.add_field("code", Value::scalar("validation-error".to_string())); + ext.add_field("errors", Value::list(errors)); + + FieldError::new("Invalid input parameters", ext.into()) + } +} + +#[derive(Debug, GraphQLObject)] +pub struct RegisterResponse { + access_token: String, + refresh_token: String, +} + +impl RegisterResponse { + pub fn new(access_token: String, refresh_token: String) -> Self { + Self { + access_token, + refresh_token, + } + } +} + +#[derive(Debug, GraphQLObject)] +pub struct TokenAuthResponse { + access_token: String, + refresh_token: String, +} + +impl TokenAuthResponse { + pub fn new(access_token: String, refresh_token: String) -> Self { + Self { + access_token, + refresh_token, + } + } +} + +#[derive(Debug, Default, GraphQLObject)] +pub struct RefreshTokenResponse { + access_token: String, + refresh_token: String, + refresh_expires_in: i32, +} + +#[derive(Debug, GraphQLObject)] +pub struct VerifyTokenResponse { + claims: Claims, +} + +impl VerifyTokenResponse { + pub fn new(claims: Claims) -> Self { + Self { claims } + } +} + +#[derive(Debug, Default, PartialEq, Serialize, Deserialize, GraphQLObject)] +pub struct UserInfo { + email: String, + is_admin: bool, +} + +impl UserInfo { + pub fn new(email: String, is_admin: bool) -> Self { + Self { email, is_admin } + } + + pub fn is_admin(&self) -> bool { + self.is_admin + } + + pub fn email(&self) -> &str { + &self.email + } +} + +#[derive(Debug, Default, Serialize, Deserialize, GraphQLObject)] +pub struct Claims { + // Required. Expiration time (as UTC timestamp) + exp: f64, + // Optional. Issued at (as UTC timestamp) + iat: f64, + // Customized. user info + user: UserInfo, +} + +impl Claims { + pub fn new(user: UserInfo) -> Self { + let now = jwt::get_current_timestamp(); + Self { + iat: now as f64, + exp: (now + *JWT_DEFAULT_EXP) as f64, + user, + } + } + + pub fn user_info(self) -> UserInfo { + self.user + } +} diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/server.rs index dc40534..07ef311 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/server.rs @@ -1,3 +1,4 @@ +pub mod auth; mod proxy; mod worker; @@ -12,6 +13,7 @@ use tracing::{info, warn}; use crate::{ api::{RegisterWorkerError, Worker, WorkerKind}, db::DbConn, + server::auth::AuthenticationService, }; pub struct ServerContext { @@ -40,6 +42,10 @@ impl ServerContext { } } + pub fn auth(&self) -> impl AuthenticationService { + self.db_conn.clone() + } + /// Query current token from the database. pub async fn read_registration_token(&self) -> Result { self.db_conn.read_registration_token().await diff --git a/ee/tabby-webserver/src/server/auth.rs b/ee/tabby-webserver/src/server/auth.rs new file mode 100644 index 0000000..8d09b50 --- /dev/null +++ b/ee/tabby-webserver/src/server/auth.rs @@ -0,0 +1,273 @@ +use std::env; + +use argon2::{ + password_hash, + password_hash::{rand_core::OsRng, SaltString}, + Argon2, PasswordHasher, PasswordVerifier, +}; +use async_trait::async_trait; +use jsonwebtoken as jwt; +use juniper::{FieldResult, IntoFieldError}; +use lazy_static::lazy_static; +use validator::Validate; + +use crate::{ + db::DbConn, + schema::auth::{ + Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, + ValidationErrors, VerifyTokenResponse, + }, +}; + +lazy_static! { + static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret( + jwt_token_secret().as_bytes() + ); + static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret( + jwt_token_secret().as_bytes() + ); + pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes +} + +/// Input parameters for register mutation +/// `validate` attribute is used to validate the input parameters +/// - `code` argument specifies which parameter causes the failure +/// - `message` argument provides client friendly error message +/// +#[derive(Validate)] +pub struct RegisterInput { + #[validate(email(code = "email", message = "Email is invalid"))] + #[validate(length( + max = 128, + code = "email", + message = "Email must be at most 128 characters" + ))] + pub email: String, + #[validate(length( + min = 8, + code = "password1", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "password1", + message = "Password must be at most 20 characters" + ))] + #[validate(must_match( + code = "password1", + message = "Passwords do not match", + other = "password2" + ))] + pub password1: String, + #[validate(length( + min = 8, + code = "password2", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "password2", + message = "Password must be at most 20 characters" + ))] + pub password2: String, +} + +impl std::fmt::Debug for RegisterInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RegisterInput") + .field("email", &self.email) + .field("password1", &"********") + .field("password2", &"********") + .finish() + } +} + +/// Input parameters for token_auth mutation +/// See `RegisterInput` for `validate` attribute usage +#[derive(Validate)] +pub struct TokenAuthInput { + #[validate(email(code = "email", message = "Email is invalid"))] + #[validate(length( + max = 128, + code = "email", + message = "Email must be at most 128 characters" + ))] + pub email: String, + #[validate(length( + min = 8, + code = "password", + message = "Password must be at least 8 characters" + ))] + #[validate(length( + max = 20, + code = "password", + message = "Password must be at most 20 characters" + ))] + pub password: String, +} + +impl std::fmt::Debug for TokenAuthInput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TokenAuthInput") + .field("email", &self.email) + .field("password", &"********") + .finish() + } +} + +#[async_trait] +pub trait AuthenticationService { + async fn register(&self, input: RegisterInput) -> FieldResult; + async fn token_auth(&self, input: TokenAuthInput) -> FieldResult; + async fn refresh_token(&self, refresh_token: String) -> FieldResult; + async fn verify_token(&self, access_token: String) -> FieldResult; +} + +#[async_trait] +impl AuthenticationService for DbConn { + async fn register(&self, input: RegisterInput) -> FieldResult { + input.validate().map_err(|err| { + let errors = err + .field_errors() + .into_iter() + .flat_map(|(_, errs)| errs) + .cloned() + .collect(); + + ValidationErrors { errors }.into_field_error() + })?; + + // check if email exists + if let Some(_) = self.get_user_by_email(&input.email).await? { + return Err("Email already exists".into()); + } + + let pwd_hash = password_hash(&input.password1)?; + + self.create_user(input.email.clone(), pwd_hash, false) + .await?; + let user = self.get_user_by_email(&input.email).await?.unwrap(); + + let access_token = generate_jwt(Claims::new(UserInfo::new( + user.email.clone(), + user.is_admin, + )))?; + + let resp = RegisterResponse::new(access_token, "".to_string()); + Ok(resp) + } + + async fn token_auth(&self, input: TokenAuthInput) -> FieldResult { + input.validate().map_err(|err| { + let errors = err + .field_errors() + .into_iter() + .flat_map(|(_, errs)| errs) + .cloned() + .collect(); + + ValidationErrors { errors }.into_field_error() + })?; + + let user = self.get_user_by_email(&input.email).await?; + + let user = match user { + Some(user) => user, + None => return Err("User not found".into()), + }; + + if !password_verify(&input.password, &user.password_encrypted) { + return Err("Password incorrect".into()); + } + + let access_token = generate_jwt(Claims::new(UserInfo::new( + user.email.clone(), + user.is_admin, + )))?; + + let resp = TokenAuthResponse::new(access_token, "".to_string()); + Ok(resp) + } + + async fn refresh_token(&self, _refresh_token: String) -> FieldResult { + Ok(RefreshTokenResponse::default()) + } + + async fn verify_token(&self, access_token: String) -> FieldResult { + let claims = validate_jwt(&access_token)?; + let resp = VerifyTokenResponse::new(claims); + Ok(resp) + } +} + +fn password_hash(raw: &str) -> password_hash::Result { + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let hash = argon2.hash_password(raw.as_bytes(), &salt)?.to_string(); + + Ok(hash) +} + +fn password_verify(raw: &str, hash: &str) -> bool { + if let Ok(parsed_hash) = argon2::PasswordHash::new(hash) { + let argon2 = Argon2::default(); + argon2.verify_password(raw.as_bytes(), &parsed_hash).is_ok() + } else { + false + } +} + +fn generate_jwt(claims: Claims) -> jwt::errors::Result { + let header = jwt::Header::default(); + let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?; + Ok(token) +} + +pub fn validate_jwt(token: &str) -> jwt::errors::Result { + let validation = jwt::Validation::default(); + let data = jwt::decode::(token, &JWT_DECODING_KEY, &validation)?; + Ok(data.claims) +} + +fn jwt_token_secret() -> String { + env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_password_hash() { + let raw = "12345678"; + let hash = password_hash(raw).unwrap(); + + assert_eq!(hash.len(), 97); + assert!(hash.starts_with("$argon2id$v=19$m=19456,t=2,p=1$")); + } + + #[test] + fn test_password_verify() { + let raw = "12345678"; + let hash = password_hash(raw).unwrap(); + + assert!(password_verify(raw, &hash)); + assert!(!password_verify(raw, "invalid hash")); + } + + #[test] + fn test_generate_jwt() { + let claims = Claims::new(UserInfo::new("test".to_string(), false)); + let token = generate_jwt(claims).unwrap(); + + assert!(!token.is_empty()) + } + + #[test] + fn test_validate_jwt() { + let claims = Claims::new(UserInfo::new("test".to_string(), false)); + let token = generate_jwt(claims).unwrap(); + let claims = validate_jwt(&token).unwrap(); + assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false)); + } +}