refactor: extract ServiceLocator interface (#933)

* refactor: cleanup the dependency chain - ServerContext should be the only thing being public in server mod

* refactor ServiceLocator

* refactor: extract worker.rs

* refactor: move db as private repo of server

* rename server -> service
add-signin-page
Meng Zhang 2023-12-01 22:16:59 +08:00 committed by GitHub
parent 1a9cbdcc3c
commit 5c52a71f77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 347 additions and 283 deletions

View File

@ -50,7 +50,7 @@ where
let split = authorization.split_once(' '); let split = authorization.split_once(' ');
match split { match split {
// Found proper bearer // Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))), Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))),
_ => Ok(Self(None)), _ => Ok(Self(None)),
} }
} }

View File

@ -1,7 +1,7 @@
pub mod extract; pub mod extract;
pub mod response; pub mod response;
use std::{future}; use std::future;
use axum::{ use axum::{
extract::{Extension, State}, extract::{Extension, State},

View File

@ -1,45 +1,13 @@
use async_trait::async_trait; use async_trait::async_trait;
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use tabby_common::api::{ use tabby_common::api::{
code::{CodeSearch, CodeSearchError, SearchResponse}, code::{CodeSearch, CodeSearchError, SearchResponse},
event::RawEventLogger, event::RawEventLogger,
}; };
use thiserror::Error;
use tokio_tungstenite::connect_async; use tokio_tungstenite::connect_async;
pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
use crate::websocket::WebSocketTransport; use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[tarpc::service] #[tarpc::service]
pub trait Hub { pub trait Hub {
async fn register_worker( async fn register_worker(

View File

@ -1,6 +1,7 @@
pub mod api; pub mod api;
mod schema; mod schema;
use api::Hub;
pub use schema::create_schema; pub use schema::create_schema;
use tabby_common::api::{ use tabby_common::api::{
code::{CodeSearch, SearchResponse}, code::{CodeSearch, SearchResponse},
@ -10,15 +11,13 @@ use tokio::sync::Mutex;
use tracing::{error, warn}; use tracing::{error, warn};
use websocket::WebSocketTransport; use websocket::WebSocketTransport;
mod db;
mod repositories; mod repositories;
mod server; mod service;
mod ui; mod ui;
mod websocket; mod websocket;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
use axum::{ use axum::{
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request, http::Request,
@ -28,8 +27,11 @@ use axum::{
}; };
use hyper::Body; use hyper::Body;
use juniper_axum::{graphiql, graphql, playground}; use juniper_axum::{graphiql, graphql, playground};
use schema::Schema; use schema::{
use server::ServerContext; worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
Schema, ServiceLocator,
};
use service::create_service_locator;
use tarpc::server::{BaseChannel, Channel}; use tarpc::server::{BaseChannel, Channel};
pub async fn attach_webserver( pub async fn attach_webserver(
@ -38,15 +40,14 @@ pub async fn attach_webserver(
logger: Arc<dyn RawEventLogger>, logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>, code: Arc<dyn CodeSearch>,
) -> (Router, Router) { ) -> (Router, Router) {
let conn = db::DbConn::new().await.unwrap(); let ctx = create_service_locator(logger, code).await;
let ctx = Arc::new(ServerContext::new(conn, logger, code));
let schema = Arc::new(create_schema()); let schema = Arc::new(create_schema());
let api = api let api = api
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer)) .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route( .route(
"/graphql", "/graphql",
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()), routing::post(graphql::<Arc<Schema>, Arc<dyn ServiceLocator>>).with_state(ctx.clone()),
) )
.route("/graphql", routing::get(playground("/graphql", None))) .route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema)) .layer(Extension(schema))
@ -61,22 +62,22 @@ pub async fn attach_webserver(
} }
async fn distributed_tabby_layer( async fn distributed_tabby_layer(
State(ws): State<Arc<ServerContext>>, State(ws): State<Arc<dyn ServiceLocator>>,
request: Request<Body>, request: Request<Body>,
next: Next<Body>, next: Next<Body>,
) -> axum::response::Response { ) -> axum::response::Response {
ws.dispatch_request(request, next).await ws.worker().dispatch_request(request, next).await
} }
async fn ws_handler( async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
State(state): State<Arc<ServerContext>>, State(state): State<Arc<dyn ServiceLocator>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(state, socket, addr)) ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
} }
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) { async fn handle_socket(state: Arc<dyn ServiceLocator>, socket: WebSocket, addr: SocketAddr) {
let transport = WebSocketTransport::from(socket); let transport = WebSocketTransport::from(socket);
let server = BaseChannel::with_defaults(transport); let server = BaseChannel::with_defaults(transport);
let imp = Arc::new(HubImpl::new(state.clone(), addr)); let imp = Arc::new(HubImpl::new(state.clone(), addr));
@ -84,14 +85,14 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
} }
pub struct HubImpl { pub struct HubImpl {
ctx: Arc<ServerContext>, ctx: Arc<dyn ServiceLocator>,
conn: SocketAddr, conn: SocketAddr,
worker_addr: Arc<Mutex<String>>, worker_addr: Arc<Mutex<String>>,
} }
impl HubImpl { impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self { pub fn new(ctx: Arc<dyn ServiceLocator>, conn: SocketAddr) -> Self {
Self { Self {
ctx, ctx,
conn, conn,
@ -108,7 +109,7 @@ impl Drop for HubImpl {
tokio::spawn(async move { tokio::spawn(async move {
let worker_addr = worker_addr.lock().await; let worker_addr = worker_addr.lock().await;
if !worker_addr.is_empty() { if !worker_addr.is_empty() {
ctx.unregister_worker(worker_addr.as_str()).await; ctx.worker().unregister_worker(worker_addr.as_str()).await;
} }
}); });
} }
@ -134,7 +135,7 @@ impl Hub for Arc<HubImpl> {
"Empty worker token".to_string(), "Empty worker token".to_string(),
)); ));
} }
let server_token = match self.ctx.read_registration_token().await { let server_token = match self.ctx.worker().read_registration_token().await {
Ok(t) => t, Ok(t) => t,
Err(err) => { Err(err) => {
error!("fetch server token: {}", err.to_string()); error!("fetch server token: {}", err.to_string());
@ -167,11 +168,11 @@ impl Hub for Arc<HubImpl> {
cpu_count, cpu_count,
cuda_devices, cuda_devices,
}; };
self.ctx.register_worker(worker).await self.ctx.worker().register_worker(worker).await
} }
async fn log_event(self, _context: tarpc::context::Context, content: String) { async fn log_event(self, _context: tarpc::context::Context, content: String) {
self.ctx.logger.log(content) self.ctx.logger().log(content)
} }
async fn search( async fn search(
@ -181,7 +182,7 @@ impl Hub for Arc<HubImpl> {
limit: usize, limit: usize,
offset: usize, offset: usize,
) -> SearchResponse { ) -> SearchResponse {
match self.ctx.code.search(&q, limit, offset).await { match self.ctx.code().search(&q, limit, offset).await {
Ok(serp) => serp, Ok(serp) => serp,
Err(err) => { Err(err) => {
warn!("Failed to search: {}", err); warn!("Failed to search: {}", err);
@ -200,7 +201,7 @@ impl Hub for Arc<HubImpl> {
) -> SearchResponse { ) -> SearchResponse {
match self match self
.ctx .ctx
.code .code()
.search_in_language(&language, &tokens, limit, offset) .search_in_language(&language, &tokens, limit, offset)
.await .await
{ {

View File

@ -1,100 +0,0 @@
pub mod auth;
use std::sync::Arc;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};
use juniper_axum::FromAuth;
use crate::{
api::Worker,
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
server::{
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
ServerContext,
},
};
pub struct Context {
claims: Option<auth::Claims>,
server: Arc<ServerContext>,
}
impl FromAuth<Arc<ServerContext>> for Context {
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Context {}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.list_workers().await
}
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.read_registration_token().await?;
Ok(token)
}
}
#[derive(Default)]
pub struct Mutation;
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.server.reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}
async fn register(
ctx: &Context,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
ctx.server.auth().register(input).await
}
async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.server.auth().token_auth(input).await
}
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
}

View File

@ -1,38 +1,35 @@
use std::fmt::Debug; use std::fmt::Debug;
use async_trait::async_trait;
use jsonwebtoken as jwt; use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value}; use juniper::{FieldResult, GraphQLObject};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::ValidationError;
use crate::server::auth::JWT_DEFAULT_EXP; lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
#[derive(Debug)] jwt_token_secret().as_bytes()
pub struct ValidationErrors { );
pub errors: Vec<ValidationError>, static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
} }
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors { pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
fn into_field_error(self) -> FieldError<S> { let header = jwt::Header::default();
let errors = self let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
.errors Ok(token)
.into_iter() }
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));
FieldError::new("Invalid input parameters", ext.into()) pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
} let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}
fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
} }
#[derive(Debug, GraphQLObject)] #[derive(Debug, GraphQLObject)]
@ -127,3 +124,39 @@ impl Claims {
&self.user &self.user
} }
} }
#[async_trait]
pub trait AuthenticationService: Send + Sync {
async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty())
}
#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}

