From 2714b88878ea6b0992fc0645b126e20a320f430f Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 8 Dec 2023 13:25:24 +0800 Subject: [PATCH] refactor(webserver): split db.rs into sub implementations for each table (#985) --- ee/tabby-webserver/src/service/db.rs | 552 ------------------ .../src/service/db/invitations.rs | 111 ++++ ee/tabby-webserver/src/service/db/mod.rs | 219 +++++++ .../src/service/db/refresh_tokens.rs | 145 +++++ ee/tabby-webserver/src/service/db/users.rs | 126 ++++ 5 files changed, 601 insertions(+), 552 deletions(-) delete mode 100644 ee/tabby-webserver/src/service/db.rs create mode 100644 ee/tabby-webserver/src/service/db/invitations.rs create mode 100644 ee/tabby-webserver/src/service/db/mod.rs create mode 100644 ee/tabby-webserver/src/service/db/refresh_tokens.rs create mode 100644 ee/tabby-webserver/src/service/db/users.rs diff --git a/ee/tabby-webserver/src/service/db.rs b/ee/tabby-webserver/src/service/db.rs deleted file mode 100644 index 4250e63..0000000 --- a/ee/tabby-webserver/src/service/db.rs +++ /dev/null @@ -1,552 +0,0 @@ -use std::{path::PathBuf, sync::Arc}; - -use anyhow::{anyhow, Result}; -use chrono::{DateTime, Utc}; -use lazy_static::lazy_static; -use rusqlite::{params, OptionalExtension, Row}; -use rusqlite_migration::{AsyncMigrations, M}; -use tabby_common::path::tabby_root; -use tokio_rusqlite::Connection; -use uuid::Uuid; - -use crate::{schema::auth::Invitation, service::cron::run_offline_job}; - -lazy_static! { - static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![ - M::up( - r#" - CREATE TABLE registration_token ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - token VARCHAR(255) NOT NULL, - created_at TIMESTAMP DEFAULT (DATETIME('now')), - updated_at TIMESTAMP DEFAULT (DATETIME('now')), - CONSTRAINT `idx_token` UNIQUE (`token`) - ); - "# - ) - .down("DROP TABLE registration_token"), - M::up( - r#" - CREATE TABLE users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - email VARCHAR(150) NOT NULL COLLATE NOCASE, - password_encrypted VARCHAR(128) NOT NULL, - is_admin BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP DEFAULT (DATETIME('now')), - updated_at TIMESTAMP DEFAULT (DATETIME('now')), - CONSTRAINT `idx_email` UNIQUE (`email`) - ); - "# - ) - .down("DROP TABLE users"), - M::up( - r#" - CREATE TABLE invitations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - email VARCHAR(150) NOT NULL COLLATE NOCASE, - code VARCHAR(36) NOT NULL, - created_at TIMESTAMP DEFAULT (DATETIME('now')), - CONSTRAINT `idx_email` UNIQUE (`email`) - CONSTRAINT `idx_code` UNIQUE (`code`) - ); - "# - ) - .down("DROP TABLE invitations"), - M::up( - r#" - CREATE TABLE refresh_tokens ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - token VARCHAR(255) NOT NULL COLLATE NOCASE, - expires_at TIMESTAMP NOT NULL, - created_at TIMESTAMP DEFAULT (DATETIME('now')), - CONSTRAINT `idx_token` UNIQUE (`token`) - ); - "# - ) - .down("DROP TABLE refresh_tokens"), - ]); -} - -#[allow(unused)] -pub struct User { - created_at: DateTime, - updated_at: DateTime, - - pub id: i32, - pub email: String, - pub password_encrypted: String, - pub is_admin: bool, -} - -impl User { - fn select(clause: &str) -> String { - r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "# - .to_owned() - + clause - } - - fn from_row(row: &Row<'_>) -> std::result::Result { - Ok(User { - id: row.get(0)?, - email: row.get(1)?, - password_encrypted: row.get(2)?, - is_admin: row.get(3)?, - created_at: row.get(4)?, - updated_at: row.get(5)?, - }) - } -} - -async fn db_path() -> Result { - let db_dir = tabby_root().join("ee"); - tokio::fs::create_dir_all(db_dir.clone()).await?; - Ok(db_dir.join("db.sqlite")) -} - -#[derive(Clone)] -pub struct DbConn { - conn: Arc, -} - -impl DbConn { - #[cfg(test)] - pub async fn new_in_memory() -> Result { - let conn = Connection::open_in_memory().await?; - DbConn::init_db(conn).await - } - - pub async fn new() -> Result { - let db_path = db_path().await?; - let conn = Connection::open(db_path).await?; - Self::init_db(conn).await - } - - /// Initialize database, create tables and insert first token if not exist - async fn init_db(mut conn: Connection) -> Result { - MIGRATIONS.to_latest(&mut conn).await?; - - let token = uuid::Uuid::new_v4().to_string(); - conn.call(move |c| { - c.execute( - r#"INSERT OR IGNORE INTO registration_token (id, token) VALUES (1, ?)"#, - params![token], - ) - }) - .await?; - - let res = Self { - conn: Arc::new(conn), - }; - run_offline_job(res.clone()); - - Ok(res) - } -} - -/// db read/write operations for `registration_token` table -impl DbConn { - /// Query token from database. - /// Since token is global unique for each tabby server, by right there's only one row in the table. - pub async fn read_registration_token(&self) -> Result { - let token = self - .conn - .call(|conn| { - conn.query_row( - r#"SELECT token FROM registration_token WHERE id = 1"#, - [], - |row| row.get(0), - ) - }) - .await?; - - Ok(token) - } - - /// Update token in database. - pub async fn reset_registration_token(&self) -> Result { - let token = uuid::Uuid::new_v4().to_string(); - let result = token.clone(); - let updated_at = chrono::Utc::now().timestamp() as u32; - - let res = self - .conn - .call(move |conn| { - conn.execute( - r#"UPDATE registration_token SET token = ?, updated_at = ? WHERE id = 1"#, - params![token, updated_at], - ) - }) - .await?; - if res != 1 { - return Err(anyhow::anyhow!("failed to update token")); - } - - Ok(result) - } -} - -/// db read/write operations for `users` table -impl DbConn { - pub async fn create_user( - &self, - email: String, - password_encrypted: String, - is_admin: bool, - ) -> Result { - let res = self - .conn - .call(move |c| { - let mut stmt = c.prepare( - r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, - )?; - let id = stmt.insert((email, password_encrypted, is_admin))?; - Ok(id) - }) - .await?; - - Ok(res as i32) - } - - pub async fn get_user(&self, id: i32) -> Result> { - let user = self - .conn - .call(move |c| { - c.query_row(User::select("id = ?").as_str(), params![id], User::from_row) - .optional() - }) - .await?; - - Ok(user) - } - - pub async fn get_user_by_email(&self, email: &str) -> Result> { - let email = email.to_owned(); - let user = self - .conn - .call(move |c| { - c.query_row( - User::select("email = ?").as_str(), - params![email], - User::from_row, - ) - .optional() - }) - .await?; - - Ok(user) - } - - pub async fn list_admin_users(&self) -> Result> { - let users = self - .conn - .call(move |c| { - let mut stmt = c.prepare(&User::select("is_admin"))?; - let user_iter = stmt.query_map([], User::from_row)?; - Ok(user_iter.filter_map(|x| x.ok()).collect::>()) - }) - .await?; - - Ok(users) - } -} - -impl Invitation { - fn from_row(row: &Row<'_>) -> std::result::Result { - Ok(Self { - id: row.get(0)?, - email: row.get(1)?, - code: row.get(2)?, - created_at: row.get(3)?, - }) - } -} - -/// db read/write operations for `invitations` table -impl DbConn { - pub async fn list_invitations(&self) -> Result> { - let invitations = self - .conn - .call(move |c| { - let mut stmt = - c.prepare(r#"SELECT id, email, code, created_at FROM invitations"#)?; - let iter = stmt.query_map([], Invitation::from_row)?; - Ok(iter.filter_map(|x| x.ok()).collect::>()) - }) - .await?; - - Ok(invitations) - } - - pub async fn get_invitation_by_code(&self, code: &str) -> Result> { - let code = code.to_owned(); - let token = self - .conn - .call(|conn| { - conn.query_row( - r#"SELECT id, email, code, created_at FROM invitations WHERE code = ?"#, - [code], - Invitation::from_row, - ) - .optional() - }) - .await?; - - Ok(token) - } - - pub async fn create_invitation(&self, email: String) -> Result { - let code = Uuid::new_v4().to_string(); - let res = self - .conn - .call(move |c| { - let mut stmt = - c.prepare(r#"INSERT INTO invitations (email, code) VALUES (?, ?)"#)?; - let rowid = stmt.insert((email, code))?; - Ok(rowid) - }) - .await?; - if res != 1 { - return Err(anyhow!("failed to create invitation")); - } - - Ok(res as i32) - } - - pub async fn delete_invitation(&self, id: i32) -> Result { - let res = self - .conn - .call(move |c| c.execute(r#"DELETE FROM invitations WHERE id = ?"#, params![id])) - .await?; - if res != 1 { - return Err(anyhow!("failed to delete invitation")); - } - - Ok(id) - } -} - -#[allow(unused)] -pub struct RefreshToken { - id: u32, - created_at: DateTime, - - pub user_id: i32, - pub token: String, - pub expires_at: DateTime, -} - -impl RefreshToken { - fn select(clause: &str) -> String { - r#"SELECT id, user_id, token, expires_at, created_at FROM refresh_tokens WHERE "#.to_owned() - + clause - } - - fn from_row(row: &Row<'_>) -> std::result::Result { - Ok(RefreshToken { - id: row.get(0)?, - user_id: row.get(1)?, - token: row.get(2)?, - expires_at: row.get(3)?, - created_at: row.get(4)?, - }) - } - - pub fn is_expired(&self) -> bool { - let now = chrono::Utc::now(); - self.expires_at < now - } -} - -/// db read/write operations for `refresh_tokens` table -impl DbConn { - pub async fn create_refresh_token(&self, user_id: i32, token: &str) -> Result<()> { - let token = token.to_string(); - let res = self - .conn - .call(move |c| { - c.execute( - r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, datetime('now', '+7 days'))"#, - params![user_id, token], - ) - }) - .await?; - if res != 1 { - return Err(anyhow::anyhow!("failed to create refresh token")); - } - - Ok(()) - } - - pub async fn replace_refresh_token(&self, old: &str, new: &str) -> Result<()> { - let old = old.to_string(); - let new = new.to_string(); - let res = self - .conn - .call(move |c| { - c.execute( - r#"UPDATE refresh_tokens SET token = ? WHERE token = ?"#, - params![new, old], - ) - }) - .await?; - if res != 1 { - return Err(anyhow::anyhow!("failed to replace refresh token")); - } - - Ok(()) - } - - pub async fn delete_expired_token(&self, utc_ts: i64) -> Result { - let res = self - .conn - .call(move |c| { - c.execute( - r#"DELETE FROM refresh_tokens WHERE expires_at < ?"#, - params![utc_ts], - ) - }) - .await?; - - Ok(res as i32) - } - - pub async fn get_refresh_token(&self, token: &str) -> Result> { - let token = token.to_string(); - let token = self - .conn - .call(move |c| { - c.query_row( - RefreshToken::select("token = ?").as_str(), - params![token], - RefreshToken::from_row, - ) - .optional() - }) - .await?; - - Ok(token) - } -} - -#[cfg(test)] -mod tests { - - use std::ops::Add; - - use super::*; - use crate::schema::auth::AuthenticationService; - - async fn create_user(conn: &DbConn) -> i32 { - let email: &str = "test@example.com"; - let password: &str = "123456789"; - conn.create_user(email.to_string(), password.to_string(), true) - .await - .unwrap() - } - - #[tokio::test] - async fn migrations_test() { - assert!(MIGRATIONS.validate().await.is_ok()); - } - - #[tokio::test] - async fn test_token() { - let conn = DbConn::new_in_memory().await.unwrap(); - let token = conn.read_registration_token().await.unwrap(); - assert_eq!(token.len(), 36); - } - - #[tokio::test] - async fn test_update_token() { - let conn = DbConn::new_in_memory().await.unwrap(); - - let old_token = conn.read_registration_token().await.unwrap(); - conn.reset_registration_token().await.unwrap(); - let new_token = conn.read_registration_token().await.unwrap(); - assert_eq!(new_token.len(), 36); - assert_ne!(old_token, new_token); - } - - #[tokio::test] - async fn test_create_user() { - let conn = DbConn::new_in_memory().await.unwrap(); - - let id = create_user(&conn).await; - let user = conn.get_user(id).await.unwrap().unwrap(); - assert_eq!(user.id, 1); - } - - #[tokio::test] - async fn test_get_user_by_email() { - let conn = DbConn::new_in_memory().await.unwrap(); - - let email = "hello@example.com"; - let user = conn.get_user_by_email(email).await.unwrap(); - - assert!(user.is_none()); - } - - #[tokio::test] - async fn test_is_admin_initialized() { - let conn = DbConn::new_in_memory().await.unwrap(); - - assert!(!conn.is_admin_initialized().await.unwrap()); - create_user(&conn).await; - assert!(conn.is_admin_initialized().await.unwrap()); - } - - #[tokio::test] - async fn test_invitations() { - let conn = DbConn::new_in_memory().await.unwrap(); - - let email = "hello@example.com".to_owned(); - conn.create_invitation(email).await.unwrap(); - - let invitations = conn.list_invitations().await.unwrap(); - assert_eq!(1, invitations.len()); - - assert!(Uuid::parse_str(&invitations[0].code).is_ok()); - let invitation = conn - .get_invitation_by_code(&invitations[0].code) - .await - .ok() - .flatten() - .unwrap(); - assert_eq!(invitation.id, invitations[0].id); - - conn.delete_invitation(invitations[0].id).await.unwrap(); - - let invitations = conn.list_invitations().await.unwrap(); - assert!(invitations.is_empty()); - } - - #[tokio::test] - async fn test_create_refresh_token() { - let conn = DbConn::new_in_memory().await.unwrap(); - - conn.create_refresh_token(1, "test").await.unwrap(); - - let token = conn.get_refresh_token("test").await.unwrap().unwrap(); - - assert_eq!(token.user_id, 1); - assert_eq!(token.token, "test"); - assert!(token.expires_at > Utc::now().add(chrono::Duration::days(6))); - assert!(token.expires_at < Utc::now().add(chrono::Duration::days(7))); - } - - #[tokio::test] - async fn test_replace_refresh_token() { - let conn = DbConn::new_in_memory().await.unwrap(); - - conn.create_refresh_token(1, "test").await.unwrap(); - conn.replace_refresh_token("test", "test2").await.unwrap(); - - let token = conn.get_refresh_token("test").await.unwrap(); - assert!(token.is_none()); - - let token = conn.get_refresh_token("test2").await.unwrap().unwrap(); - assert_eq!(token.user_id, 1); - assert_eq!(token.token, "test2"); - } -} diff --git a/ee/tabby-webserver/src/service/db/invitations.rs b/ee/tabby-webserver/src/service/db/invitations.rs new file mode 100644 index 0000000..c3b7a5f --- /dev/null +++ b/ee/tabby-webserver/src/service/db/invitations.rs @@ -0,0 +1,111 @@ +use anyhow::{anyhow, Result}; +use rusqlite::{params, OptionalExtension, Row}; +use uuid::Uuid; + +use super::DbConn; +use crate::schema::auth::Invitation; + +impl Invitation { + fn from_row(row: &Row<'_>) -> std::result::Result { + Ok(Self { + id: row.get(0)?, + email: row.get(1)?, + code: row.get(2)?, + created_at: row.get(3)?, + }) + } +} + +/// db read/write operations for `invitations` table +impl DbConn { + pub async fn list_invitations(&self) -> Result> { + let invitations = self + .conn + .call(move |c| { + let mut stmt = + c.prepare(r#"SELECT id, email, code, created_at FROM invitations"#)?; + let iter = stmt.query_map([], Invitation::from_row)?; + Ok(iter.filter_map(|x| x.ok()).collect::>()) + }) + .await?; + + Ok(invitations) + } + + pub async fn get_invitation_by_code(&self, code: &str) -> Result> { + let code = code.to_owned(); + let token = self + .conn + .call(|conn| { + conn.query_row( + r#"SELECT id, email, code, created_at FROM invitations WHERE code = ?"#, + [code], + Invitation::from_row, + ) + .optional() + }) + .await?; + + Ok(token) + } + + pub async fn create_invitation(&self, email: String) -> Result { + let code = Uuid::new_v4().to_string(); + let res = self + .conn + .call(move |c| { + let mut stmt = + c.prepare(r#"INSERT INTO invitations (email, code) VALUES (?, ?)"#)?; + let rowid = stmt.insert((email, code))?; + Ok(rowid) + }) + .await?; + if res != 1 { + return Err(anyhow!("failed to create invitation")); + } + + Ok(res as i32) + } + + pub async fn delete_invitation(&self, id: i32) -> Result { + let res = self + .conn + .call(move |c| c.execute(r#"DELETE FROM invitations WHERE id = ?"#, params![id])) + .await?; + if res != 1 { + return Err(anyhow!("failed to delete invitation")); + } + + Ok(id) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_invitations() { + let conn = DbConn::new_in_memory().await.unwrap(); + + let email = "hello@example.com".to_owned(); + conn.create_invitation(email).await.unwrap(); + + let invitations = conn.list_invitations().await.unwrap(); + assert_eq!(1, invitations.len()); + + assert!(Uuid::parse_str(&invitations[0].code).is_ok()); + let invitation = conn + .get_invitation_by_code(&invitations[0].code) + .await + .ok() + .flatten() + .unwrap(); + assert_eq!(invitation.id, invitations[0].id); + + conn.delete_invitation(invitations[0].id).await.unwrap(); + + let invitations = conn.list_invitations().await.unwrap(); + assert!(invitations.is_empty()); + } +} diff --git a/ee/tabby-webserver/src/service/db/mod.rs b/ee/tabby-webserver/src/service/db/mod.rs new file mode 100644 index 0000000..1ac9453 --- /dev/null +++ b/ee/tabby-webserver/src/service/db/mod.rs @@ -0,0 +1,219 @@ +mod invitations; +mod refresh_tokens; +mod users; + +use std::{path::PathBuf, sync::Arc}; + +use anyhow::Result; +use lazy_static::lazy_static; +use rusqlite::params; +use rusqlite_migration::{AsyncMigrations, M}; +use tabby_common::path::tabby_root; +use tokio_rusqlite::Connection; + +use crate::service::cron::run_offline_job; + +lazy_static! { + static ref MIGRATIONS: AsyncMigrations = AsyncMigrations::new(vec![ + M::up( + r#" + CREATE TABLE registration_token ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + updated_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_token` UNIQUE (`token`) + ); + "# + ) + .down("DROP TABLE registration_token"), + M::up( + r#" + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email VARCHAR(150) NOT NULL COLLATE NOCASE, + password_encrypted VARCHAR(128) NOT NULL, + is_admin BOOLEAN NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + updated_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_email` UNIQUE (`email`) + ); + "# + ) + .down("DROP TABLE users"), + M::up( + r#" + CREATE TABLE invitations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email VARCHAR(150) NOT NULL COLLATE NOCASE, + code VARCHAR(36) NOT NULL, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_email` UNIQUE (`email`) + CONSTRAINT `idx_code` UNIQUE (`code`) + ); + "# + ) + .down("DROP TABLE invitations"), + M::up( + r#" + CREATE TABLE refresh_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + token VARCHAR(255) NOT NULL COLLATE NOCASE, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT (DATETIME('now')), + CONSTRAINT `idx_token` UNIQUE (`token`) + ); + "# + ) + .down("DROP TABLE refresh_tokens"), + ]); +} + +async fn db_path() -> Result { + let db_dir = tabby_root().join("ee"); + tokio::fs::create_dir_all(db_dir.clone()).await?; + Ok(db_dir.join("db.sqlite")) +} + +#[derive(Clone)] +pub struct DbConn { + conn: Arc, +} + +impl DbConn { + #[cfg(test)] + pub async fn new_in_memory() -> Result { + let conn = Connection::open_in_memory().await?; + DbConn::init_db(conn).await + } + + pub async fn new() -> Result { + let db_path = db_path().await?; + let conn = Connection::open(db_path).await?; + Self::init_db(conn).await + } + + /// Initialize database, create tables and insert first token if not exist + async fn init_db(mut conn: Connection) -> Result { + MIGRATIONS.to_latest(&mut conn).await?; + + let token = uuid::Uuid::new_v4().to_string(); + conn.call(move |c| { + c.execute( + r#"INSERT OR IGNORE INTO registration_token (id, token) VALUES (1, ?)"#, + params![token], + ) + }) + .await?; + + let res = Self { + conn: Arc::new(conn), + }; + run_offline_job(res.clone()); + + Ok(res) + } +} + +/// db read/write operations for `registration_token` table +impl DbConn { + /// Query token from database. + /// Since token is global unique for each tabby server, by right there's only one row in the table. + pub async fn read_registration_token(&self) -> Result { + let token = self + .conn + .call(|conn| { + conn.query_row( + r#"SELECT token FROM registration_token WHERE id = 1"#, + [], + |row| row.get(0), + ) + }) + .await?; + + Ok(token) + } + + /// Update token in database. + pub async fn reset_registration_token(&self) -> Result { + let token = uuid::Uuid::new_v4().to_string(); + let result = token.clone(); + let updated_at = chrono::Utc::now().timestamp() as u32; + + let res = self + .conn + .call(move |conn| { + conn.execute( + r#"UPDATE registration_token SET token = ?, updated_at = ? WHERE id = 1"#, + params![token, updated_at], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to update token")); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::schema::auth::AuthenticationService; + + async fn create_user(conn: &DbConn) -> i32 { + let email: &str = "test@example.com"; + let password: &str = "123456789"; + conn.create_user(email.to_string(), password.to_string(), true) + .await + .unwrap() + } + + #[tokio::test] + async fn migrations_test() { + assert!(MIGRATIONS.validate().await.is_ok()); + } + + #[tokio::test] + async fn test_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + let token = conn.read_registration_token().await.unwrap(); + assert_eq!(token.len(), 36); + } + + #[tokio::test] + async fn test_update_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + + let old_token = conn.read_registration_token().await.unwrap(); + conn.reset_registration_token().await.unwrap(); + let new_token = conn.read_registration_token().await.unwrap(); + assert_eq!(new_token.len(), 36); + assert_ne!(old_token, new_token); + } + + #[tokio::test] + async fn test_is_admin_initialized() { + let conn = DbConn::new_in_memory().await.unwrap(); + + assert!(!conn.is_admin_initialized().await.unwrap()); + create_user(&conn).await; + assert!(conn.is_admin_initialized().await.unwrap()); + } +} + +#[cfg(test)] +mod testutils { + use super::*; + + pub(crate) async fn create_user(conn: &DbConn) -> i32 { + let email: &str = "test@example.com"; + let password: &str = "123456789"; + conn.create_user(email.to_string(), password.to_string(), true) + .await + .unwrap() + } +} diff --git a/ee/tabby-webserver/src/service/db/refresh_tokens.rs b/ee/tabby-webserver/src/service/db/refresh_tokens.rs new file mode 100644 index 0000000..63dd452 --- /dev/null +++ b/ee/tabby-webserver/src/service/db/refresh_tokens.rs @@ -0,0 +1,145 @@ +use anyhow::Result; +use chrono::{DateTime, Utc}; +use rusqlite::{params, OptionalExtension, Row}; + +use super::DbConn; + +#[allow(unused)] +pub struct RefreshToken { + id: u32, + created_at: DateTime, + + pub user_id: i32, + pub token: String, + pub expires_at: DateTime, +} + +impl RefreshToken { + fn select(clause: &str) -> String { + r#"SELECT id, user_id, token, expires_at, created_at FROM refresh_tokens WHERE "#.to_owned() + + clause + } + + fn from_row(row: &Row<'_>) -> std::result::Result { + Ok(RefreshToken { + id: row.get(0)?, + user_id: row.get(1)?, + token: row.get(2)?, + expires_at: row.get(3)?, + created_at: row.get(4)?, + }) + } + + pub fn is_expired(&self) -> bool { + let now = chrono::Utc::now(); + self.expires_at < now + } +} + +/// db read/write operations for `refresh_tokens` table +impl DbConn { + pub async fn create_refresh_token(&self, user_id: i32, token: &str) -> Result<()> { + let token = token.to_string(); + let res = self + .conn + .call(move |c| { + c.execute( + r#"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, datetime('now', '+7 days'))"#, + params![user_id, token], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to create refresh token")); + } + + Ok(()) + } + + pub async fn replace_refresh_token(&self, old: &str, new: &str) -> Result<()> { + let old = old.to_string(); + let new = new.to_string(); + let res = self + .conn + .call(move |c| { + c.execute( + r#"UPDATE refresh_tokens SET token = ? WHERE token = ?"#, + params![new, old], + ) + }) + .await?; + if res != 1 { + return Err(anyhow::anyhow!("failed to replace refresh token")); + } + + Ok(()) + } + + pub async fn delete_expired_token(&self, utc_ts: i64) -> Result { + let res = self + .conn + .call(move |c| { + c.execute( + r#"DELETE FROM refresh_tokens WHERE expires_at < ?"#, + params![utc_ts], + ) + }) + .await?; + + Ok(res as i32) + } + + pub async fn get_refresh_token(&self, token: &str) -> Result> { + let token = token.to_string(); + let token = self + .conn + .call(move |c| { + c.query_row( + RefreshToken::select("token = ?").as_str(), + params![token], + RefreshToken::from_row, + ) + .optional() + }) + .await?; + + Ok(token) + } +} + +#[cfg(test)] +mod tests { + + use std::ops::Add; + + use super::*; + + #[tokio::test] + async fn test_create_refresh_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + + conn.create_refresh_token(1, "test").await.unwrap(); + + let token = conn.get_refresh_token("test").await.unwrap().unwrap(); + + assert_eq!(token.user_id, 1); + assert_eq!(token.token, "test"); + assert!(token.expires_at > Utc::now().add(chrono::Duration::days(6))); + assert!(token.expires_at < Utc::now().add(chrono::Duration::days(7))); + } + + #[tokio::test] + async fn test_replace_refresh_token() { + let conn = DbConn::new_in_memory().await.unwrap(); + + conn.create_refresh_token(1, "test").await.unwrap(); + conn.replace_refresh_token("test", "test2").await.unwrap(); + + let token = conn.get_refresh_token("test").await.unwrap(); + assert!(token.is_none()); + + let token = conn.get_refresh_token("test2").await.unwrap().unwrap(); + assert_eq!(token.user_id, 1); + assert_eq!(token.token, "test2"); + } +} diff --git a/ee/tabby-webserver/src/service/db/users.rs b/ee/tabby-webserver/src/service/db/users.rs new file mode 100644 index 0000000..7403afd --- /dev/null +++ b/ee/tabby-webserver/src/service/db/users.rs @@ -0,0 +1,126 @@ +// db read/write operations for `users` table + +use anyhow::Result; +use chrono::{DateTime, Utc}; +use rusqlite::{params, OptionalExtension, Row}; + +use super::DbConn; + +#[allow(unused)] +pub struct User { + created_at: DateTime, + updated_at: DateTime, + + pub id: i32, + pub email: String, + pub password_encrypted: String, + pub is_admin: bool, +} + +impl User { + fn select(clause: &str) -> String { + r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "# + .to_owned() + + clause + } + + fn from_row(row: &Row<'_>) -> std::result::Result { + Ok(User { + id: row.get(0)?, + email: row.get(1)?, + password_encrypted: row.get(2)?, + is_admin: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + }) + } +} + +impl DbConn { + pub async fn create_user( + &self, + email: String, + password_encrypted: String, + is_admin: bool, + ) -> Result { + let res = self + .conn + .call(move |c| { + let mut stmt = c.prepare( + r#"INSERT INTO users (email, password_encrypted, is_admin) VALUES (?, ?, ?)"#, + )?; + let id = stmt.insert((email, password_encrypted, is_admin))?; + Ok(id) + }) + .await?; + + Ok(res as i32) + } + + pub async fn get_user(&self, id: i32) -> Result> { + let user = self + .conn + .call(move |c| { + c.query_row(User::select("id = ?").as_str(), params![id], User::from_row) + .optional() + }) + .await?; + + Ok(user) + } + + pub async fn get_user_by_email(&self, email: &str) -> Result> { + let email = email.to_owned(); + let user = self + .conn + .call(move |c| { + c.query_row( + User::select("email = ?").as_str(), + params![email], + User::from_row, + ) + .optional() + }) + .await?; + + Ok(user) + } + + pub async fn list_admin_users(&self) -> Result> { + let users = self + .conn + .call(move |c| { + let mut stmt = c.prepare(&User::select("is_admin"))?; + let user_iter = stmt.query_map([], User::from_row)?; + Ok(user_iter.filter_map(|x| x.ok()).collect::>()) + }) + .await?; + + Ok(users) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::service::db::testutils::create_user; + + #[tokio::test] + async fn test_create_user() { + let conn = DbConn::new_in_memory().await.unwrap(); + + let id = create_user(&conn).await; + let user = conn.get_user(id).await.unwrap().unwrap(); + assert_eq!(user.id, 1); + } + + #[tokio::test] + async fn test_get_user_by_email() { + let conn = DbConn::new_in_memory().await.unwrap(); + + let email = "hello@example.com"; + let user = conn.get_user_by_email(email).await.unwrap(); + + assert!(user.is_none()); + } +}