feat: implement refresh_token API (#938)

* feat: impl refresh token api

* resolve comment
add-signin-page
Eric 2023-12-04 13:26:24 +08:00 committed by GitHub
parent 870638cbbf
commit 73442c33a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 409 additions and 27 deletions

40
Cargo.lock generated
View File

@ -987,6 +987,17 @@ dependencies = [
"nom 4.1.1",
]
[[package]]
name = "cron"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ff76b51e4c068c52bfd2866e1567bee7c567ae8f24ada09fd4307019e25eab7"
dependencies = [
"chrono",
"nom 7.1.3",
"once_cell",
]
[[package]]
name = "crossbeam"
version = "0.8.2"
@ -2258,7 +2269,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51f368c9c76dde2282714ae32dc274b79c27527a0c06c816f6dda048904d0d7c"
dependencies = [
"chrono",
"cron",
"cron 0.6.1",
"uuid 0.8.2",
]
@ -2907,6 +2918,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "num-integer"
version = "0.1.45"
@ -4824,6 +4846,7 @@ dependencies = [
"tarpc",
"thiserror",
"tokio",
"tokio-cron-scheduler",
"tokio-rusqlite",
"tokio-tungstenite",
"tower",
@ -5193,6 +5216,21 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "tokio-cron-scheduler"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de2c1fd54a857b29c6cd1846f31903d0ae8e28175615c14a277aed45c58d8e27"
dependencies = [
"chrono",
"cron 0.12.0",
"num-derive",
"num-traits",
"tokio",
"tracing",
"uuid 1.4.1",
]
[[package]]
name = "tokio-io-timeout"
version = "1.2.0"

View File

@ -29,6 +29,7 @@ tabby-common = { path = "../../crates/tabby-common" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["fs"] }
tokio-cron-scheduler = "0.9.4"
tokio-rusqlite = "0.4.0"
tokio-tungstenite = "0.20.1"
tower = { version = "0.4", features = ["util"] }

View File

@ -13,6 +13,7 @@ type Mutation {
register(email: String!, password1: String!, password2: String!, invitationCode: String): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse!
refreshToken(refreshToken: String!): RefreshTokenResponse!
createInvitation(email: String!): Int!
deleteInvitation(id: Int!): Int!
}
@ -63,6 +64,12 @@ type TokenAuthResponse {
refreshToken: String!
}
type RefreshTokenResponse {
accessToken: String!
refreshToken: String!
refreshExpiresAt: Float!
}
schema {
query: Query
mutation: Mutation

View File

@ -7,6 +7,7 @@ use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
use validator::ValidationErrors;
use super::from_validation_errors;
@ -19,6 +20,7 @@ lazy_static! {
jwt_token_secret().as_bytes()
);
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
static ref JWT_REFRESH_PERIOD: i64 = 7 * 24 * 60 * 60; // 7 days
}
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
@ -37,10 +39,15 @@ fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}
pub fn generate_refresh_token(utc_ts: i64) -> (String, i64) {
let token = Uuid::new_v4().to_string().replace('-', "");
(token, utc_ts + *JWT_REFRESH_PERIOD)
}
#[derive(Debug, GraphQLObject)]
pub struct RegisterResponse {
access_token: String,
refresh_token: String,
pub refresh_token: String,
}
impl RegisterResponse {
@ -82,7 +89,7 @@ impl<S: ScalarValue> IntoFieldError<S> for RegisterError {
#[derive(Debug, GraphQLObject)]
pub struct TokenAuthResponse {
access_token: String,
refresh_token: String,
pub refresh_token: String,
}
impl TokenAuthResponse {
@ -127,11 +134,45 @@ impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError {
}
}
#[derive(Debug, Default, GraphQLObject)]
#[derive(Error, Debug)]
pub enum RefreshTokenError {
#[error("Invalid refresh token")]
InvalidRefreshToken,
#[error("Expired refresh token")]
ExpiredRefreshToken,
#[error("User not found")]
UserNotFound,
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("Unknown error")]
Unknown,
}
impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError {
fn into_field_error(self) -> FieldError<S> {
self.into()
}
}
#[derive(Debug, GraphQLObject)]
pub struct RefreshTokenResponse {
access_token: String,
refresh_token: String,
refresh_expires_in: i32,
pub access_token: String,
pub refresh_token: String,
pub refresh_expires_at: f64,
}
impl RefreshTokenResponse {
pub fn new(access_token: String, refresh_token: String, refresh_expires_at: f64) -> Self {
Self {
access_token,
refresh_token,
refresh_expires_at,
}
}
}
#[derive(Debug, GraphQLObject)]
@ -215,7 +256,10 @@ pub trait AuthenticationService: Send + Sync {
password: String,
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;
async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
async fn refresh_token(
&self,
refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>;
@ -245,4 +289,11 @@ mod tests {
&UserInfo::new("test".to_string(), false)
);
}
#[test]
fn test_generate_refresh_token() {
let (token, exp) = generate_refresh_token(100);
assert_eq!(token.len(), 32);
assert_eq!(exp, 100 + *JWT_REFRESH_PERIOD);
}
}

View File

@ -17,7 +17,10 @@ use self::{
worker::WorkerService,
};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse},
auth::{
RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
VerifyTokenResponse,
},
worker::Worker,
};
@ -142,6 +145,13 @@ impl Mutation {
Ok(ctx.locator.auth().verify_token(token).await?)
}
async fn refresh_token(
ctx: &Context,
refresh_token: String,
) -> Result<RefreshTokenResponse, RefreshTokenError> {
ctx.locator.auth().refresh_token(refresh_token).await
}
async fn create_invitation(ctx: &Context, email: String) -> Result<i32> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {

View File

@ -9,9 +9,9 @@ use validator::Validate;
use super::db::DbConn;
use crate::schema::auth::{
generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse,
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo,
VerifyTokenResponse,
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Claims, Invitation,
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
TokenAuthResponse, UserInfo, VerifyTokenResponse,
};
/// Input parameters for register mutation
@ -146,6 +146,10 @@ impl AuthenticationService for DbConn {
.await?;
let user = self.get_user(id).await?.unwrap();
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp());
self.create_refresh_token(id, &refresh_token, expires_at)
.await?;
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
user.is_admin,
@ -153,7 +157,7 @@ impl AuthenticationService for DbConn {
return Err(RegisterError::Unknown);
};
let resp = RegisterResponse::new(access_token, "".to_string());
let resp = RegisterResponse::new(access_token, refresh_token);
Ok(resp)
}
@ -173,6 +177,10 @@ impl AuthenticationService for DbConn {
return Err(TokenAuthError::InvalidPassword);
}
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp());
self.create_refresh_token(user.id, &refresh_token, expires_at)
.await?;
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
user.is_admin,
@ -180,12 +188,39 @@ impl AuthenticationService for DbConn {
return Err(TokenAuthError::Unknown);
};
let resp = TokenAuthResponse::new(access_token, "".to_string());
let resp = TokenAuthResponse::new(access_token, refresh_token);
Ok(resp)
}
async fn refresh_token(&self, _refresh_token: String) -> Result<RefreshTokenResponse> {
Ok(RefreshTokenResponse::default())
async fn refresh_token(
&self,
token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> {
let Some(refresh_token) = self.get_refresh_token(&token).await? else {
return Err(RefreshTokenError::InvalidRefreshToken);
};
if refresh_token.is_expired() {
return Err(RefreshTokenError::ExpiredRefreshToken);
}
let Some(user) = self.get_user(refresh_token.user_id).await? else {
return Err(RefreshTokenError::UserNotFound);
};
let (new_token, _) = generate_refresh_token(chrono::Utc::now().timestamp());
self.replace_refresh_token(&token, &new_token).await?;
// refresh token update is done, generate new access token based on user info
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(),
user.is_admin,
))) else {
return Err(RefreshTokenError::Unknown);
};
let resp =
RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at as f64);
Ok(resp)
}
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
@ -256,7 +291,7 @@ mod tests {
static ADMIN_EMAIL: &str = "test@example.com";
static ADMIN_PASSWORD: &str = "123456789";
async fn create_admin_user(conn: &DbConn) -> i32 {
async fn register_admin_user(conn: &DbConn) -> RegisterResponse {
conn.register(
ADMIN_EMAIL.to_owned(),
ADMIN_PASSWORD.to_owned(),
@ -264,8 +299,7 @@ mod tests {
None,
)
.await
.unwrap();
1
.unwrap()
}
#[tokio::test]
@ -277,7 +311,7 @@ mod tests {
Err(TokenAuthError::UserNotFound)
);
create_admin_user(&conn).await;
register_admin_user(&conn).await;
assert_matches!(
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
@ -285,10 +319,16 @@ mod tests {
Err(TokenAuthError::InvalidPassword)
);
assert!(conn
let resp1 = conn
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
.await
.is_ok());
.unwrap();
let resp2 = conn
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
.await
.unwrap();
// each auth should generate a new refresh token
assert_ne!(resp1.refresh_token, resp2.refresh_token);
}
#[tokio::test]
@ -296,7 +336,7 @@ mod tests {
let conn = DbConn::new_in_memory().await.unwrap();
assert!(!conn.is_admin_initialized().await.unwrap());
create_admin_user(&conn).await;
register_admin_user(&conn).await;
let email = "user@user.com";
let password = "12345678";
@ -351,4 +391,23 @@ mod tests {
Err(RegisterError::DuplicateEmail)
);
}
#[tokio::test]
async fn test_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap();
let reg = register_admin_user(&conn).await;
let resp1 = conn.refresh_token(reg.refresh_token.clone()).await.unwrap();
// new access token should be valid
assert!(validate_jwt(&resp1.access_token).is_ok());
// refresh token should be renewed
assert_ne!(reg.refresh_token, resp1.refresh_token);
let resp2 = conn
.refresh_token(resp1.refresh_token.clone())
.await
.unwrap();
// expire time should be no change
assert_eq!(resp1.refresh_expires_at, resp2.refresh_expires_at);
}
}

View File

@ -0,0 +1,62 @@
use std::time::Duration;
use anyhow::Result;
use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{error, warn};
use crate::service::db::DbConn;
async fn new_job_scheduler(jobs: Vec<Job>) -> Result<JobScheduler> {
let scheduler = JobScheduler::new().await?;
for job in jobs {
scheduler.add(job).await?;
}
scheduler.start().await?;
Ok(scheduler)
}
async fn new_refresh_token_job(db_conn: DbConn) -> Result<Job> {
// job is run every 2 hours
let job = Job::new_async("0 0 1/2 * * * *", move |_, _| {
let utc_ts = chrono::Utc::now().timestamp();
let db_conn = db_conn.clone();
Box::pin(async move {
let res = db_conn.delete_expired_token(utc_ts).await;
if let Err(e) = res {
error!("failed to delete expired token: {}", e);
}
})
})?;
Ok(job)
}
pub fn run_offline_job(db_conn: DbConn) {
tokio::spawn(async move {
let Ok(job) = new_refresh_token_job(db_conn.clone()).await else {
error!("failed to create db job");
return;
};
let Ok(mut scheduler) = new_job_scheduler(vec![job]).await else {
error!("failed to start job scheduler");
return;
};
loop {
match scheduler.time_till_next_job().await {
Ok(Some(duration)) => {
tokio::time::sleep(duration).await;
}
Ok(None) => {
warn!("no job available, exit scheduler");
return;
}
Err(e) => {
error!("failed to get job sleep time: {}, re-try in 1 second", e);
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
}

View File

@ -8,7 +8,7 @@ use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;
use uuid::Uuid;
use crate::schema::auth::Invitation;
use crate::{schema::auth::Invitation, service::cron::run_offline_job};
lazy_static! {
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![
@ -51,6 +51,19 @@ lazy_static! {
"#
)
.down("DROP TABLE invitations"),
M::up(
r#"
CREATE TABLE refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token VARCHAR(255) NOT NULL COLLATE NOCASE,
expires_at INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_token` UNIQUE (`token`)
);
"#
)
.down("DROP TABLE refresh_tokens"),
]);
}
@ -59,7 +72,7 @@ pub struct User {
created_at: String,
updated_at: String,
pub id: u32,
pub id: i32,
pub email: String,
pub password_encrypted: String,
pub is_admin: bool,
@ -121,9 +134,12 @@ impl DbConn {
})
.await?;
Ok(Self {
let res = Self {
conn: Arc::new(conn),
})
};
run_offline_job(res.clone());
Ok(res)
}
}
@ -309,6 +325,114 @@ impl DbConn {
}
}
#[allow(unused)]
pub struct RefreshToken {
id: u32,
created_at: String,
pub user_id: i32,
pub token: String,
pub expires_at: i64,
}
impl RefreshToken {
fn select(clause: &str) -> String {
r#"SELECT id, user_id, token, expires_at, created_at FROM refresh_tokens WHERE "#.to_owned()
+ clause
}
fn from_row(row: &Row<'_>) -> std::result::Result<RefreshToken, rusqlite::Error> {
Ok(RefreshToken {
id: row.get(0)?,
user_id: row.get(1)?,
token: row.get(2)?,
expires_at: row.get(3)?,
created_at: row.get(4)?,
})
}
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.expires_at < now
}
}
/// db read/write operations for `refresh_tokens` table
impl DbConn {
pub async fn create_refresh_token(
&self,
user_id: i32,
token: &str,
expires_at: i64,
) -> Result<()> {
let token = token.to_string();
let res = self
.conn
.call(move |c| {
c.execute(
r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)"#,
params![user_id, token, expires_at],
)
})
.await?;
if res != 1 {
return Err(anyhow::anyhow!("failed to create refresh token"));
}
Ok(())
}
pub async fn replace_refresh_token(&self, old: &str, new: &str) -> Result<()> {
let old = old.to_string();
let new = new.to_string();
let res = self
.conn
.call(move |c| {
c.execute(
r#"UPDATE refresh_tokens SET token = ? WHERE token = ?"#,
params![new, old],
)
})
.await?;
if res != 1 {
return Err(anyhow::anyhow!("failed to replace refresh token"));
}
Ok(())
}
pub async fn delete_expired_token(&self, utc_ts: i64) -> Result<i32> {
let res = self
.conn
.call(move |c| {
c.execute(
r#"DELETE FROM refresh_tokens WHERE expires_at < ?"#,
params![utc_ts],
)
})
.await?;
Ok(res as i32)
}
pub async fn get_refresh_token(&self, token: &str) -> Result<Option<RefreshToken>> {
let token = token.to_string();
let token = self
.conn
.call(move |c| {
c.query_row(
RefreshToken::select("token = ?").as_str(),
params![token],
RefreshToken::from_row,
)
.optional()
})
.await?;
Ok(token)
}
}
#[cfg(test)]
mod tests {
@ -398,4 +522,33 @@ mod tests {
let invitations = conn.list_invitations().await.unwrap();
assert!(invitations.is_empty());
}
#[tokio::test]
async fn test_create_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap();
conn.create_refresh_token(1, "test", 100).await.unwrap();
let token = conn.get_refresh_token("test").await.unwrap().unwrap();
assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test");
assert_eq!(token.expires_at, 100);
}
#[tokio::test]
async fn test_replace_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap();
conn.create_refresh_token(1, "test", 100).await.unwrap();
conn.replace_refresh_token("test", "test2").await.unwrap();
let token = conn.get_refresh_token("test").await.unwrap();
assert!(token.is_none());
let token = conn.get_refresh_token("test2").await.unwrap().unwrap();
assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test2");
assert_eq!(token.expires_at, 100);
}
}

View File

@ -1,4 +1,5 @@
mod auth;
mod cron;
mod db;
mod proxy;
mod worker;