refactor: extract ServiceLocator interface (#933)
* refactor: cleanup the dependency chain - ServerContext should be the only thing being public in server mod * refactor ServiceLocator * refactor: extract worker.rs * refactor: move db as private repo of server * rename server -> serviceadd-signin-page
parent
1a9cbdcc3c
commit
5c52a71f77
|
|
@ -50,7 +50,7 @@ where
|
|||
let split = authorization.split_once(' ');
|
||||
match split {
|
||||
// Found proper bearer
|
||||
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
|
||||
Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))),
|
||||
_ => Ok(Self(None)),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod extract;
|
||||
pub mod response;
|
||||
|
||||
use std::{future};
|
||||
use std::future;
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, State},
|
||||
|
|
|
|||
|
|
@ -1,45 +1,13 @@
|
|||
use async_trait::async_trait;
|
||||
use juniper::{GraphQLEnum, GraphQLObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::api::{
|
||||
code::{CodeSearch, CodeSearchError, SearchResponse},
|
||||
event::RawEventLogger,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
|
||||
use crate::websocket::WebSocketTransport;
|
||||
|
||||
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
||||
pub enum WorkerKind {
|
||||
Completion,
|
||||
Chat,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
pub kind: WorkerKind,
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub device: String,
|
||||
pub arch: String,
|
||||
pub cpu_info: String,
|
||||
pub cpu_count: i32,
|
||||
pub cuda_devices: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||
pub enum RegisterWorkerError {
|
||||
#[error("Invalid token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
|
||||
#[error("Each hub client should only calls register_worker once")]
|
||||
RegisterWorkerOnce,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait Hub {
|
||||
async fn register_worker(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod api;
|
||||
|
||||
mod schema;
|
||||
use api::Hub;
|
||||
pub use schema::create_schema;
|
||||
use tabby_common::api::{
|
||||
code::{CodeSearch, SearchResponse},
|
||||
|
|
@ -10,15 +11,13 @@ use tokio::sync::Mutex;
|
|||
use tracing::{error, warn};
|
||||
use websocket::WebSocketTransport;
|
||||
|
||||
mod db;
|
||||
mod repositories;
|
||||
mod server;
|
||||
mod service;
|
||||
mod ui;
|
||||
mod websocket;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
|
||||
use axum::{
|
||||
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
||||
http::Request,
|
||||
|
|
@ -28,8 +27,11 @@ use axum::{
|
|||
};
|
||||
use hyper::Body;
|
||||
use juniper_axum::{graphiql, graphql, playground};
|
||||
use schema::Schema;
|
||||
use server::ServerContext;
|
||||
use schema::{
|
||||
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
|
||||
Schema, ServiceLocator,
|
||||
};
|
||||
use service::create_service_locator;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
|
||||
pub async fn attach_webserver(
|
||||
|
|
@ -38,15 +40,14 @@ pub async fn attach_webserver(
|
|||
logger: Arc<dyn RawEventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
) -> (Router, Router) {
|
||||
let conn = db::DbConn::new().await.unwrap();
|
||||
let ctx = Arc::new(ServerContext::new(conn, logger, code));
|
||||
let ctx = create_service_locator(logger, code).await;
|
||||
let schema = Arc::new(create_schema());
|
||||
|
||||
let api = api
|
||||
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
|
||||
.route(
|
||||
"/graphql",
|
||||
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
|
||||
routing::post(graphql::<Arc<Schema>, Arc<dyn ServiceLocator>>).with_state(ctx.clone()),
|
||||
)
|
||||
.route("/graphql", routing::get(playground("/graphql", None)))
|
||||
.layer(Extension(schema))
|
||||
|
|
@ -61,22 +62,22 @@ pub async fn attach_webserver(
|
|||
}
|
||||
|
||||
async fn distributed_tabby_layer(
|
||||
State(ws): State<Arc<ServerContext>>,
|
||||
State(ws): State<Arc<dyn ServiceLocator>>,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response {
|
||||
ws.dispatch_request(request, next).await
|
||||
ws.worker().dispatch_request(request, next).await
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<ServerContext>>,
|
||||
State(state): State<Arc<dyn ServiceLocator>>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
||||
}
|
||||
|
||||
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
|
||||
async fn handle_socket(state: Arc<dyn ServiceLocator>, socket: WebSocket, addr: SocketAddr) {
|
||||
let transport = WebSocketTransport::from(socket);
|
||||
let server = BaseChannel::with_defaults(transport);
|
||||
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
||||
|
|
@ -84,14 +85,14 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
|
|||
}
|
||||
|
||||
pub struct HubImpl {
|
||||
ctx: Arc<ServerContext>,
|
||||
ctx: Arc<dyn ServiceLocator>,
|
||||
conn: SocketAddr,
|
||||
|
||||
worker_addr: Arc<Mutex<String>>,
|
||||
}
|
||||
|
||||
impl HubImpl {
|
||||
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
||||
pub fn new(ctx: Arc<dyn ServiceLocator>, conn: SocketAddr) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
conn,
|
||||
|
|
@ -108,7 +109,7 @@ impl Drop for HubImpl {
|
|||
tokio::spawn(async move {
|
||||
let worker_addr = worker_addr.lock().await;
|
||||
if !worker_addr.is_empty() {
|
||||
ctx.unregister_worker(worker_addr.as_str()).await;
|
||||
ctx.worker().unregister_worker(worker_addr.as_str()).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -134,7 +135,7 @@ impl Hub for Arc<HubImpl> {
|
|||
"Empty worker token".to_string(),
|
||||
));
|
||||
}
|
||||
let server_token = match self.ctx.read_registration_token().await {
|
||||
let server_token = match self.ctx.worker().read_registration_token().await {
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
error!("fetch server token: {}", err.to_string());
|
||||
|
|
@ -167,11 +168,11 @@ impl Hub for Arc<HubImpl> {
|
|||
cpu_count,
|
||||
cuda_devices,
|
||||
};
|
||||
self.ctx.register_worker(worker).await
|
||||
self.ctx.worker().register_worker(worker).await
|
||||
}
|
||||
|
||||
async fn log_event(self, _context: tarpc::context::Context, content: String) {
|
||||
self.ctx.logger.log(content)
|
||||
self.ctx.logger().log(content)
|
||||
}
|
||||
|
||||
async fn search(
|
||||
|
|
@ -181,7 +182,7 @@ impl Hub for Arc<HubImpl> {
|
|||
limit: usize,
|
||||
offset: usize,
|
||||
) -> SearchResponse {
|
||||
match self.ctx.code.search(&q, limit, offset).await {
|
||||
match self.ctx.code().search(&q, limit, offset).await {
|
||||
Ok(serp) => serp,
|
||||
Err(err) => {
|
||||
warn!("Failed to search: {}", err);
|
||||
|
|
@ -200,7 +201,7 @@ impl Hub for Arc<HubImpl> {
|
|||
) -> SearchResponse {
|
||||
match self
|
||||
.ctx
|
||||
.code
|
||||
.code()
|
||||
.search_in_language(&language, &tokens, limit, offset)
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,100 +0,0 @@
|
|||
pub mod auth;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
||||
use juniper::{
|
||||
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
|
||||
};
|
||||
use juniper_axum::FromAuth;
|
||||
|
||||
use crate::{
|
||||
api::Worker,
|
||||
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
|
||||
server::{
|
||||
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
|
||||
ServerContext,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct Context {
|
||||
claims: Option<auth::Claims>,
|
||||
server: Arc<ServerContext>,
|
||||
}
|
||||
|
||||
impl FromAuth<Arc<ServerContext>> for Context {
|
||||
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
|
||||
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
|
||||
Self { claims, server }
|
||||
}
|
||||
}
|
||||
|
||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||
impl juniper::Context for Context {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Query;
|
||||
|
||||
#[graphql_object(context = Context)]
|
||||
impl Query {
|
||||
async fn workers(ctx: &Context) -> Vec<Worker> {
|
||||
ctx.server.list_workers().await
|
||||
}
|
||||
|
||||
async fn registration_token(ctx: &Context) -> FieldResult<String> {
|
||||
let token = ctx.server.read_registration_token().await?;
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Mutation;
|
||||
|
||||
#[graphql_object(context = Context)]
|
||||
impl Mutation {
|
||||
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.user_info().is_admin() {
|
||||
let reg_token = ctx.server.reset_registration_token().await?;
|
||||
return Ok(reg_token);
|
||||
}
|
||||
}
|
||||
Err(FieldError::new(
|
||||
"Only admin is able to reset registration token",
|
||||
graphql_value!("Unauthorized"),
|
||||
))
|
||||
}
|
||||
|
||||
async fn register(
|
||||
ctx: &Context,
|
||||
email: String,
|
||||
password1: String,
|
||||
password2: String,
|
||||
) -> FieldResult<RegisterResponse> {
|
||||
let input = RegisterInput {
|
||||
email,
|
||||
password1,
|
||||
password2,
|
||||
};
|
||||
ctx.server.auth().register(input).await
|
||||
}
|
||||
|
||||
async fn token_auth(
|
||||
ctx: &Context,
|
||||
email: String,
|
||||
password: String,
|
||||
) -> FieldResult<TokenAuthResponse> {
|
||||
let input = TokenAuthInput { email, password };
|
||||
ctx.server.auth().token_auth(input).await
|
||||
}
|
||||
|
||||
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
|
||||
ctx.server.auth().verify_token(token).await
|
||||
}
|
||||
}
|
||||
|
||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
||||
|
||||
pub fn create_schema() -> Schema {
|
||||
Schema::new(Query, Mutation, EmptySubscription::new())
|
||||
}
|
||||
|
|
@ -1,38 +1,35 @@
|
|||
use std::fmt::Debug;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use jsonwebtoken as jwt;
|
||||
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value};
|
||||
use juniper::{FieldResult, GraphQLObject};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validator::ValidationError;
|
||||
|
||||
use crate::server::auth::JWT_DEFAULT_EXP;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ValidationErrors {
|
||||
pub errors: Vec<ValidationError>,
|
||||
lazy_static! {
|
||||
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
||||
jwt_token_secret().as_bytes()
|
||||
);
|
||||
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
|
||||
jwt_token_secret().as_bytes()
|
||||
);
|
||||
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
|
||||
}
|
||||
|
||||
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
||||
fn into_field_error(self) -> FieldError<S> {
|
||||
let errors = self
|
||||
.errors
|
||||
.into_iter()
|
||||
.map(|err| {
|
||||
let mut obj = Object::with_capacity(2);
|
||||
obj.add_field("path", Value::scalar(err.code.to_string()));
|
||||
obj.add_field(
|
||||
"message",
|
||||
Value::scalar(err.message.unwrap_or_default().to_string()),
|
||||
);
|
||||
obj.into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut ext = Object::with_capacity(2);
|
||||
ext.add_field("code", Value::scalar("validation-error".to_string()));
|
||||
ext.add_field("errors", Value::list(errors));
|
||||
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
|
||||
let header = jwt::Header::default();
|
||||
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
FieldError::new("Invalid input parameters", ext.into())
|
||||
}
|
||||
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
|
||||
let validation = jwt::Validation::default();
|
||||
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
|
||||
Ok(data.claims)
|
||||
}
|
||||
|
||||
fn jwt_token_secret() -> String {
|
||||
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug, GraphQLObject)]
|
||||
|
|
@ -127,3 +124,39 @@ impl Claims {
|
|||
&self.user
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthenticationService: Send + Sync {
|
||||
async fn register(
|
||||
&self,
|
||||
email: String,
|
||||
password1: String,
|
||||
password2: String,
|
||||
) -> FieldResult<RegisterResponse>;
|
||||
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
|
||||
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
|
||||
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_generate_jwt() {
|
||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||
let token = generate_jwt(claims).unwrap();
|
||||
|
||||
assert!(!token.is_empty())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_jwt() {
|
||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||
let token = generate_jwt(claims).unwrap();
|
||||
let claims = validate_jwt(&token).unwrap();
|
||||
assert_eq!(
|
||||
claims.user_info(),
|
||||
&UserInfo::new("test".to_string(), false)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
pub mod auth;
|
||||
pub mod worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use auth::AuthenticationService;
|
||||
use juniper::{
|
||||
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
|
||||
Object, RootNode, ScalarValue, Value,
|
||||
};
|
||||
use juniper_axum::FromAuth;
|
||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||
use validator::ValidationError;
|
||||
|
||||
use self::{auth::validate_jwt, worker::WorkerService};
|
||||
use crate::schema::{
|
||||
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
|
||||
worker::Worker,
|
||||
};
|
||||
|
||||
pub trait ServiceLocator: Send + Sync {
|
||||
fn auth(&self) -> &dyn AuthenticationService;
|
||||
fn worker(&self) -> &dyn WorkerService;
|
||||
fn code(&self) -> &dyn CodeSearch;
|
||||
fn logger(&self) -> &dyn RawEventLogger;
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
claims: Option<auth::Claims>,
|
||||
server: Arc<dyn ServiceLocator>,
|
||||
}
|
||||
|
||||
impl FromAuth<Arc<dyn ServiceLocator>> for Context {
|
||||
fn build(server: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
|
||||
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
|
||||
Self { claims, server }
|
||||
}
|
||||
}
|
||||
|
||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||
impl juniper::Context for Context {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Query;
|
||||
|
||||
#[graphql_object(context = Context)]
|
||||
impl Query {
|
||||
async fn workers(ctx: &Context) -> Vec<Worker> {
|
||||
ctx.server.worker().list_workers().await
|
||||
}
|
||||
|
||||
async fn registration_token(ctx: &Context) -> FieldResult<String> {
|
||||
let token = ctx.server.worker().read_registration_token().await?;
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Mutation;
|
||||
|
||||
#[graphql_object(context = Context)]
|
||||
impl Mutation {
|
||||
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.user_info().is_admin() {
|
||||
let reg_token = ctx.server.worker().reset_registration_token().await?;
|
||||
return Ok(reg_token);
|
||||
}
|
||||
}
|
||||
Err(FieldError::new(
|
||||
"Only admin is able to reset registration token",
|
||||
graphql_value!("Unauthorized"),
|
||||
))
|
||||
}
|
||||
|
||||
async fn register(
|
||||
ctx: &Context,
|
||||
email: String,
|
||||
password1: String,
|
||||
password2: String,
|
||||
) -> FieldResult<RegisterResponse> {
|
||||
ctx.server
|
||||
.auth()
|
||||
.register(email, password1, password2)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn token_auth(
|
||||
ctx: &Context,
|
||||
email: String,
|
||||
password: String,
|
||||
) -> FieldResult<TokenAuthResponse> {
|
||||
ctx.server.auth().token_auth(email, password).await
|
||||
}
|
||||
|
||||
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
|
||||
ctx.server.auth().verify_token(token).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ValidationErrors {
|
||||
pub errors: Vec<ValidationError>,
|
||||
}
|
||||
|
||||
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
||||
fn into_field_error(self) -> FieldError<S> {
|
||||
let errors = self
|
||||
.errors
|
||||
.into_iter()
|
||||
.map(|err| {
|
||||
let mut obj = Object::with_capacity(2);
|
||||
obj.add_field("path", Value::scalar(err.code.to_string()));
|
||||
obj.add_field(
|
||||
"message",
|
||||
Value::scalar(err.message.unwrap_or_default().to_string()),
|
||||
);
|
||||
obj.into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut ext = Object::with_capacity(2);
|
||||
ext.add_field("code", Value::scalar("validation-error".to_string()));
|
||||
ext.add_field("errors", Value::list(errors));
|
||||
|
||||
FieldError::new("Invalid input parameters", ext.into())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
||||
|
||||
pub fn create_schema() -> Schema {
|
||||
Schema::new(Query, Mutation, EmptySubscription::new())
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use axum::middleware::Next;
|
||||
use hyper::{Body, Request};
|
||||
use juniper::{GraphQLEnum, GraphQLObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
||||
pub enum WorkerKind {
|
||||
Completion,
|
||||
Chat,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
pub kind: WorkerKind,
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub device: String,
|
||||
pub arch: String,
|
||||
pub cpu_info: String,
|
||||
pub cpu_count: i32,
|
||||
pub cuda_devices: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||
pub enum RegisterWorkerError {
|
||||
#[error("Invalid token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
|
||||
#[error("Each hub client should only calls register_worker once")]
|
||||
RegisterWorkerOnce,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait WorkerService: Send + Sync {
|
||||
async fn read_registration_token(&self) -> Result<String>;
|
||||
async fn reset_registration_token(&self) -> Result<String>;
|
||||
async fn list_workers(&self) -> Vec<Worker>;
|
||||
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError>;
|
||||
async fn unregister_worker(&self, worker_addr: &str);
|
||||
async fn dispatch_request(
|
||||
&self,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response;
|
||||
}
|
||||
|
|
@ -1,48 +1,35 @@
|
|||
use std::env;
|
||||
|
||||
use argon2::{
|
||||
password_hash,
|
||||
password_hash::{rand_core::OsRng, SaltString},
|
||||
Argon2, PasswordHasher, PasswordVerifier,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use jsonwebtoken as jwt;
|
||||
use juniper::{FieldResult, IntoFieldError};
|
||||
use lazy_static::lazy_static;
|
||||
use validator::Validate;
|
||||
|
||||
use crate::{
|
||||
db::DbConn,
|
||||
schema::auth::{
|
||||
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
|
||||
ValidationErrors, VerifyTokenResponse,
|
||||
use super::db::DbConn;
|
||||
use crate::schema::{
|
||||
auth::{
|
||||
generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
|
||||
RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
|
||||
},
|
||||
ValidationErrors,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
||||
jwt_token_secret().as_bytes()
|
||||
);
|
||||
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
|
||||
jwt_token_secret().as_bytes()
|
||||
);
|
||||
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
|
||||
}
|
||||
|
||||
/// Input parameters for register mutation
|
||||
/// `validate` attribute is used to validate the input parameters
|
||||
/// - `code` argument specifies which parameter causes the failure
|
||||
/// - `message` argument provides client friendly error message
|
||||
///
|
||||
#[derive(Validate)]
|
||||
pub struct RegisterInput {
|
||||
struct RegisterInput {
|
||||
#[validate(email(code = "email", message = "Email is invalid"))]
|
||||
#[validate(length(
|
||||
max = 128,
|
||||
code = "email",
|
||||
message = "Email must be at most 128 characters"
|
||||
))]
|
||||
pub email: String,
|
||||
email: String,
|
||||
#[validate(length(
|
||||
min = 8,
|
||||
code = "password1",
|
||||
|
|
@ -58,7 +45,7 @@ pub struct RegisterInput {
|
|||
message = "Passwords do not match",
|
||||
other = "password2"
|
||||
))]
|
||||
pub password1: String,
|
||||
password1: String,
|
||||
#[validate(length(
|
||||
min = 8,
|
||||
code = "password2",
|
||||
|
|
@ -69,7 +56,7 @@ pub struct RegisterInput {
|
|||
code = "password2",
|
||||
message = "Password must be at most 20 characters"
|
||||
))]
|
||||
pub password2: String,
|
||||
password2: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RegisterInput {
|
||||
|
|
@ -85,14 +72,14 @@ impl std::fmt::Debug for RegisterInput {
|
|||
/// Input parameters for token_auth mutation
|
||||
/// See `RegisterInput` for `validate` attribute usage
|
||||
#[derive(Validate)]
|
||||
pub struct TokenAuthInput {
|
||||
struct TokenAuthInput {
|
||||
#[validate(email(code = "email", message = "Email is invalid"))]
|
||||
#[validate(length(
|
||||
max = 128,
|
||||
code = "email",
|
||||
message = "Email must be at most 128 characters"
|
||||
))]
|
||||
pub email: String,
|
||||
email: String,
|
||||
#[validate(length(
|
||||
min = 8,
|
||||
code = "password",
|
||||
|
|
@ -103,7 +90,7 @@ pub struct TokenAuthInput {
|
|||
code = "password",
|
||||
message = "Password must be at most 20 characters"
|
||||
))]
|
||||
pub password: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TokenAuthInput {
|
||||
|
|
@ -115,17 +102,19 @@ impl std::fmt::Debug for TokenAuthInput {
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthenticationService {
|
||||
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
|
||||
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
|
||||
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
|
||||
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthenticationService for DbConn {
|
||||
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> {
|
||||
async fn register(
|
||||
&self,
|
||||
email: String,
|
||||
password1: String,
|
||||
password2: String,
|
||||
) -> FieldResult<RegisterResponse> {
|
||||
let input = RegisterInput {
|
||||
email,
|
||||
password1,
|
||||
password2,
|
||||
};
|
||||
input.validate().map_err(|err| {
|
||||
let errors = err
|
||||
.field_errors()
|
||||
|
|
@ -138,7 +127,7 @@ impl AuthenticationService for DbConn {
|
|||
})?;
|
||||
|
||||
// check if email exists
|
||||
if let Some(_) = self.get_user_by_email(&input.email).await? {
|
||||
if self.get_user_by_email(&input.email).await?.is_some() {
|
||||
return Err("Email already exists".into());
|
||||
}
|
||||
|
||||
|
|
@ -157,7 +146,8 @@ impl AuthenticationService for DbConn {
|
|||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> {
|
||||
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
|
||||
let input = TokenAuthInput { email, password };
|
||||
input.validate().map_err(|err| {
|
||||
let errors = err
|
||||
.field_errors()
|
||||
|
|
@ -217,22 +207,6 @@ fn password_verify(raw: &str, hash: &str) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
|
||||
let header = jwt::Header::default();
|
||||
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
|
||||
let validation = jwt::Validation::default();
|
||||
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
|
||||
Ok(data.claims)
|
||||
}
|
||||
|
||||
fn jwt_token_secret() -> String {
|
||||
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -254,23 +228,4 @@ mod tests {
|
|||
assert!(password_verify(raw, &hash));
|
||||
assert!(!password_verify(raw, "invalid hash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_jwt() {
|
||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||
let token = generate_jwt(claims).unwrap();
|
||||
|
||||
assert!(!token.is_empty())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_jwt() {
|
||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||
let token = generate_jwt(claims).unwrap();
|
||||
let claims = validate_jwt(&token).unwrap();
|
||||
assert_eq!(
|
||||
claims.user_info(),
|
||||
&UserInfo::new("test".to_string(), false)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,63 +1,65 @@
|
|||
pub mod auth;
|
||||
mod auth;
|
||||
mod db;
|
||||
mod proxy;
|
||||
mod worker;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
api::{RegisterWorkerError, Worker, WorkerKind},
|
||||
db::DbConn,
|
||||
server::auth::AuthenticationService,
|
||||
use self::db::DbConn;
|
||||
use crate::schema::{
|
||||
auth::AuthenticationService,
|
||||
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
|
||||
ServiceLocator,
|
||||
};
|
||||
|
||||
pub struct ServerContext {
|
||||
struct ServerContext {
|
||||
client: Client<HttpConnector>,
|
||||
completion: worker::WorkerGroup,
|
||||
chat: worker::WorkerGroup,
|
||||
db_conn: DbConn,
|
||||
|
||||
pub logger: Arc<dyn RawEventLogger>,
|
||||
pub code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn RawEventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
}
|
||||
|
||||
impl ServerContext {
|
||||
pub fn new(
|
||||
db_conn: DbConn,
|
||||
logger: Arc<dyn RawEventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
) -> Self {
|
||||
pub async fn new(logger: Arc<dyn RawEventLogger>, code: Arc<dyn CodeSearch>) -> Self {
|
||||
Self {
|
||||
client: Client::default(),
|
||||
completion: worker::WorkerGroup::default(),
|
||||
chat: worker::WorkerGroup::default(),
|
||||
db_conn,
|
||||
db_conn: DbConn::new().await.unwrap(),
|
||||
logger,
|
||||
code,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth(&self) -> impl AuthenticationService {
|
||||
self.db_conn.clone()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerService for ServerContext {
|
||||
/// Query current token from the database.
|
||||
pub async fn read_registration_token(&self) -> Result<String> {
|
||||
async fn read_registration_token(&self) -> Result<String> {
|
||||
self.db_conn.read_registration_token().await
|
||||
}
|
||||
|
||||
/// Generate new token, and update it in the database.
|
||||
/// Return new token after update is done
|
||||
pub async fn reset_registration_token(&self) -> Result<String> {
|
||||
async fn reset_registration_token(&self) -> Result<String> {
|
||||
self.db_conn.reset_registration_token().await
|
||||
}
|
||||
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
||||
async fn list_workers(&self) -> Vec<Worker> {
|
||||
[self.completion.list().await, self.chat.list().await].concat()
|
||||
}
|
||||
|
||||
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
||||
let worker = match worker.kind {
|
||||
WorkerKind::Completion => self.completion.register(worker).await,
|
||||
WorkerKind::Chat => self.chat.register(worker).await,
|
||||
|
|
@ -74,7 +76,7 @@ impl ServerContext {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn unregister_worker(&self, worker_addr: &str) {
|
||||
async fn unregister_worker(&self, worker_addr: &str) {
|
||||
let kind = if self.chat.unregister(worker_addr).await {
|
||||
WorkerKind::Chat
|
||||
} else if self.completion.unregister(worker_addr).await {
|
||||
|
|
@ -87,11 +89,7 @@ impl ServerContext {
|
|||
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
|
||||
}
|
||||
|
||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
||||
[self.completion.list().await, self.chat.list().await].concat()
|
||||
}
|
||||
|
||||
pub async fn dispatch_request(
|
||||
async fn dispatch_request(
|
||||
&self,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
|
|
@ -129,3 +127,28 @@ impl ServerContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServiceLocator for ServerContext {
|
||||
fn auth(&self) -> &dyn AuthenticationService {
|
||||
&self.db_conn
|
||||
}
|
||||
|
||||
fn worker(&self) -> &dyn WorkerService {
|
||||
self
|
||||
}
|
||||
|
||||
fn code(&self) -> &dyn CodeSearch {
|
||||
&*self.code
|
||||
}
|
||||
|
||||
fn logger(&self) -> &dyn RawEventLogger {
|
||||
&*self.logger
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_service_locator(
|
||||
logger: Arc<dyn RawEventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
) -> Arc<dyn ServiceLocator> {
|
||||
Arc::new(ServerContext::new(logger, code).await)
|
||||
}
|
||||
|
|
@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::error;
|
||||
|
||||
use crate::api::Worker;
|
||||
use crate::schema::worker::Worker;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WorkerGroup {
|
||||
|
|
@ -61,7 +61,7 @@ fn random_index(size: usize) -> usize {
|
|||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::api::WorkerKind;
|
||||
use crate::schema::worker::WorkerKind;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_group() {
|
||||
Loading…
Reference in New Issue