feat: implement refresh_token API (#938)
* feat: impl refresh token api * resolve commentadd-signin-page
parent
870638cbbf
commit
73442c33a7
|
|
@ -987,6 +987,17 @@ dependencies = [
|
|||
"nom 4.1.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cron"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ff76b51e4c068c52bfd2866e1567bee7c567ae8f24ada09fd4307019e25eab7"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"nom 7.1.3",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.8.2"
|
||||
|
|
@ -2258,7 +2269,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "51f368c9c76dde2282714ae32dc274b79c27527a0c06c816f6dda048904d0d7c"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"cron",
|
||||
"cron 0.6.1",
|
||||
"uuid 0.8.2",
|
||||
]
|
||||
|
||||
|
|
@ -2907,6 +2918,17 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
|
|
@ -4824,6 +4846,7 @@ dependencies = [
|
|||
"tarpc",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-rusqlite",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
|
|
@ -5193,6 +5216,21 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-cron-scheduler"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de2c1fd54a857b29c6cd1846f31903d0ae8e28175615c14a277aed45c58d8e27"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"cron 0.12.0",
|
||||
"num-derive",
|
||||
"num-traits",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid 1.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-io-timeout"
|
||||
version = "1.2.0"
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ tabby-common = { path = "../../crates/tabby-common" }
|
|||
tarpc = { version = "0.33.0", features = ["serde-transport"] }
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["fs"] }
|
||||
tokio-cron-scheduler = "0.9.4"
|
||||
tokio-rusqlite = "0.4.0"
|
||||
tokio-tungstenite = "0.20.1"
|
||||
tower = { version = "0.4", features = ["util"] }
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ type Mutation {
|
|||
register(email: String!, password1: String!, password2: String!, invitationCode: String): RegisterResponse!
|
||||
tokenAuth(email: String!, password: String!): TokenAuthResponse!
|
||||
verifyToken(token: String!): VerifyTokenResponse!
|
||||
refreshToken(refreshToken: String!): RefreshTokenResponse!
|
||||
createInvitation(email: String!): Int!
|
||||
deleteInvitation(id: Int!): Int!
|
||||
}
|
||||
|
|
@ -63,6 +64,12 @@ type TokenAuthResponse {
|
|||
refreshToken: String!
|
||||
}
|
||||
|
||||
type RefreshTokenResponse {
|
||||
accessToken: String!
|
||||
refreshToken: String!
|
||||
refreshExpiresAt: Float!
|
||||
}
|
||||
|
||||
schema {
|
||||
query: Query
|
||||
mutation: Mutation
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue};
|
|||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
use validator::ValidationErrors;
|
||||
|
||||
use super::from_validation_errors;
|
||||
|
|
@ -19,6 +20,7 @@ lazy_static! {
|
|||
jwt_token_secret().as_bytes()
|
||||
);
|
||||
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
|
||||
static ref JWT_REFRESH_PERIOD: i64 = 7 * 24 * 60 * 60; // 7 days
|
||||
}
|
||||
|
||||
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
|
||||
|
|
@ -37,10 +39,15 @@ fn jwt_token_secret() -> String {
|
|||
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
|
||||
}
|
||||
|
||||
pub fn generate_refresh_token(utc_ts: i64) -> (String, i64) {
|
||||
let token = Uuid::new_v4().to_string().replace('-', "");
|
||||
(token, utc_ts + *JWT_REFRESH_PERIOD)
|
||||
}
|
||||
|
||||
#[derive(Debug, GraphQLObject)]
|
||||
pub struct RegisterResponse {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
impl RegisterResponse {
|
||||
|
|
@ -82,7 +89,7 @@ impl<S: ScalarValue> IntoFieldError<S> for RegisterError {
|
|||
#[derive(Debug, GraphQLObject)]
|
||||
pub struct TokenAuthResponse {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
impl TokenAuthResponse {
|
||||
|
|
@ -127,11 +134,45 @@ impl<S: ScalarValue> IntoFieldError<S> for TokenAuthError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, GraphQLObject)]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum RefreshTokenError {
|
||||
#[error("Invalid refresh token")]
|
||||
InvalidRefreshToken,
|
||||
|
||||
#[error("Expired refresh token")]
|
||||
ExpiredRefreshToken,
|
||||
|
||||
#[error("User not found")]
|
||||
UserNotFound,
|
||||
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
|
||||
#[error("Unknown error")]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError {
|
||||
fn into_field_error(self) -> FieldError<S> {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, GraphQLObject)]
|
||||
pub struct RefreshTokenResponse {
|
||||
access_token: String,
|
||||
refresh_token: String,
|
||||
refresh_expires_in: i32,
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
pub refresh_expires_at: f64,
|
||||
}
|
||||
|
||||
impl RefreshTokenResponse {
|
||||
pub fn new(access_token: String, refresh_token: String, refresh_expires_at: f64) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
refresh_token,
|
||||
refresh_expires_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, GraphQLObject)]
|
||||
|
|
@ -215,7 +256,10 @@ pub trait AuthenticationService: Send + Sync {
|
|||
password: String,
|
||||
) -> std::result::Result<TokenAuthResponse, TokenAuthError>;
|
||||
|
||||
async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
refresh_token: String,
|
||||
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
||||
async fn is_admin_initialized(&self) -> Result<bool>;
|
||||
|
||||
|
|
@ -245,4 +289,11 @@ mod tests {
|
|||
&UserInfo::new("test".to_string(), false)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_refresh_token() {
|
||||
let (token, exp) = generate_refresh_token(100);
|
||||
assert_eq!(token.len(), 32);
|
||||
assert_eq!(exp, 100 + *JWT_REFRESH_PERIOD);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,10 @@ use self::{
|
|||
worker::WorkerService,
|
||||
};
|
||||
use crate::schema::{
|
||||
auth::{RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse},
|
||||
auth::{
|
||||
RefreshTokenError, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
|
||||
VerifyTokenResponse,
|
||||
},
|
||||
worker::Worker,
|
||||
};
|
||||
|
||||
|
|
@ -142,6 +145,13 @@ impl Mutation {
|
|||
Ok(ctx.locator.auth().verify_token(token).await?)
|
||||
}
|
||||
|
||||
async fn refresh_token(
|
||||
ctx: &Context,
|
||||
refresh_token: String,
|
||||
) -> Result<RefreshTokenResponse, RefreshTokenError> {
|
||||
ctx.locator.auth().refresh_token(refresh_token).await
|
||||
}
|
||||
|
||||
async fn create_invitation(ctx: &Context, email: String) -> Result<i32> {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.user_info().is_admin() {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ use validator::Validate;
|
|||
|
||||
use super::db::DbConn;
|
||||
use crate::schema::auth::{
|
||||
generate_jwt, validate_jwt, AuthenticationService, Claims, Invitation, RefreshTokenResponse,
|
||||
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, UserInfo,
|
||||
VerifyTokenResponse,
|
||||
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, Claims, Invitation,
|
||||
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
|
||||
TokenAuthResponse, UserInfo, VerifyTokenResponse,
|
||||
};
|
||||
|
||||
/// Input parameters for register mutation
|
||||
|
|
@ -146,6 +146,10 @@ impl AuthenticationService for DbConn {
|
|||
.await?;
|
||||
let user = self.get_user(id).await?.unwrap();
|
||||
|
||||
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp());
|
||||
self.create_refresh_token(id, &refresh_token, expires_at)
|
||||
.await?;
|
||||
|
||||
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
||||
user.email.clone(),
|
||||
user.is_admin,
|
||||
|
|
@ -153,7 +157,7 @@ impl AuthenticationService for DbConn {
|
|||
return Err(RegisterError::Unknown);
|
||||
};
|
||||
|
||||
let resp = RegisterResponse::new(access_token, "".to_string());
|
||||
let resp = RegisterResponse::new(access_token, refresh_token);
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
|
|
@ -173,6 +177,10 @@ impl AuthenticationService for DbConn {
|
|||
return Err(TokenAuthError::InvalidPassword);
|
||||
}
|
||||
|
||||
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp());
|
||||
self.create_refresh_token(user.id, &refresh_token, expires_at)
|
||||
.await?;
|
||||
|
||||
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
||||
user.email.clone(),
|
||||
user.is_admin,
|
||||
|
|
@ -180,12 +188,39 @@ impl AuthenticationService for DbConn {
|
|||
return Err(TokenAuthError::Unknown);
|
||||
};
|
||||
|
||||
let resp = TokenAuthResponse::new(access_token, "".to_string());
|
||||
let resp = TokenAuthResponse::new(access_token, refresh_token);
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, _refresh_token: String) -> Result<RefreshTokenResponse> {
|
||||
Ok(RefreshTokenResponse::default())
|
||||
async fn refresh_token(
|
||||
&self,
|
||||
token: String,
|
||||
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError> {
|
||||
let Some(refresh_token) = self.get_refresh_token(&token).await? else {
|
||||
return Err(RefreshTokenError::InvalidRefreshToken);
|
||||
};
|
||||
if refresh_token.is_expired() {
|
||||
return Err(RefreshTokenError::ExpiredRefreshToken);
|
||||
}
|
||||
let Some(user) = self.get_user(refresh_token.user_id).await? else {
|
||||
return Err(RefreshTokenError::UserNotFound);
|
||||
};
|
||||
|
||||
let (new_token, _) = generate_refresh_token(chrono::Utc::now().timestamp());
|
||||
self.replace_refresh_token(&token, &new_token).await?;
|
||||
|
||||
// refresh token update is done, generate new access token based on user info
|
||||
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
|
||||
user.email.clone(),
|
||||
user.is_admin,
|
||||
))) else {
|
||||
return Err(RefreshTokenError::Unknown);
|
||||
};
|
||||
|
||||
let resp =
|
||||
RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at as f64);
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn verify_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
||||
|
|
@ -256,7 +291,7 @@ mod tests {
|
|||
static ADMIN_EMAIL: &str = "test@example.com";
|
||||
static ADMIN_PASSWORD: &str = "123456789";
|
||||
|
||||
async fn create_admin_user(conn: &DbConn) -> i32 {
|
||||
async fn register_admin_user(conn: &DbConn) -> RegisterResponse {
|
||||
conn.register(
|
||||
ADMIN_EMAIL.to_owned(),
|
||||
ADMIN_PASSWORD.to_owned(),
|
||||
|
|
@ -264,8 +299,7 @@ mod tests {
|
|||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
1
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -277,7 +311,7 @@ mod tests {
|
|||
Err(TokenAuthError::UserNotFound)
|
||||
);
|
||||
|
||||
create_admin_user(&conn).await;
|
||||
register_admin_user(&conn).await;
|
||||
|
||||
assert_matches!(
|
||||
conn.token_auth(ADMIN_EMAIL.to_owned(), "12345678".to_owned())
|
||||
|
|
@ -285,10 +319,16 @@ mod tests {
|
|||
Err(TokenAuthError::InvalidPassword)
|
||||
);
|
||||
|
||||
assert!(conn
|
||||
let resp1 = conn
|
||||
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
|
||||
.await
|
||||
.is_ok());
|
||||
.unwrap();
|
||||
let resp2 = conn
|
||||
.token_auth(ADMIN_EMAIL.to_owned(), ADMIN_PASSWORD.to_owned())
|
||||
.await
|
||||
.unwrap();
|
||||
// each auth should generate a new refresh token
|
||||
assert_ne!(resp1.refresh_token, resp2.refresh_token);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -296,7 +336,7 @@ mod tests {
|
|||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
|
||||
assert!(!conn.is_admin_initialized().await.unwrap());
|
||||
create_admin_user(&conn).await;
|
||||
register_admin_user(&conn).await;
|
||||
|
||||
let email = "user@user.com";
|
||||
let password = "12345678";
|
||||
|
|
@ -351,4 +391,23 @@ mod tests {
|
|||
Err(RegisterError::DuplicateEmail)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_token() {
|
||||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
let reg = register_admin_user(&conn).await;
|
||||
|
||||
let resp1 = conn.refresh_token(reg.refresh_token.clone()).await.unwrap();
|
||||
// new access token should be valid
|
||||
assert!(validate_jwt(&resp1.access_token).is_ok());
|
||||
// refresh token should be renewed
|
||||
assert_ne!(reg.refresh_token, resp1.refresh_token);
|
||||
|
||||
let resp2 = conn
|
||||
.refresh_token(resp1.refresh_token.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
// expire time should be no change
|
||||
assert_eq!(resp1.refresh_expires_at, resp2.refresh_expires_at);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio_cron_scheduler::{Job, JobScheduler};
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::service::db::DbConn;
|
||||
|
||||
async fn new_job_scheduler(jobs: Vec<Job>) -> Result<JobScheduler> {
|
||||
let scheduler = JobScheduler::new().await?;
|
||||
for job in jobs {
|
||||
scheduler.add(job).await?;
|
||||
}
|
||||
scheduler.start().await?;
|
||||
Ok(scheduler)
|
||||
}
|
||||
|
||||
async fn new_refresh_token_job(db_conn: DbConn) -> Result<Job> {
|
||||
// job is run every 2 hours
|
||||
let job = Job::new_async("0 0 1/2 * * * *", move |_, _| {
|
||||
let utc_ts = chrono::Utc::now().timestamp();
|
||||
let db_conn = db_conn.clone();
|
||||
Box::pin(async move {
|
||||
let res = db_conn.delete_expired_token(utc_ts).await;
|
||||
if let Err(e) = res {
|
||||
error!("failed to delete expired token: {}", e);
|
||||
}
|
||||
})
|
||||
})?;
|
||||
|
||||
Ok(job)
|
||||
}
|
||||
|
||||
pub fn run_offline_job(db_conn: DbConn) {
|
||||
tokio::spawn(async move {
|
||||
let Ok(job) = new_refresh_token_job(db_conn.clone()).await else {
|
||||
error!("failed to create db job");
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(mut scheduler) = new_job_scheduler(vec![job]).await else {
|
||||
error!("failed to start job scheduler");
|
||||
return;
|
||||
};
|
||||
|
||||
loop {
|
||||
match scheduler.time_till_next_job().await {
|
||||
Ok(Some(duration)) => {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
Ok(None) => {
|
||||
warn!("no job available, exit scheduler");
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to get job sleep time: {}, re-try in 1 second", e);
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -8,7 +8,7 @@ use tabby_common::path::tabby_root;
|
|||
use tokio_rusqlite::Connection;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::schema::auth::Invitation;
|
||||
use crate::{schema::auth::Invitation, service::cron::run_offline_job};
|
||||
|
||||
lazy_static! {
|
||||
static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![
|
||||
|
|
@ -51,6 +51,19 @@ lazy_static! {
|
|||
"#
|
||||
)
|
||||
.down("DROP TABLE invitations"),
|
||||
M::up(
|
||||
r#"
|
||||
CREATE TABLE refresh_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token VARCHAR(255) NOT NULL COLLATE NOCASE,
|
||||
expires_at INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT (DATETIME('now')),
|
||||
CONSTRAINT `idx_token` UNIQUE (`token`)
|
||||
);
|
||||
"#
|
||||
)
|
||||
.down("DROP TABLE refresh_tokens"),
|
||||
]);
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +72,7 @@ pub struct User {
|
|||
created_at: String,
|
||||
updated_at: String,
|
||||
|
||||
pub id: u32,
|
||||
pub id: i32,
|
||||
pub email: String,
|
||||
pub password_encrypted: String,
|
||||
pub is_admin: bool,
|
||||
|
|
@ -121,9 +134,12 @@ impl DbConn {
|
|||
})
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
let res = Self {
|
||||
conn: Arc::new(conn),
|
||||
})
|
||||
};
|
||||
run_offline_job(res.clone());
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -309,6 +325,114 @@ impl DbConn {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub struct RefreshToken {
|
||||
id: u32,
|
||||
created_at: String,
|
||||
|
||||
pub user_id: i32,
|
||||
pub token: String,
|
||||
pub expires_at: i64,
|
||||
}
|
||||
|
||||
impl RefreshToken {
|
||||
fn select(clause: &str) -> String {
|
||||
r#"SELECT id, user_id, token, expires_at, created_at FROM refresh_tokens WHERE "#.to_owned()
|
||||
+ clause
|
||||
}
|
||||
|
||||
fn from_row(row: &Row<'_>) -> std::result::Result<RefreshToken, rusqlite::Error> {
|
||||
Ok(RefreshToken {
|
||||
id: row.get(0)?,
|
||||
user_id: row.get(1)?,
|
||||
token: row.get(2)?,
|
||||
expires_at: row.get(3)?,
|
||||
created_at: row.get(4)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_expired(&self) -> bool {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
self.expires_at < now
|
||||
}
|
||||
}
|
||||
|
||||
/// db read/write operations for `refresh_tokens` table
|
||||
impl DbConn {
|
||||
pub async fn create_refresh_token(
|
||||
&self,
|
||||
user_id: i32,
|
||||
token: &str,
|
||||
expires_at: i64,
|
||||
) -> Result<()> {
|
||||
let token = token.to_string();
|
||||
let res = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
c.execute(
|
||||
r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)"#,
|
||||
params![user_id, token, expires_at],
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
if res != 1 {
|
||||
return Err(anyhow::anyhow!("failed to create refresh token"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn replace_refresh_token(&self, old: &str, new: &str) -> Result<()> {
|
||||
let old = old.to_string();
|
||||
let new = new.to_string();
|
||||
let res = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
c.execute(
|
||||
r#"UPDATE refresh_tokens SET token = ? WHERE token = ?"#,
|
||||
params![new, old],
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
if res != 1 {
|
||||
return Err(anyhow::anyhow!("failed to replace refresh token"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_expired_token(&self, utc_ts: i64) -> Result<i32> {
|
||||
let res = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
c.execute(
|
||||
r#"DELETE FROM refresh_tokens WHERE expires_at < ?"#,
|
||||
params![utc_ts],
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(res as i32)
|
||||
}
|
||||
|
||||
pub async fn get_refresh_token(&self, token: &str) -> Result<Option<RefreshToken>> {
|
||||
let token = token.to_string();
|
||||
let token = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
c.query_row(
|
||||
RefreshToken::select("token = ?").as_str(),
|
||||
params![token],
|
||||
RefreshToken::from_row,
|
||||
)
|
||||
.optional()
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
|
@ -398,4 +522,33 @@ mod tests {
|
|||
let invitations = conn.list_invitations().await.unwrap();
|
||||
assert!(invitations.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_refresh_token() {
|
||||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
|
||||
conn.create_refresh_token(1, "test", 100).await.unwrap();
|
||||
|
||||
let token = conn.get_refresh_token("test").await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(token.user_id, 1);
|
||||
assert_eq!(token.token, "test");
|
||||
assert_eq!(token.expires_at, 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_replace_refresh_token() {
|
||||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
|
||||
conn.create_refresh_token(1, "test", 100).await.unwrap();
|
||||
conn.replace_refresh_token("test", "test2").await.unwrap();
|
||||
|
||||
let token = conn.get_refresh_token("test").await.unwrap();
|
||||
assert!(token.is_none());
|
||||
|
||||
let token = conn.get_refresh_token("test2").await.unwrap().unwrap();
|
||||
assert_eq!(token.user_id, 1);
|
||||
assert_eq!(token.token, "test2");
|
||||
assert_eq!(token.expires_at, 100);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
mod auth;
|
||||
mod cron;
|
||||
mod db;
|
||||
mod proxy;
|
||||
mod worker;
|
||||
|
|
|
|||
Loading…
Reference in New Issue