From 5c52a71f77ed65a3b266577017c5bb832742362c Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 1 Dec 2023 22:16:59 +0800 Subject: [PATCH] refactor: extract ServiceLocator interface (#933) * refactor: cleanup the dependency chain - ServerContext should be the only thing being public in server mod * refactor ServiceLocator * refactor: extract worker.rs * refactor: move db as private repo of server * rename server -> service --- crates/juniper-axum/src/extract.rs | 2 +- crates/juniper-axum/src/lib.rs | 2 +- ee/tabby-webserver/src/api.rs | 34 +---- ee/tabby-webserver/src/lib.rs | 41 +++--- ee/tabby-webserver/src/schema.rs | 100 ------------- ee/tabby-webserver/src/schema/auth.rs | 87 ++++++++---- ee/tabby-webserver/src/schema/mod.rs | 133 ++++++++++++++++++ ee/tabby-webserver/src/schema/worker.rs | 51 +++++++ .../src/{server => service}/auth.rs | 99 ++++--------- ee/tabby-webserver/src/{ => service}/db.rs | 0 .../src/{server.rs => service/mod.rs} | 77 ++++++---- .../src/{server => service}/proxy.rs | 0 .../src/{server => service}/worker.rs | 4 +- 13 files changed, 347 insertions(+), 283 deletions(-) delete mode 100644 ee/tabby-webserver/src/schema.rs create mode 100644 ee/tabby-webserver/src/schema/mod.rs create mode 100644 ee/tabby-webserver/src/schema/worker.rs rename ee/tabby-webserver/src/{server => service}/auth.rs (69%) rename ee/tabby-webserver/src/{ => service}/db.rs (100%) rename ee/tabby-webserver/src/{server.rs => service/mod.rs} (69%) rename ee/tabby-webserver/src/{server => service}/proxy.rs (100%) rename ee/tabby-webserver/src/{server => service}/worker.rs (96%) diff --git a/crates/juniper-axum/src/extract.rs b/crates/juniper-axum/src/extract.rs index d32e852..63fa43d 100644 --- a/crates/juniper-axum/src/extract.rs +++ b/crates/juniper-axum/src/extract.rs @@ -50,7 +50,7 @@ where let split = authorization.split_once(' '); match split { // Found proper bearer - Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))), + Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))), _ => Ok(Self(None)), } } diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs index f64ac2e..092f46f 100644 --- a/crates/juniper-axum/src/lib.rs +++ b/crates/juniper-axum/src/lib.rs @@ -1,7 +1,7 @@ pub mod extract; pub mod response; -use std::{future}; +use std::future; use axum::{ extract::{Extension, State}, diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index fa5998f..f201a1c 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -1,45 +1,13 @@ use async_trait::async_trait; -use juniper::{GraphQLEnum, GraphQLObject}; -use serde::{Deserialize, Serialize}; use tabby_common::api::{ code::{CodeSearch, CodeSearchError, SearchResponse}, event::RawEventLogger, }; -use thiserror::Error; use tokio_tungstenite::connect_async; +pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind}; use crate::websocket::WebSocketTransport; -#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)] -pub enum WorkerKind { - Completion, - Chat, -} - -#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)] -pub struct Worker { - pub kind: WorkerKind, - pub name: String, - pub addr: String, - pub device: String, - pub arch: String, - pub cpu_info: String, - pub cpu_count: i32, - pub cuda_devices: Vec, -} - -#[derive(Serialize, Deserialize, Error, Debug)] -pub enum RegisterWorkerError { - #[error("Invalid token")] - InvalidToken(String), - - #[error("Feature requires enterprise license")] - RequiresEnterpriseLicense, - - #[error("Each hub client should only calls register_worker once")] - RegisterWorkerOnce, -} - #[tarpc::service] pub trait Hub { async fn register_worker( diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 07bb961..2331bd5 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,6 +1,7 @@ pub mod api; mod schema; +use api::Hub; pub use schema::create_schema; use tabby_common::api::{ code::{CodeSearch, SearchResponse}, @@ -10,15 +11,13 @@ use tokio::sync::Mutex; use tracing::{error, warn}; use websocket::WebSocketTransport; -mod db; mod repositories; -mod server; +mod service; mod ui; mod websocket; use std::{net::SocketAddr, sync::Arc}; -use api::{Hub, RegisterWorkerError, Worker, WorkerKind}; use axum::{ extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, http::Request, @@ -28,8 +27,11 @@ use axum::{ }; use hyper::Body; use juniper_axum::{graphiql, graphql, playground}; -use schema::Schema; -use server::ServerContext; +use schema::{ + worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, + Schema, ServiceLocator, +}; +use service::create_service_locator; use tarpc::server::{BaseChannel, Channel}; pub async fn attach_webserver( @@ -38,15 +40,14 @@ pub async fn attach_webserver( logger: Arc, code: Arc, ) -> (Router, Router) { - let conn = db::DbConn::new().await.unwrap(); - let ctx = Arc::new(ServerContext::new(conn, logger, code)); + let ctx = create_service_locator(logger, code).await; let schema = Arc::new(create_schema()); let api = api .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer)) .route( "/graphql", - routing::post(graphql::, Arc>).with_state(ctx.clone()), + routing::post(graphql::, Arc>).with_state(ctx.clone()), ) .route("/graphql", routing::get(playground("/graphql", None))) .layer(Extension(schema)) @@ -61,22 +62,22 @@ pub async fn attach_webserver( } async fn distributed_tabby_layer( - State(ws): State>, + State(ws): State>, request: Request, next: Next, ) -> axum::response::Response { - ws.dispatch_request(request, next).await + ws.worker().dispatch_request(request, next).await } async fn ws_handler( ws: WebSocketUpgrade, - State(state): State>, + State(state): State>, ConnectInfo(addr): ConnectInfo, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(state, socket, addr)) } -async fn handle_socket(state: Arc, socket: WebSocket, addr: SocketAddr) { +async fn handle_socket(state: Arc, socket: WebSocket, addr: SocketAddr) { let transport = WebSocketTransport::from(socket); let server = BaseChannel::with_defaults(transport); let imp = Arc::new(HubImpl::new(state.clone(), addr)); @@ -84,14 +85,14 @@ async fn handle_socket(state: Arc, socket: WebSocket, addr: Socke } pub struct HubImpl { - ctx: Arc, + ctx: Arc, conn: SocketAddr, worker_addr: Arc>, } impl HubImpl { - pub fn new(ctx: Arc, conn: SocketAddr) -> Self { + pub fn new(ctx: Arc, conn: SocketAddr) -> Self { Self { ctx, conn, @@ -108,7 +109,7 @@ impl Drop for HubImpl { tokio::spawn(async move { let worker_addr = worker_addr.lock().await; if !worker_addr.is_empty() { - ctx.unregister_worker(worker_addr.as_str()).await; + ctx.worker().unregister_worker(worker_addr.as_str()).await; } }); } @@ -134,7 +135,7 @@ impl Hub for Arc { "Empty worker token".to_string(), )); } - let server_token = match self.ctx.read_registration_token().await { + let server_token = match self.ctx.worker().read_registration_token().await { Ok(t) => t, Err(err) => { error!("fetch server token: {}", err.to_string()); @@ -167,11 +168,11 @@ impl Hub for Arc { cpu_count, cuda_devices, }; - self.ctx.register_worker(worker).await + self.ctx.worker().register_worker(worker).await } async fn log_event(self, _context: tarpc::context::Context, content: String) { - self.ctx.logger.log(content) + self.ctx.logger().log(content) } async fn search( @@ -181,7 +182,7 @@ impl Hub for Arc { limit: usize, offset: usize, ) -> SearchResponse { - match self.ctx.code.search(&q, limit, offset).await { + match self.ctx.code().search(&q, limit, offset).await { Ok(serp) => serp, Err(err) => { warn!("Failed to search: {}", err); @@ -200,7 +201,7 @@ impl Hub for Arc { ) -> SearchResponse { match self .ctx - .code + .code() .search_in_language(&language, &tokens, limit, offset) .await { diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs deleted file mode 100644 index f02ba98..0000000 --- a/ee/tabby-webserver/src/schema.rs +++ /dev/null @@ -1,100 +0,0 @@ -pub mod auth; - -use std::sync::Arc; - - -use juniper::{ - graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode, -}; -use juniper_axum::FromAuth; - -use crate::{ - api::Worker, - schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse}, - server::{ - auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput}, - ServerContext, - }, -}; - -pub struct Context { - claims: Option, - server: Arc, -} - -impl FromAuth> for Context { - fn build(server: Arc, bearer: Option) -> Self { - let claims = bearer.and_then(|token| validate_jwt(&token).ok()); - Self { claims, server } - } -} - -// To make our context usable by Juniper, we have to implement a marker trait. -impl juniper::Context for Context {} - -#[derive(Default)] -pub struct Query; - -#[graphql_object(context = Context)] -impl Query { - async fn workers(ctx: &Context) -> Vec { - ctx.server.list_workers().await - } - - async fn registration_token(ctx: &Context) -> FieldResult { - let token = ctx.server.read_registration_token().await?; - Ok(token) - } -} - -#[derive(Default)] -pub struct Mutation; - -#[graphql_object(context = Context)] -impl Mutation { - async fn reset_registration_token(ctx: &Context) -> FieldResult { - if let Some(claims) = &ctx.claims { - if claims.user_info().is_admin() { - let reg_token = ctx.server.reset_registration_token().await?; - return Ok(reg_token); - } - } - Err(FieldError::new( - "Only admin is able to reset registration token", - graphql_value!("Unauthorized"), - )) - } - - async fn register( - ctx: &Context, - email: String, - password1: String, - password2: String, - ) -> FieldResult { - let input = RegisterInput { - email, - password1, - password2, - }; - ctx.server.auth().register(input).await - } - - async fn token_auth( - ctx: &Context, - email: String, - password: String, - ) -> FieldResult { - let input = TokenAuthInput { email, password }; - ctx.server.auth().token_auth(input).await - } - - async fn verify_token(ctx: &Context, token: String) -> FieldResult { - ctx.server.auth().verify_token(token).await - } -} - -pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; - -pub fn create_schema() -> Schema { - Schema::new(Query, Mutation, EmptySubscription::new()) -} diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 9e5aa61..48e79ea 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -1,38 +1,35 @@ use std::fmt::Debug; +use async_trait::async_trait; use jsonwebtoken as jwt; -use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value}; +use juniper::{FieldResult, GraphQLObject}; +use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use validator::ValidationError; -use crate::server::auth::JWT_DEFAULT_EXP; - -#[derive(Debug)] -pub struct ValidationErrors { - pub errors: Vec, +lazy_static! { + static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret( + jwt_token_secret().as_bytes() + ); + static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret( + jwt_token_secret().as_bytes() + ); + static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes } -impl IntoFieldError for ValidationErrors { - fn into_field_error(self) -> FieldError { - let errors = self - .errors - .into_iter() - .map(|err| { - let mut obj = Object::with_capacity(2); - obj.add_field("path", Value::scalar(err.code.to_string())); - obj.add_field( - "message", - Value::scalar(err.message.unwrap_or_default().to_string()), - ); - obj.into() - }) - .collect::>(); - let mut ext = Object::with_capacity(2); - ext.add_field("code", Value::scalar("validation-error".to_string())); - ext.add_field("errors", Value::list(errors)); +pub fn generate_jwt(claims: Claims) -> jwt::errors::Result { + let header = jwt::Header::default(); + let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?; + Ok(token) +} - FieldError::new("Invalid input parameters", ext.into()) - } +pub fn validate_jwt(token: &str) -> jwt::errors::Result { + let validation = jwt::Validation::default(); + let data = jwt::decode::(token, &JWT_DECODING_KEY, &validation)?; + Ok(data.claims) +} + +fn jwt_token_secret() -> String { + std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string()) } #[derive(Debug, GraphQLObject)] @@ -127,3 +124,39 @@ impl Claims { &self.user } } + +#[async_trait] +pub trait AuthenticationService: Send + Sync { + async fn register( + &self, + email: String, + password1: String, + password2: String, + ) -> FieldResult; + async fn token_auth(&self, email: String, password: String) -> FieldResult; + async fn refresh_token(&self, refresh_token: String) -> FieldResult; + async fn verify_token(&self, access_token: String) -> FieldResult; +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_generate_jwt() { + let claims = Claims::new(UserInfo::new("test".to_string(), false)); + let token = generate_jwt(claims).unwrap(); + + assert!(!token.is_empty()) + } + + #[test] + fn test_validate_jwt() { + let claims = Claims::new(UserInfo::new("test".to_string(), false)); + let token = generate_jwt(claims).unwrap(); + let claims = validate_jwt(&token).unwrap(); + assert_eq!( + claims.user_info(), + &UserInfo::new("test".to_string(), false) + ); + } +} diff --git a/ee/tabby-webserver/src/schema/mod.rs b/ee/tabby-webserver/src/schema/mod.rs new file mode 100644 index 0000000..0df47fa --- /dev/null +++ b/ee/tabby-webserver/src/schema/mod.rs @@ -0,0 +1,133 @@ +pub mod auth; +pub mod worker; + +use std::sync::Arc; + +use auth::AuthenticationService; +use juniper::{ + graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError, + Object, RootNode, ScalarValue, Value, +}; +use juniper_axum::FromAuth; +use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; +use validator::ValidationError; + +use self::{auth::validate_jwt, worker::WorkerService}; +use crate::schema::{ + auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse}, + worker::Worker, +}; + +pub trait ServiceLocator: Send + Sync { + fn auth(&self) -> &dyn AuthenticationService; + fn worker(&self) -> &dyn WorkerService; + fn code(&self) -> &dyn CodeSearch; + fn logger(&self) -> &dyn RawEventLogger; +} + +pub struct Context { + claims: Option, + server: Arc, +} + +impl FromAuth> for Context { + fn build(server: Arc, bearer: Option) -> Self { + let claims = bearer.and_then(|token| validate_jwt(&token).ok()); + Self { claims, server } + } +} + +// To make our context usable by Juniper, we have to implement a marker trait. +impl juniper::Context for Context {} + +#[derive(Default)] +pub struct Query; + +#[graphql_object(context = Context)] +impl Query { + async fn workers(ctx: &Context) -> Vec { + ctx.server.worker().list_workers().await + } + + async fn registration_token(ctx: &Context) -> FieldResult { + let token = ctx.server.worker().read_registration_token().await?; + Ok(token) + } +} + +#[derive(Default)] +pub struct Mutation; + +#[graphql_object(context = Context)] +impl Mutation { + async fn reset_registration_token(ctx: &Context) -> FieldResult { + if let Some(claims) = &ctx.claims { + if claims.user_info().is_admin() { + let reg_token = ctx.server.worker().reset_registration_token().await?; + return Ok(reg_token); + } + } + Err(FieldError::new( + "Only admin is able to reset registration token", + graphql_value!("Unauthorized"), + )) + } + + async fn register( + ctx: &Context, + email: String, + password1: String, + password2: String, + ) -> FieldResult { + ctx.server + .auth() + .register(email, password1, password2) + .await + } + + async fn token_auth( + ctx: &Context, + email: String, + password: String, + ) -> FieldResult { + ctx.server.auth().token_auth(email, password).await + } + + async fn verify_token(ctx: &Context, token: String) -> FieldResult { + ctx.server.auth().verify_token(token).await + } +} + +#[derive(Debug)] +pub struct ValidationErrors { + pub errors: Vec, +} + +impl IntoFieldError for ValidationErrors { + fn into_field_error(self) -> FieldError { + let errors = self + .errors + .into_iter() + .map(|err| { + let mut obj = Object::with_capacity(2); + obj.add_field("path", Value::scalar(err.code.to_string())); + obj.add_field( + "message", + Value::scalar(err.message.unwrap_or_default().to_string()), + ); + obj.into() + }) + .collect::>(); + let mut ext = Object::with_capacity(2); + ext.add_field("code", Value::scalar("validation-error".to_string())); + ext.add_field("errors", Value::list(errors)); + + FieldError::new("Invalid input parameters", ext.into()) + } +} + +pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; + +pub fn create_schema() -> Schema { + Schema::new(Query, Mutation, EmptySubscription::new()) +} diff --git a/ee/tabby-webserver/src/schema/worker.rs b/ee/tabby-webserver/src/schema/worker.rs new file mode 100644 index 0000000..78223ee --- /dev/null +++ b/ee/tabby-webserver/src/schema/worker.rs @@ -0,0 +1,51 @@ +use anyhow::Result; +use async_trait::async_trait; +use axum::middleware::Next; +use hyper::{Body, Request}; +use juniper::{GraphQLEnum, GraphQLObject}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)] +pub enum WorkerKind { + Completion, + Chat, +} + +#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)] +pub struct Worker { + pub kind: WorkerKind, + pub name: String, + pub addr: String, + pub device: String, + pub arch: String, + pub cpu_info: String, + pub cpu_count: i32, + pub cuda_devices: Vec, +} + +#[derive(Serialize, Deserialize, Error, Debug)] +pub enum RegisterWorkerError { + #[error("Invalid token")] + InvalidToken(String), + + #[error("Feature requires enterprise license")] + RequiresEnterpriseLicense, + + #[error("Each hub client should only calls register_worker once")] + RegisterWorkerOnce, +} + +#[async_trait] +pub trait WorkerService: Send + Sync { + async fn read_registration_token(&self) -> Result; + async fn reset_registration_token(&self) -> Result; + async fn list_workers(&self) -> Vec; + async fn register_worker(&self, worker: Worker) -> Result; + async fn unregister_worker(&self, worker_addr: &str); + async fn dispatch_request( + &self, + request: Request, + next: Next, + ) -> axum::response::Response; +} diff --git a/ee/tabby-webserver/src/server/auth.rs b/ee/tabby-webserver/src/service/auth.rs similarity index 69% rename from ee/tabby-webserver/src/server/auth.rs rename to ee/tabby-webserver/src/service/auth.rs index df54e88..447d184 100644 --- a/ee/tabby-webserver/src/server/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -1,48 +1,35 @@ -use std::env; - use argon2::{ password_hash, password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher, PasswordVerifier, }; use async_trait::async_trait; -use jsonwebtoken as jwt; use juniper::{FieldResult, IntoFieldError}; -use lazy_static::lazy_static; use validator::Validate; -use crate::{ - db::DbConn, - schema::auth::{ - Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, - ValidationErrors, VerifyTokenResponse, +use super::db::DbConn; +use crate::schema::{ + auth::{ + generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse, + RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse, }, + ValidationErrors, }; -lazy_static! { - static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret( - jwt_token_secret().as_bytes() - ); - static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret( - jwt_token_secret().as_bytes() - ); - pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes -} - /// Input parameters for register mutation /// `validate` attribute is used to validate the input parameters /// - `code` argument specifies which parameter causes the failure /// - `message` argument provides client friendly error message /// #[derive(Validate)] -pub struct RegisterInput { +struct RegisterInput { #[validate(email(code = "email", message = "Email is invalid"))] #[validate(length( max = 128, code = "email", message = "Email must be at most 128 characters" ))] - pub email: String, + email: String, #[validate(length( min = 8, code = "password1", @@ -58,7 +45,7 @@ pub struct RegisterInput { message = "Passwords do not match", other = "password2" ))] - pub password1: String, + password1: String, #[validate(length( min = 8, code = "password2", @@ -69,7 +56,7 @@ pub struct RegisterInput { code = "password2", message = "Password must be at most 20 characters" ))] - pub password2: String, + password2: String, } impl std::fmt::Debug for RegisterInput { @@ -85,14 +72,14 @@ impl std::fmt::Debug for RegisterInput { /// Input parameters for token_auth mutation /// See `RegisterInput` for `validate` attribute usage #[derive(Validate)] -pub struct TokenAuthInput { +struct TokenAuthInput { #[validate(email(code = "email", message = "Email is invalid"))] #[validate(length( max = 128, code = "email", message = "Email must be at most 128 characters" ))] - pub email: String, + email: String, #[validate(length( min = 8, code = "password", @@ -103,7 +90,7 @@ pub struct TokenAuthInput { code = "password", message = "Password must be at most 20 characters" ))] - pub password: String, + password: String, } impl std::fmt::Debug for TokenAuthInput { @@ -115,17 +102,19 @@ impl std::fmt::Debug for TokenAuthInput { } } -#[async_trait] -pub trait AuthenticationService { - async fn register(&self, input: RegisterInput) -> FieldResult; - async fn token_auth(&self, input: TokenAuthInput) -> FieldResult; - async fn refresh_token(&self, refresh_token: String) -> FieldResult; - async fn verify_token(&self, access_token: String) -> FieldResult; -} - #[async_trait] impl AuthenticationService for DbConn { - async fn register(&self, input: RegisterInput) -> FieldResult { + async fn register( + &self, + email: String, + password1: String, + password2: String, + ) -> FieldResult { + let input = RegisterInput { + email, + password1, + password2, + }; input.validate().map_err(|err| { let errors = err .field_errors() @@ -138,7 +127,7 @@ impl AuthenticationService for DbConn { })?; // check if email exists - if let Some(_) = self.get_user_by_email(&input.email).await? { + if self.get_user_by_email(&input.email).await?.is_some() { return Err("Email already exists".into()); } @@ -157,7 +146,8 @@ impl AuthenticationService for DbConn { Ok(resp) } - async fn token_auth(&self, input: TokenAuthInput) -> FieldResult { + async fn token_auth(&self, email: String, password: String) -> FieldResult { + let input = TokenAuthInput { email, password }; input.validate().map_err(|err| { let errors = err .field_errors() @@ -217,22 +207,6 @@ fn password_verify(raw: &str, hash: &str) -> bool { } } -fn generate_jwt(claims: Claims) -> jwt::errors::Result { - let header = jwt::Header::default(); - let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?; - Ok(token) -} - -pub fn validate_jwt(token: &str) -> jwt::errors::Result { - let validation = jwt::Validation::default(); - let data = jwt::decode::(token, &JWT_DECODING_KEY, &validation)?; - Ok(data.claims) -} - -fn jwt_token_secret() -> String { - env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string()) -} - #[cfg(test)] mod tests { use super::*; @@ -254,23 +228,4 @@ mod tests { assert!(password_verify(raw, &hash)); assert!(!password_verify(raw, "invalid hash")); } - - #[test] - fn test_generate_jwt() { - let claims = Claims::new(UserInfo::new("test".to_string(), false)); - let token = generate_jwt(claims).unwrap(); - - assert!(!token.is_empty()) - } - - #[test] - fn test_validate_jwt() { - let claims = Claims::new(UserInfo::new("test".to_string(), false)); - let token = generate_jwt(claims).unwrap(); - let claims = validate_jwt(&token).unwrap(); - assert_eq!( - claims.user_info(), - &UserInfo::new("test".to_string(), false) - ); - } } diff --git a/ee/tabby-webserver/src/db.rs b/ee/tabby-webserver/src/service/db.rs similarity index 100% rename from ee/tabby-webserver/src/db.rs rename to ee/tabby-webserver/src/service/db.rs diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/service/mod.rs similarity index 69% rename from ee/tabby-webserver/src/server.rs rename to ee/tabby-webserver/src/service/mod.rs index 07ef311..bd47418 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -1,63 +1,65 @@ -pub mod auth; +mod auth; +mod db; mod proxy; mod worker; use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; +use async_trait::async_trait; use axum::{http::Request, middleware::Next, response::IntoResponse}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tracing::{info, warn}; -use crate::{ - api::{RegisterWorkerError, Worker, WorkerKind}, - db::DbConn, - server::auth::AuthenticationService, +use self::db::DbConn; +use crate::schema::{ + auth::AuthenticationService, + worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService}, + ServiceLocator, }; -pub struct ServerContext { +struct ServerContext { client: Client, completion: worker::WorkerGroup, chat: worker::WorkerGroup, db_conn: DbConn, - pub logger: Arc, - pub code: Arc, + logger: Arc, + code: Arc, } impl ServerContext { - pub fn new( - db_conn: DbConn, - logger: Arc, - code: Arc, - ) -> Self { + pub async fn new(logger: Arc, code: Arc) -> Self { Self { client: Client::default(), completion: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(), - db_conn, + db_conn: DbConn::new().await.unwrap(), logger, code, } } +} - pub fn auth(&self) -> impl AuthenticationService { - self.db_conn.clone() - } - +#[async_trait] +impl WorkerService for ServerContext { /// Query current token from the database. - pub async fn read_registration_token(&self) -> Result { + async fn read_registration_token(&self) -> Result { self.db_conn.read_registration_token().await } /// Generate new token, and update it in the database. /// Return new token after update is done - pub async fn reset_registration_token(&self) -> Result { + async fn reset_registration_token(&self) -> Result { self.db_conn.reset_registration_token().await } - pub async fn register_worker(&self, worker: Worker) -> Result { + async fn list_workers(&self) -> Vec { + [self.completion.list().await, self.chat.list().await].concat() + } + + async fn register_worker(&self, worker: Worker) -> Result { let worker = match worker.kind { WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await, @@ -74,7 +76,7 @@ impl ServerContext { } } - pub async fn unregister_worker(&self, worker_addr: &str) { + async fn unregister_worker(&self, worker_addr: &str) { let kind = if self.chat.unregister(worker_addr).await { WorkerKind::Chat } else if self.completion.unregister(worker_addr).await { @@ -87,11 +89,7 @@ impl ServerContext { info!("unregistering <{:?}> worker at {}", kind, worker_addr); } - pub async fn list_workers(&self) -> Vec { - [self.completion.list().await, self.chat.list().await].concat() - } - - pub async fn dispatch_request( + async fn dispatch_request( &self, request: Request, next: Next, @@ -129,3 +127,28 @@ impl ServerContext { } } } + +impl ServiceLocator for ServerContext { + fn auth(&self) -> &dyn AuthenticationService { + &self.db_conn + } + + fn worker(&self) -> &dyn WorkerService { + self + } + + fn code(&self) -> &dyn CodeSearch { + &*self.code + } + + fn logger(&self) -> &dyn RawEventLogger { + &*self.logger + } +} + +pub async fn create_service_locator( + logger: Arc, + code: Arc, +) -> Arc { + Arc::new(ServerContext::new(logger, code).await) +} diff --git a/ee/tabby-webserver/src/server/proxy.rs b/ee/tabby-webserver/src/service/proxy.rs similarity index 100% rename from ee/tabby-webserver/src/server/proxy.rs rename to ee/tabby-webserver/src/service/proxy.rs diff --git a/ee/tabby-webserver/src/server/worker.rs b/ee/tabby-webserver/src/service/worker.rs similarity index 96% rename from ee/tabby-webserver/src/server/worker.rs rename to ee/tabby-webserver/src/service/worker.rs index 128da9c..3701129 100644 --- a/ee/tabby-webserver/src/server/worker.rs +++ b/ee/tabby-webserver/src/service/worker.rs @@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tracing::error; -use crate::api::Worker; +use crate::schema::worker::Worker; #[derive(Default)] pub struct WorkerGroup { @@ -61,7 +61,7 @@ fn random_index(size: usize) -> usize { mod tests { use super::*; - use crate::api::WorkerKind; + use crate::schema::worker::WorkerKind; #[tokio::test] async fn test_worker_group() {