refactor: extract ServiceLocator interface (#933)

* refactor: cleanup the dependency chain - ServerContext should be the only thing being public in server mod

* refactor ServiceLocator

* refactor: extract worker.rs

* refactor: move db as private repo of server

* rename server -> service
add-signin-page
Meng Zhang 2023-12-01 22:16:59 +08:00 committed by GitHub
parent 1a9cbdcc3c
commit 5c52a71f77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 347 additions and 283 deletions

View File

@ -50,7 +50,7 @@ where
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))),
_ => Ok(Self(None)),
}
}

View File

@ -1,7 +1,7 @@
pub mod extract;
pub mod response;
use std::{future};
use std::future;
use axum::{
extract::{Extension, State},

View File

@ -1,45 +1,13 @@
use async_trait::async_trait;
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use tabby_common::api::{
code::{CodeSearch, CodeSearchError, SearchResponse},
event::RawEventLogger,
};
use thiserror::Error;
use tokio_tungstenite::connect_async;
pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[tarpc::service]
pub trait Hub {
async fn register_worker(

View File

@ -1,6 +1,7 @@
pub mod api;
mod schema;
use api::Hub;
pub use schema::create_schema;
use tabby_common::api::{
code::{CodeSearch, SearchResponse},
@ -10,15 +11,13 @@ use tokio::sync::Mutex;
use tracing::{error, warn};
use websocket::WebSocketTransport;
mod db;
mod repositories;
mod server;
mod service;
mod ui;
mod websocket;
use std::{net::SocketAddr, sync::Arc};
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
use axum::{
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request,
@ -28,8 +27,11 @@ use axum::{
};
use hyper::Body;
use juniper_axum::{graphiql, graphql, playground};
use schema::Schema;
use server::ServerContext;
use schema::{
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
Schema, ServiceLocator,
};
use service::create_service_locator;
use tarpc::server::{BaseChannel, Channel};
pub async fn attach_webserver(
@ -38,15 +40,14 @@ pub async fn attach_webserver(
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> (Router, Router) {
let conn = db::DbConn::new().await.unwrap();
let ctx = Arc::new(ServerContext::new(conn, logger, code));
let ctx = create_service_locator(logger, code).await;
let schema = Arc::new(create_schema());
let api = api
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
routing::post(graphql::<Arc<Schema>, Arc<dyn ServiceLocator>>).with_state(ctx.clone()),
)
.route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema))
@ -61,22 +62,22 @@ pub async fn attach_webserver(
}
async fn distributed_tabby_layer(
State(ws): State<Arc<ServerContext>>,
State(ws): State<Arc<dyn ServiceLocator>>,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
ws.dispatch_request(request, next).await
ws.worker().dispatch_request(request, next).await
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<ServerContext>>,
State(state): State<Arc<dyn ServiceLocator>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
}
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
async fn handle_socket(state: Arc<dyn ServiceLocator>, socket: WebSocket, addr: SocketAddr) {
let transport = WebSocketTransport::from(socket);
let server = BaseChannel::with_defaults(transport);
let imp = Arc::new(HubImpl::new(state.clone(), addr));
@ -84,14 +85,14 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
}
pub struct HubImpl {
ctx: Arc<ServerContext>,
ctx: Arc<dyn ServiceLocator>,
conn: SocketAddr,
worker_addr: Arc<Mutex<String>>,
}
impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
pub fn new(ctx: Arc<dyn ServiceLocator>, conn: SocketAddr) -> Self {
Self {
ctx,
conn,
@ -108,7 +109,7 @@ impl Drop for HubImpl {
tokio::spawn(async move {
let worker_addr = worker_addr.lock().await;
if !worker_addr.is_empty() {
ctx.unregister_worker(worker_addr.as_str()).await;
ctx.worker().unregister_worker(worker_addr.as_str()).await;
}
});
}
@ -134,7 +135,7 @@ impl Hub for Arc<HubImpl> {
"Empty worker token".to_string(),
));
}
let server_token = match self.ctx.read_registration_token().await {
let server_token = match self.ctx.worker().read_registration_token().await {
Ok(t) => t,
Err(err) => {
error!("fetch server token: {}", err.to_string());
@ -167,11 +168,11 @@ impl Hub for Arc<HubImpl> {
cpu_count,
cuda_devices,
};
self.ctx.register_worker(worker).await
self.ctx.worker().register_worker(worker).await
}
async fn log_event(self, _context: tarpc::context::Context, content: String) {
self.ctx.logger.log(content)
self.ctx.logger().log(content)
}
async fn search(
@ -181,7 +182,7 @@ impl Hub for Arc<HubImpl> {
limit: usize,
offset: usize,
) -> SearchResponse {
match self.ctx.code.search(&q, limit, offset).await {
match self.ctx.code().search(&q, limit, offset).await {
Ok(serp) => serp,
Err(err) => {
warn!("Failed to search: {}", err);
@ -200,7 +201,7 @@ impl Hub for Arc<HubImpl> {
) -> SearchResponse {
match self
.ctx
.code
.code()
.search_in_language(&language, &tokens, limit, offset)
.await
{

View File

@ -1,100 +0,0 @@
pub mod auth;
use std::sync::Arc;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
};
use juniper_axum::FromAuth;
use crate::{
api::Worker,
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
server::{
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
ServerContext,
},
};
pub struct Context {
claims: Option<auth::Claims>,
server: Arc<ServerContext>,
}
impl FromAuth<Arc<ServerContext>> for Context {
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Context {}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.list_workers().await
}
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.read_registration_token().await?;
Ok(token)
}
}
#[derive(Default)]
pub struct Mutation;
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.server.reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}
async fn register(
ctx: &Context,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
ctx.server.auth().register(input).await
}
async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
ctx.server.auth().token_auth(input).await
}
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
}

View File

@ -1,38 +1,35 @@
use std::fmt::Debug;
use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value};
use juniper::{FieldResult, GraphQLObject};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use validator::ValidationError;
use crate::server::auth::JWT_DEFAULT_EXP;
#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}
FieldError::new("Invalid input parameters", ext.into())
}
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}
fn jwt_token_secret() -> String {
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}
#[derive(Debug, GraphQLObject)]
@ -127,3 +124,39 @@ impl Claims {
&self.user
}
}
#[async_trait]
pub trait AuthenticationService: Send + Sync {
async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty())
}
#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}

View File

@ -0,0 +1,133 @@
pub mod auth;
pub mod worker;
use std::sync::Arc;
use auth::AuthenticationService;
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
Object, RootNode, ScalarValue, Value,
};
use juniper_axum::FromAuth;
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use validator::ValidationError;
use self::{auth::validate_jwt, worker::WorkerService};
use crate::schema::{
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
worker::Worker,
};
pub trait ServiceLocator: Send + Sync {
fn auth(&self) -> &dyn AuthenticationService;
fn worker(&self) -> &dyn WorkerService;
fn code(&self) -> &dyn CodeSearch;
fn logger(&self) -> &dyn RawEventLogger;
}
pub struct Context {
claims: Option<auth::Claims>,
server: Arc<dyn ServiceLocator>,
}
impl FromAuth<Arc<dyn ServiceLocator>> for Context {
fn build(server: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Context {}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.worker().list_workers().await
}
async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.worker().read_registration_token().await?;
Ok(token)
}
}
#[derive(Default)]
pub struct Mutation;
#[graphql_object(context = Context)]
impl Mutation {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.server.worker().reset_registration_token().await?;
return Ok(reg_token);
}
}
Err(FieldError::new(
"Only admin is able to reset registration token",
graphql_value!("Unauthorized"),
))
}
async fn register(
ctx: &Context,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
ctx.server
.auth()
.register(email, password1, password2)
.await
}
async fn token_auth(
ctx: &Context,
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
ctx.server.auth().token_auth(email, password).await
}
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
}
}
#[derive(Debug)]
pub struct ValidationErrors {
pub errors: Vec<ValidationError>,
}
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
fn into_field_error(self) -> FieldError<S> {
let errors = self
.errors
.into_iter()
.map(|err| {
let mut obj = Object::with_capacity(2);
obj.add_field("path", Value::scalar(err.code.to_string()));
obj.add_field(
"message",
Value::scalar(err.message.unwrap_or_default().to_string()),
);
obj.into()
})
.collect::<Vec<_>>();
let mut ext = Object::with_capacity(2);
ext.add_field("code", Value::scalar("validation-error".to_string()));
ext.add_field("errors", Value::list(errors));
FieldError::new("Invalid input parameters", ext.into())
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new())
}

