refactor: Use DateTime<Utc> for sqlite datetime fields (#946)

* refactor: use DateTime<Utc> for RefreshToken.expires_at

* refactor: set other date time fields to be DateTime<Utc>

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
add-signin-page
Meng Zhang 2023-12-05 23:47:10 +08:00 committed by GitHub
parent 73442c33a7
commit 74f81cb02a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 36 deletions

1
Cargo.lock generated
View File

@ -3868,6 +3868,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
dependencies = [ dependencies = [
"bitflags 2.4.0", "bitflags 2.4.0",
"chrono",
"fallible-iterator", "fallible-iterator",
"fallible-streaming-iterator", "fallible-streaming-iterator",
"hashlink", "hashlink",

View File

@ -20,7 +20,7 @@ juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0" lazy_static = "1.4.0"
mime_guess = "2.0.4" mime_guess = "2.0.4"
pin-project = "1.1.3" pin-project = "1.1.3"
rusqlite = { version = "0.29.0", features = ["bundled"] } rusqlite = { version = "0.29.0", features = ["bundled", "chrono"] }
# `async-tokio-rusqlite` is only available from 1.1.0-alpha.2, will bump up version when it's stable # `async-tokio-rusqlite` is only available from 1.1.0-alpha.2, will bump up version when it's stable
rusqlite_migration = { version = "1.1.0-alpha.2", features = ["async-tokio-rusqlite"] } rusqlite_migration = { version = "1.1.0-alpha.2", features = ["async-tokio-rusqlite"] }
rust-embed = "8.0.0" rust-embed = "8.0.0"

View File

@ -2,6 +2,7 @@ use std::fmt::Debug;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc};
use jsonwebtoken as jwt; use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue}; use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue};
use lazy_static::lazy_static; use lazy_static::lazy_static;
@ -20,7 +21,6 @@ lazy_static! {
jwt_token_secret().as_bytes() jwt_token_secret().as_bytes()
); );
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
static ref JWT_REFRESH_PERIOD: i64 = 7 * 24 * 60 * 60; // 7 days
} }
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> { pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
@ -39,9 +39,8 @@ fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string()) std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
} }
pub fn generate_refresh_token(utc_ts: i64) -> (String, i64) { pub fn generate_refresh_token() -> String {
let token = Uuid::new_v4().to_string().replace('-', ""); Uuid::new_v4().to_string().replace('-', "")
(token, utc_ts + *JWT_REFRESH_PERIOD)
} }
#[derive(Debug, GraphQLObject)] #[derive(Debug, GraphQLObject)]
@ -162,11 +161,15 @@ impl<S: ScalarValue> IntoFieldError<S> for RefreshTokenError {
pub struct RefreshTokenResponse { pub struct RefreshTokenResponse {
pub access_token: String, pub access_token: String,
pub refresh_token: String, pub refresh_token: String,
pub refresh_expires_at: f64, pub refresh_expires_at: DateTime<Utc>,
} }
impl RefreshTokenResponse { impl RefreshTokenResponse {
pub fn new(access_token: String, refresh_token: String, refresh_expires_at: f64) -> Self { pub fn new(
access_token: String,
refresh_token: String,
refresh_expires_at: DateTime<Utc>,
) -> Self {
Self { Self {
access_token, access_token,
refresh_token, refresh_token,
@ -292,8 +295,7 @@ mod tests {
#[test] #[test]
fn test_generate_refresh_token() { fn test_generate_refresh_token() {
let (token, exp) = generate_refresh_token(100); let token = generate_refresh_token();
assert_eq!(token.len(), 32); assert_eq!(token.len(), 32);
assert_eq!(exp, 100 + *JWT_REFRESH_PERIOD);
} }
} }

View File

@ -146,9 +146,8 @@ impl AuthenticationService for DbConn {
.await?; .await?;
let user = self.get_user(id).await?.unwrap(); let user = self.get_user(id).await?.unwrap();
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp()); let refresh_token = generate_refresh_token();
self.create_refresh_token(id, &refresh_token, expires_at) self.create_refresh_token(id, &refresh_token).await?;
.await?;
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(), user.email.clone(),
@ -177,9 +176,8 @@ impl AuthenticationService for DbConn {
return Err(TokenAuthError::InvalidPassword); return Err(TokenAuthError::InvalidPassword);
} }
let (refresh_token, expires_at) = generate_refresh_token(chrono::Utc::now().timestamp()); let refresh_token = generate_refresh_token();
self.create_refresh_token(user.id, &refresh_token, expires_at) self.create_refresh_token(user.id, &refresh_token).await?;
.await?;
let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new( let Ok(access_token) = generate_jwt(Claims::new(UserInfo::new(
user.email.clone(), user.email.clone(),
@ -206,7 +204,7 @@ impl AuthenticationService for DbConn {
return Err(RefreshTokenError::UserNotFound); return Err(RefreshTokenError::UserNotFound);
}; };
let (new_token, _) = generate_refresh_token(chrono::Utc::now().timestamp()); let new_token = generate_refresh_token();
self.replace_refresh_token(&token, &new_token).await?; self.replace_refresh_token(&token, &new_token).await?;
// refresh token update is done, generate new access token based on user info // refresh token update is done, generate new access token based on user info
@ -217,8 +215,7 @@ impl AuthenticationService for DbConn {
return Err(RefreshTokenError::Unknown); return Err(RefreshTokenError::Unknown);
}; };
let resp = let resp = RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at);
RefreshTokenResponse::new(access_token, new_token, refresh_token.expires_at as f64);
Ok(resp) Ok(resp)
} }

View File

@ -1,6 +1,7 @@
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use rusqlite::{params, OptionalExtension, Row}; use rusqlite::{params, OptionalExtension, Row};
use rusqlite_migration::{AsyncMigrations, M}; use rusqlite_migration::{AsyncMigrations, M};
@ -57,7 +58,7 @@ lazy_static! {
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
token VARCHAR(255) NOT NULL COLLATE NOCASE, token VARCHAR(255) NOT NULL COLLATE NOCASE,
expires_at INTEGER NOT NULL, expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT (DATETIME('now')), created_at TIMESTAMP DEFAULT (DATETIME('now')),
CONSTRAINT `idx_token` UNIQUE (`token`) CONSTRAINT `idx_token` UNIQUE (`token`)
); );
@ -69,8 +70,8 @@ lazy_static! {
#[allow(unused)] #[allow(unused)]
pub struct User { pub struct User {
created_at: String, created_at: DateTime<Utc>,
updated_at: String, updated_at: DateTime<Utc>,
pub id: i32, pub id: i32,
pub email: String, pub email: String,
@ -328,11 +329,11 @@ impl DbConn {
#[allow(unused)] #[allow(unused)]
pub struct RefreshToken { pub struct RefreshToken {
id: u32, id: u32,
created_at: String, created_at: DateTime<Utc>,
pub user_id: i32, pub user_id: i32,
pub token: String, pub token: String,
pub expires_at: i64, pub expires_at: DateTime<Utc>,
} }
impl RefreshToken { impl RefreshToken {
@ -352,26 +353,21 @@ impl RefreshToken {
} }
pub fn is_expired(&self) -> bool { pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp(); let now = chrono::Utc::now();
self.expires_at < now self.expires_at < now
} }
} }
/// db read/write operations for `refresh_tokens` table /// db read/write operations for `refresh_tokens` table
impl DbConn { impl DbConn {
pub async fn create_refresh_token( pub async fn create_refresh_token(&self, user_id: i32, token: &str) -> Result<()> {
&self,
user_id: i32,
token: &str,
expires_at: i64,
) -> Result<()> {
let token = token.to_string(); let token = token.to_string();
let res = self let res = self
.conn .conn
.call(move |c| { .call(move |c| {
c.execute( c.execute(
r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)"#, r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, datetime('now', '+7 days'))"#,
params![user_id, token, expires_at], params![user_id, token],
) )
}) })
.await?; .await?;
@ -436,6 +432,8 @@ impl DbConn {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::ops::Add;
use super::*; use super::*;
use crate::schema::auth::AuthenticationService; use crate::schema::auth::AuthenticationService;
@ -527,20 +525,21 @@ mod tests {
async fn test_create_refresh_token() { async fn test_create_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap(); let conn = DbConn::new_in_memory().await.unwrap();
conn.create_refresh_token(1, "test", 100).await.unwrap(); conn.create_refresh_token(1, "test").await.unwrap();
let token = conn.get_refresh_token("test").await.unwrap().unwrap(); let token = conn.get_refresh_token("test").await.unwrap().unwrap();
assert_eq!(token.user_id, 1); assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test"); assert_eq!(token.token, "test");
assert_eq!(token.expires_at, 100); assert!(token.expires_at > Utc::now().add(chrono::Duration::days(6)));
assert!(token.expires_at < Utc::now().add(chrono::Duration::days(7)));
} }
#[tokio::test] #[tokio::test]
async fn test_replace_refresh_token() { async fn test_replace_refresh_token() {
let conn = DbConn::new_in_memory().await.unwrap(); let conn = DbConn::new_in_memory().await.unwrap();
conn.create_refresh_token(1, "test", 100).await.unwrap(); conn.create_refresh_token(1, "test").await.unwrap();
conn.replace_refresh_token("test", "test2").await.unwrap(); conn.replace_refresh_token("test", "test2").await.unwrap();
let token = conn.get_refresh_token("test").await.unwrap(); let token = conn.get_refresh_token("test").await.unwrap();
@ -549,6 +548,5 @@ mod tests {
let token = conn.get_refresh_token("test2").await.unwrap().unwrap(); let token = conn.get_refresh_token("test2").await.unwrap().unwrap();
assert_eq!(token.user_id, 1); assert_eq!(token.user_id, 1);
assert_eq!(token.token, "test2"); assert_eq!(token.token, "test2");
assert_eq!(token.expires_at, 100);
} }
} }