View File

@ -0,0 +1,133 @@
pub mod auth;
pub mod worker;
use std::sync::Arc;
use auth::AuthenticationService;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
Object, RootNode, ScalarValue, Value,
};
use juniper_axum::FromAuth;
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use validator::ValidationError;
use self::{auth::validate_jwt, worker::WorkerService};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
worker::Worker,
};
pub trait ServiceLocator: Send + Sync {
fn auth(&self) -> &dyn AuthenticationService;
fn worker(&self) -> &dyn WorkerService;
fn code(&self) -> &dyn CodeSearch;
fn logger(&self) -> &dyn RawEventLogger;
}
pub struct Context {
claims: Option<auth::Claims>,
server: Arc<dyn ServiceLocator>,
}
impl FromAuth<Arc<dyn ServiceLocator>> for Context {
fn build(server: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Context {}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.worker().list_workers().await
}
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.worker().read_registration_token().await?;
Ok(token)
}
}
#[derive(Default)]
pub struct Mutation;
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.server.worker().reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}
async fn register(
ctx: &Context,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
ctx.server
.auth()
.register(email, password1, password2)
.await
}
async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
ctx.server.auth().token_auth(email, password).await
}
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}
#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
}
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));
FieldError::new("Invalid input parameters", ext.into())
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
}