View File

@ -0,0 +1,51 @@
use anyhow::Result;
use async_trait::async_trait;
use axum::middleware::Next;
use hyper::{Body, Request};
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[async_trait]
pub trait WorkerService: Send + Sync {
async fn read_registration_token(&self) -> Result<String>;
async fn reset_registration_token(&self) -> Result<String>;
async fn list_workers(&self) -> Vec<Worker>;
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError>;
async fn unregister_worker(&self, worker_addr: &str);
async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response;
}

View File

@ -1,48 +1,35 @@
use std::env;
use argon2::{
password_hash,
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher, PasswordVerifier,
};
use async_trait::async_trait;
use jsonwebtoken as jwt;
use juniper::{FieldResult, IntoFieldError};
use lazy_static::lazy_static;
use validator::Validate;
use crate::{
db::DbConn,
schema::auth::{
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
ValidationErrors, VerifyTokenResponse,
use super::db::DbConn;
use crate::schema::{
auth::{
generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
},
ValidationErrors,
};
lazy_static! {
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
jwt_token_secret().as_bytes()
);
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
jwt_token_secret().as_bytes()
);
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
}
/// Input parameters for register mutation
/// `validate` attribute is used to validate the input parameters
/// - `code` argument specifies which parameter causes the failure
/// - `message` argument provides client friendly error message
///
#[derive(Validate)]
pub struct RegisterInput {
struct RegisterInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
email: String,
#[validate(length(
min = 8,
code = "password1",
@ -58,7 +45,7 @@ pub struct RegisterInput {
message = "Passwords do not match",
other = "password2"
))]
pub password1: String,
password1: String,
#[validate(length(
min = 8,
code = "password2",
@ -69,7 +56,7 @@ pub struct RegisterInput {
code = "password2",
message = "Password must be at most 20 characters"
))]
pub password2: String,
password2: String,
}
impl std::fmt::Debug for RegisterInput {
@ -85,14 +72,14 @@ impl std::fmt::Debug for RegisterInput {
/// Input parameters for token_auth mutation
/// See `RegisterInput` for `validate` attribute usage
#[derive(Validate)]
pub struct TokenAuthInput {
struct TokenAuthInput {
#[validate(email(code = "email", message = "Email is invalid"))]
#[validate(length(
max = 128,
code = "email",
message = "Email must be at most 128 characters"
))]
pub email: String,
email: String,
#[validate(length(
min = 8,
code = "password",
@ -103,7 +90,7 @@ pub struct TokenAuthInput {
code = "password",
message = "Password must be at most 20 characters"
))]
pub password: String,
password: String,
}
impl std::fmt::Debug for TokenAuthInput {
@ -115,17 +102,19 @@ impl std::fmt::Debug for TokenAuthInput {
}
}
#[async_trait]
pub trait AuthenticationService {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
}
#[async_trait]
impl AuthenticationService for DbConn {
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> {
async fn register(
&self,
email: String,
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
let input = RegisterInput {
email,
password1,
password2,
};
input.validate().map_err(|err| {
let errors = err
.field_errors()
@ -138,7 +127,7 @@ impl AuthenticationService for DbConn {
})?;
// check if email exists
if let Some(_) = self.get_user_by_email(&input.email).await? {
if self.get_user_by_email(&input.email).await?.is_some() {
return Err("Email already exists".into());
}
@ -157,7 +146,8 @@ impl AuthenticationService for DbConn {
Ok(resp)
}
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> {
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password };
input.validate().map_err(|err| {
let errors = err
.field_errors()
@ -217,22 +207,6 @@ fn password_verify(raw: &str, hash: &str) -> bool {
}
}
fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
let header = jwt::Header::default();
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
Ok(token)
}
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
let validation = jwt::Validation::default();
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
Ok(data.claims)
}
fn jwt_token_secret() -> String {
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
@ -254,23 +228,4 @@ mod tests {
assert!(password_verify(raw, &hash));
assert!(!password_verify(raw, "invalid hash"));
}
#[test]
fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty())
}
#[test]
fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap();
assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
}
}

View File

@ -1,63 +1,65 @@
pub mod auth;
mod auth;
mod db;
mod proxy;
mod worker;
use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use async_trait::async_trait;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn};
use crate::{
api::{RegisterWorkerError, Worker, WorkerKind},
db::DbConn,
server::auth::AuthenticationService,
use self::db::DbConn;
use crate::schema::{
auth::AuthenticationService,
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
ServiceLocator,
};
pub struct ServerContext {
struct ServerContext {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
db_conn: DbConn,
pub logger: Arc<dyn RawEventLogger>,
pub code: Arc<dyn CodeSearch>,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
}
impl ServerContext {
pub fn new(
db_conn: DbConn,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Self {
pub async fn new(logger: Arc<dyn RawEventLogger>, code: Arc<dyn CodeSearch>) -> Self {
Self {
client: Client::default(),
completion: worker::WorkerGroup::default(),
chat: worker::WorkerGroup::default(),
db_conn,
db_conn: DbConn::new().await.unwrap(),
logger,
code,
}
}
}
pub fn auth(&self) -> impl AuthenticationService {
self.db_conn.clone()
}
#[async_trait]
impl WorkerService for ServerContext {
/// Query current token from the database.
pub async fn read_registration_token(&self) -> Result<String> {
async fn read_registration_token(&self) -> Result<String> {
self.db_conn.read_registration_token().await
}
/// Generate new token, and update it in the database.
/// Return new token after update is done
pub async fn reset_registration_token(&self) -> Result<String> {
async fn reset_registration_token(&self) -> Result<String> {
self.db_conn.reset_registration_token().await
}
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
async fn list_workers(&self) -> Vec<Worker> {
[self.completion.list().await, self.chat.list().await].concat()
}
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await,
@ -74,7 +76,7 @@ impl ServerContext {
}
}
pub async fn unregister_worker(&self, worker_addr: &str) {
async fn unregister_worker(&self, worker_addr: &str) {
let kind = if self.chat.unregister(worker_addr).await {
WorkerKind::Chat
} else if self.completion.unregister(worker_addr).await {
@ -87,11 +89,7 @@ impl ServerContext {
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
}
pub async fn list_workers(&self) -> Vec<Worker> {
[self.completion.list().await, self.chat.list().await].concat()
}
pub async fn dispatch_request(
async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
@ -129,3 +127,28 @@ impl ServerContext {
}
}
}
impl ServiceLocator for ServerContext {
fn auth(&self) -> &dyn AuthenticationService {
&self.db_conn
}
fn worker(&self) -> &dyn WorkerService {
self
}
fn code(&self) -> &dyn CodeSearch {
&*self.code
}
fn logger(&self) -> &dyn RawEventLogger {
&*self.logger
}
}
pub async fn create_service_locator(
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Arc<dyn ServiceLocator> {
Arc::new(ServerContext::new(logger, code).await)
}

View File

@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::error;
use crate::api::Worker;
use crate::schema::worker::Worker;
#[derive(Default)]
pub struct WorkerGroup {
@ -61,7 +61,7 @@ fn random_index(size: usize) -> usize {
mod tests {
use super::*;
use crate::api::WorkerKind;
use crate::schema::worker::WorkerKind;
#[tokio::test]
async fn test_worker_group() {