View File

@ -0,0 +1,51 @@
use anyhow::Result;
use async_trait::async_trait;
use axum::middleware::Next;
use hyper::{Body, Request};
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[async_trait]
pub trait WorkerService: Send + Sync {
async fn read_registration_token(&self) -> Result<String>;
async fn reset_registration_token(&self) -> Result<String>;
async fn list_workers(&self) -> Vec<Worker>;
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError>;
async fn unregister_worker(&self, worker_addr: &str);
async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response;
}

View File

@ -1,48 +1,35 @@
use std::env;
use argon2::{ use argon2::{
password_hash, password_hash,
password_hash::{rand_core::OsRng, SaltString}, password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher, PasswordVerifier, Argon2, PasswordHasher, PasswordVerifier,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldResult, IntoFieldError}; use juniper::{FieldResult, IntoFieldError};
use lazy_static::lazy_static;
use validator::Validate; use validator::Validate;
use crate::{ use super::db::DbConn;
db::DbConn, use crate::schema::{
schema::auth::{ auth::{
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
ValidationErrors, VerifyTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
}, },
ValidationErrors,
}; };
lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}
/// Input parameters for register mutation /// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters /// `validate` attribute is used to validate the input parameters
/// - `code` argument specifies which parameter causes the failure /// - `code` argument specifies which parameter causes the failure
/// - `message` argument provides client friendly error message /// - `message` argument provides client friendly error message
/// ///
#[derive(Validate)] #[derive(Validate)]
pub struct RegisterInput { struct RegisterInput {
#[validate(email(code = "email", message = "Email is invalid"))] #[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length( #[validate(length(
max = 128, max = 128,
code = "email", code = "email",
message = "Email must be at most 128 characters" message = "Email must be at most 128 characters"
))] ))]
pub email: String, email: String,
#[validate(length( #[validate(length(
min = 8, min = 8,
code = "password1", code = "password1",
@ -58,7 +45,7 @@ pub struct RegisterInput {
message = "Passwords do not match", message = "Passwords do not match",
other = "password2" other = "password2"
))] ))]
pub password1: String, password1: String,
#[validate(length( #[validate(length(
min = 8, min = 8,
code = "password2", code = "password2",
@ -69,7 +56,7 @@ pub struct RegisterInput {
code = "password2", code = "password2",
message = "Password must be at most 20 characters" message = "Password must be at most 20 characters"
))] ))]
pub password2: String, password2: String,
} }
impl std::fmt::Debug for RegisterInput { impl std::fmt::Debug for RegisterInput {
@ -85,14 +72,14 @@ impl std::fmt::Debug for RegisterInput {
/// Input parameters for token_auth mutation /// Input parameters for token_auth mutation
/// See `RegisterInput` for `validate` attribute usage /// See `RegisterInput` for `validate` attribute usage
#[derive(Validate)] #[derive(Validate)]
pub struct TokenAuthInput { struct TokenAuthInput {
#[validate(email(code = "email", message = "Email is invalid"))] #[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length( #[validate(length(
max = 128, max = 128,
code = "email", code = "email",
message = "Email must be at most 128 characters" message = "Email must be at most 128 characters"
))] ))]
pub email: String, email: String,
#[validate(length( #[validate(length(
min = 8, min = 8,
code = "password", code = "password",
@ -103,7 +90,7 @@ pub struct TokenAuthInput {
code = "password", code = "password",
message = "Password must be at most 20 characters" message = "Password must be at most 20 characters"
))] ))]
pub password: String, password: String,
} }
impl std::fmt::Debug for TokenAuthInput { impl std::fmt::Debug for TokenAuthInput {
@ -115,17 +102,19 @@ impl std::fmt::Debug for TokenAuthInput {
} }
} }
#[async_trait]
pub trait AuthenticationService {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
#[async_trait] #[async_trait]
impl AuthenticationService for DbConn { impl AuthenticationService for DbConn {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> { async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
input.validate().map_err(|err| { input.validate().map_err(|err| {
let errors = err let errors = err
.field_errors() .field_errors()
@ -138,7 +127,7 @@ impl AuthenticationService for DbConn {
})?; })?;
// check if email exists // check if email exists
if let Some(_) = self.get_user_by_email(&input.email).await? { if self.get_user_by_email(&input.email).await?.is_some() {
return Err("Email already exists".into()); return Err("Email already exists".into());
} }
@ -157,7 +146,8 @@ impl AuthenticationService for DbConn {
Ok(resp) Ok(resp)
} }
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> { async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
input.validate().map_err(|err| { input.validate().map_err(|err| {
let errors = err let errors = err
.field_errors() .field_errors()
@ -217,22 +207,6 @@ fn password_verify(raw: &str, hash: &str) -> bool {
} }
} }
fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}
fn jwt_token_secret() -> String {
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -254,23 +228,4 @@ mod tests {
assert!(password_verify(raw, &hash)); assert!(password_verify(raw, &hash));
assert!(!password_verify(raw, "invalid hash")); assert!(!password_verify(raw, "invalid hash"));
} }
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty())
}
#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
} }

View File

@ -1,63 +1,65 @@
pub mod auth; mod auth;
mod db;
mod proxy; mod proxy;
mod worker; mod worker;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait;
use axum::{http::Request, middleware::Next, response::IntoResponse}; use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode}; use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use self::db::DbConn;
api::{RegisterWorkerError, Worker, WorkerKind}, use crate::schema::{
db::DbConn, auth::AuthenticationService,
server::auth::AuthenticationService, worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
ServiceLocator,
}; };
pub struct ServerContext { struct ServerContext {
client: Client<HttpConnector>, client: Client<HttpConnector>,
completion: worker::WorkerGroup, completion: worker::WorkerGroup,
chat: worker::WorkerGroup, chat: worker::WorkerGroup,
db_conn: DbConn, db_conn: DbConn,
pub logger: Arc<dyn RawEventLogger>, logger: Arc<dyn RawEventLogger>,
pub code: Arc<dyn CodeSearch>, code: Arc<dyn CodeSearch>,
} }
impl ServerContext { impl ServerContext {
pub fn new( pub async fn new(logger: Arc<dyn RawEventLogger>, code: Arc<dyn CodeSearch>) -> Self {
db_conn: DbConn,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Self {
Self { Self {
client: Client::default(), client: Client::default(),
completion: worker::WorkerGroup::default(), completion: worker::WorkerGroup::default(),
chat: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(),
db_conn, db_conn: DbConn::new().await.unwrap(),
logger, logger,
code, code,
} }
} }
}
pub fn auth(&self) -> impl AuthenticationService { #[async_trait]
self.db_conn.clone() impl WorkerService for ServerContext {
}
/// Query current token from the database. /// Query current token from the database.
pub async fn read_registration_token(&self) -> Result<String> { async fn read_registration_token(&self) -> Result<String> {
self.db_conn.read_registration_token().await self.db_conn.read_registration_token().await
} }
/// Generate new token, and update it in the database. /// Generate new token, and update it in the database.
/// Return new token after update is done /// Return new token after update is done
pub async fn reset_registration_token(&self) -> Result<String> { async fn reset_registration_token(&self) -> Result<String> {
self.db_conn.reset_registration_token().await self.db_conn.reset_registration_token().await
} }
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> { async fn list_workers(&self) -> Vec<Worker> {
[self.completion.list().await, self.chat.list().await].concat()
}
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
let worker = match worker.kind { let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await,
@ -74,7 +76,7 @@ impl ServerContext {
} }
} }
pub async fn unregister_worker(&self, worker_addr: &str) { async fn unregister_worker(&self, worker_addr: &str) {
let kind = if self.chat.unregister(worker_addr).await { let kind = if self.chat.unregister(worker_addr).await {
WorkerKind::Chat WorkerKind::Chat
} else if self.completion.unregister(worker_addr).await { } else if self.completion.unregister(worker_addr).await {
@ -87,11 +89,7 @@ impl ServerContext {
info!("unregistering <{:?}> worker at {}", kind, worker_addr); info!("unregistering <{:?}> worker at {}", kind, worker_addr);
} }
pub async fn list_workers(&self) -> Vec<Worker> { async fn dispatch_request(
[self.completion.list().await, self.chat.list().await].concat()
}
pub async fn dispatch_request(
&self, &self,
request: Request<Body>, request: Request<Body>,
next: Next<Body>, next: Next<Body>,
@ -129,3 +127,28 @@ impl ServerContext {
} }
} }
} }
impl ServiceLocator for ServerContext {
fn auth(&self) -> &dyn AuthenticationService {
&self.db_conn
}
fn worker(&self) -> &dyn WorkerService {
self
}
fn code(&self) -> &dyn CodeSearch {
&*self.code
}
fn logger(&self) -> &dyn RawEventLogger {
&*self.logger
}
}
pub async fn create_service_locator(
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Arc<dyn ServiceLocator> {
Arc::new(ServerContext::new(logger, code).await)
}

View File

@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::error; use tracing::error;
use crate::api::Worker; use crate::schema::worker::Worker;
#[derive(Default)] #[derive(Default)]
pub struct WorkerGroup { pub struct WorkerGroup {
@ -61,7 +61,7 @@ fn random_index(size: usize) -> usize {
mod tests { mod tests {
use super::*; use super::*;
use crate::api::WorkerKind; use crate::schema::worker::WorkerKind;
#[tokio::test] #[tokio::test]
async fn test_worker_group() { async fn test_worker_group() {