refactor: extract ServiceLocator interface (#933)
* refactor: cleanup the dependency chain - ServerContext should be the only thing being public in server mod * refactor ServiceLocator * refactor: extract worker.rs * refactor: move db as private repo of server * rename server -> serviceadd-signin-page
parent
1a9cbdcc3c
commit
5c52a71f77
|
|
@ -50,7 +50,7 @@ where
|
||||||
let split = authorization.split_once(' ');
|
let split = authorization.split_once(' ');
|
||||||
match split {
|
match split {
|
||||||
// Found proper bearer
|
// Found proper bearer
|
||||||
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
|
Some(("Bearer", contents)) => Ok(Self(Some(contents.to_owned()))),
|
||||||
_ => Ok(Self(None)),
|
_ => Ok(Self(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
pub mod extract;
|
pub mod extract;
|
||||||
pub mod response;
|
pub mod response;
|
||||||
|
|
||||||
use std::{future};
|
use std::future;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, State},
|
extract::{Extension, State},
|
||||||
|
|
|
||||||
|
|
@ -1,45 +1,13 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use juniper::{GraphQLEnum, GraphQLObject};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tabby_common::api::{
|
use tabby_common::api::{
|
||||||
code::{CodeSearch, CodeSearchError, SearchResponse},
|
code::{CodeSearch, CodeSearchError, SearchResponse},
|
||||||
event::RawEventLogger,
|
event::RawEventLogger,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
|
||||||
use tokio_tungstenite::connect_async;
|
use tokio_tungstenite::connect_async;
|
||||||
|
|
||||||
|
pub use crate::schema::worker::{RegisterWorkerError, Worker, WorkerKind};
|
||||||
use crate::websocket::WebSocketTransport;
|
use crate::websocket::WebSocketTransport;
|
||||||
|
|
||||||
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
|
||||||
pub enum WorkerKind {
|
|
||||||
Completion,
|
|
||||||
Chat,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
|
||||||
pub struct Worker {
|
|
||||||
pub kind: WorkerKind,
|
|
||||||
pub name: String,
|
|
||||||
pub addr: String,
|
|
||||||
pub device: String,
|
|
||||||
pub arch: String,
|
|
||||||
pub cpu_info: String,
|
|
||||||
pub cpu_count: i32,
|
|
||||||
pub cuda_devices: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
|
||||||
pub enum RegisterWorkerError {
|
|
||||||
#[error("Invalid token")]
|
|
||||||
InvalidToken(String),
|
|
||||||
|
|
||||||
#[error("Feature requires enterprise license")]
|
|
||||||
RequiresEnterpriseLicense,
|
|
||||||
|
|
||||||
#[error("Each hub client should only calls register_worker once")]
|
|
||||||
RegisterWorkerOnce,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
pub trait Hub {
|
pub trait Hub {
|
||||||
async fn register_worker(
|
async fn register_worker(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
pub mod api;
|
pub mod api;
|
||||||
|
|
||||||
mod schema;
|
mod schema;
|
||||||
|
use api::Hub;
|
||||||
pub use schema::create_schema;
|
pub use schema::create_schema;
|
||||||
use tabby_common::api::{
|
use tabby_common::api::{
|
||||||
code::{CodeSearch, SearchResponse},
|
code::{CodeSearch, SearchResponse},
|
||||||
|
|
@ -10,15 +11,13 @@ use tokio::sync::Mutex;
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
use websocket::WebSocketTransport;
|
use websocket::WebSocketTransport;
|
||||||
|
|
||||||
mod db;
|
|
||||||
mod repositories;
|
mod repositories;
|
||||||
mod server;
|
mod service;
|
||||||
mod ui;
|
mod ui;
|
||||||
mod websocket;
|
mod websocket;
|
||||||
|
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
||||||
http::Request,
|
http::Request,
|
||||||
|
|
@ -28,8 +27,11 @@ use axum::{
|
||||||
};
|
};
|
||||||
use hyper::Body;
|
use hyper::Body;
|
||||||
use juniper_axum::{graphiql, graphql, playground};
|
use juniper_axum::{graphiql, graphql, playground};
|
||||||
use schema::Schema;
|
use schema::{
|
||||||
use server::ServerContext;
|
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
|
||||||
|
Schema, ServiceLocator,
|
||||||
|
};
|
||||||
|
use service::create_service_locator;
|
||||||
use tarpc::server::{BaseChannel, Channel};
|
use tarpc::server::{BaseChannel, Channel};
|
||||||
|
|
||||||
pub async fn attach_webserver(
|
pub async fn attach_webserver(
|
||||||
|
|
@ -38,15 +40,14 @@ pub async fn attach_webserver(
|
||||||
logger: Arc<dyn RawEventLogger>,
|
logger: Arc<dyn RawEventLogger>,
|
||||||
code: Arc<dyn CodeSearch>,
|
code: Arc<dyn CodeSearch>,
|
||||||
) -> (Router, Router) {
|
) -> (Router, Router) {
|
||||||
let conn = db::DbConn::new().await.unwrap();
|
let ctx = create_service_locator(logger, code).await;
|
||||||
let ctx = Arc::new(ServerContext::new(conn, logger, code));
|
|
||||||
let schema = Arc::new(create_schema());
|
let schema = Arc::new(create_schema());
|
||||||
|
|
||||||
let api = api
|
let api = api
|
||||||
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
|
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
|
||||||
.route(
|
.route(
|
||||||
"/graphql",
|
"/graphql",
|
||||||
routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
|
routing::post(graphql::<Arc<Schema>, Arc<dyn ServiceLocator>>).with_state(ctx.clone()),
|
||||||
)
|
)
|
||||||
.route("/graphql", routing::get(playground("/graphql", None)))
|
.route("/graphql", routing::get(playground("/graphql", None)))
|
||||||
.layer(Extension(schema))
|
.layer(Extension(schema))
|
||||||
|
|
@ -61,22 +62,22 @@ pub async fn attach_webserver(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn distributed_tabby_layer(
|
async fn distributed_tabby_layer(
|
||||||
State(ws): State<Arc<ServerContext>>,
|
State(ws): State<Arc<dyn ServiceLocator>>,
|
||||||
request: Request<Body>,
|
request: Request<Body>,
|
||||||
next: Next<Body>,
|
next: Next<Body>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
ws.dispatch_request(request, next).await
|
ws.worker().dispatch_request(request, next).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ws_handler(
|
async fn ws_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
State(state): State<Arc<ServerContext>>,
|
State(state): State<Arc<dyn ServiceLocator>>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
|
async fn handle_socket(state: Arc<dyn ServiceLocator>, socket: WebSocket, addr: SocketAddr) {
|
||||||
let transport = WebSocketTransport::from(socket);
|
let transport = WebSocketTransport::from(socket);
|
||||||
let server = BaseChannel::with_defaults(transport);
|
let server = BaseChannel::with_defaults(transport);
|
||||||
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
||||||
|
|
@ -84,14 +85,14 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct HubImpl {
|
pub struct HubImpl {
|
||||||
ctx: Arc<ServerContext>,
|
ctx: Arc<dyn ServiceLocator>,
|
||||||
conn: SocketAddr,
|
conn: SocketAddr,
|
||||||
|
|
||||||
worker_addr: Arc<Mutex<String>>,
|
worker_addr: Arc<Mutex<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubImpl {
|
impl HubImpl {
|
||||||
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
pub fn new(ctx: Arc<dyn ServiceLocator>, conn: SocketAddr) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ctx,
|
ctx,
|
||||||
conn,
|
conn,
|
||||||
|
|
@ -108,7 +109,7 @@ impl Drop for HubImpl {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let worker_addr = worker_addr.lock().await;
|
let worker_addr = worker_addr.lock().await;
|
||||||
if !worker_addr.is_empty() {
|
if !worker_addr.is_empty() {
|
||||||
ctx.unregister_worker(worker_addr.as_str()).await;
|
ctx.worker().unregister_worker(worker_addr.as_str()).await;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -134,7 +135,7 @@ impl Hub for Arc<HubImpl> {
|
||||||
"Empty worker token".to_string(),
|
"Empty worker token".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let server_token = match self.ctx.read_registration_token().await {
|
let server_token = match self.ctx.worker().read_registration_token().await {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("fetch server token: {}", err.to_string());
|
error!("fetch server token: {}", err.to_string());
|
||||||
|
|
@ -167,11 +168,11 @@ impl Hub for Arc<HubImpl> {
|
||||||
cpu_count,
|
cpu_count,
|
||||||
cuda_devices,
|
cuda_devices,
|
||||||
};
|
};
|
||||||
self.ctx.register_worker(worker).await
|
self.ctx.worker().register_worker(worker).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn log_event(self, _context: tarpc::context::Context, content: String) {
|
async fn log_event(self, _context: tarpc::context::Context, content: String) {
|
||||||
self.ctx.logger.log(content)
|
self.ctx.logger().log(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn search(
|
async fn search(
|
||||||
|
|
@ -181,7 +182,7 @@ impl Hub for Arc<HubImpl> {
|
||||||
limit: usize,
|
limit: usize,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
) -> SearchResponse {
|
) -> SearchResponse {
|
||||||
match self.ctx.code.search(&q, limit, offset).await {
|
match self.ctx.code().search(&q, limit, offset).await {
|
||||||
Ok(serp) => serp,
|
Ok(serp) => serp,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to search: {}", err);
|
warn!("Failed to search: {}", err);
|
||||||
|
|
@ -200,7 +201,7 @@ impl Hub for Arc<HubImpl> {
|
||||||
) -> SearchResponse {
|
) -> SearchResponse {
|
||||||
match self
|
match self
|
||||||
.ctx
|
.ctx
|
||||||
.code
|
.code()
|
||||||
.search_in_language(&language, &tokens, limit, offset)
|
.search_in_language(&language, &tokens, limit, offset)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,100 +0,0 @@
|
||||||
pub mod auth;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
|
|
||||||
use juniper::{
|
|
||||||
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
|
|
||||||
};
|
|
||||||
use juniper_axum::FromAuth;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
api::Worker,
|
|
||||||
schema::auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
|
|
||||||
server::{
|
|
||||||
auth::{validate_jwt, AuthenticationService, RegisterInput, TokenAuthInput},
|
|
||||||
ServerContext,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub struct Context {
|
|
||||||
claims: Option<auth::Claims>,
|
|
||||||
server: Arc<ServerContext>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromAuth<Arc<ServerContext>> for Context {
|
|
||||||
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
|
|
||||||
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
|
|
||||||
Self { claims, server }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
|
||||||
impl juniper::Context for Context {}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct Query;
|
|
||||||
|
|
||||||
#[graphql_object(context = Context)]
|
|
||||||
impl Query {
|
|
||||||
async fn workers(ctx: &Context) -> Vec<Worker> {
|
|
||||||
ctx.server.list_workers().await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn registration_token(ctx: &Context) -> FieldResult<String> {
|
|
||||||
let token = ctx.server.read_registration_token().await?;
|
|
||||||
Ok(token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct Mutation;
|
|
||||||
|
|
||||||
#[graphql_object(context = Context)]
|
|
||||||
impl Mutation {
|
|
||||||
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
|
|
||||||
if let Some(claims) = &ctx.claims {
|
|
||||||
if claims.user_info().is_admin() {
|
|
||||||
let reg_token = ctx.server.reset_registration_token().await?;
|
|
||||||
return Ok(reg_token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(FieldError::new(
|
|
||||||
"Only admin is able to reset registration token",
|
|
||||||
graphql_value!("Unauthorized"),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn register(
|
|
||||||
ctx: &Context,
|
|
||||||
email: String,
|
|
||||||
password1: String,
|
|
||||||
password2: String,
|
|
||||||
) -> FieldResult<RegisterResponse> {
|
|
||||||
let input = RegisterInput {
|
|
||||||
email,
|
|
||||||
password1,
|
|
||||||
password2,
|
|
||||||
};
|
|
||||||
ctx.server.auth().register(input).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn token_auth(
|
|
||||||
ctx: &Context,
|
|
||||||
email: String,
|
|
||||||
password: String,
|
|
||||||
) -> FieldResult<TokenAuthResponse> {
|
|
||||||
let input = TokenAuthInput { email, password };
|
|
||||||
ctx.server.auth().token_auth(input).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
|
|
||||||
ctx.server.auth().verify_token(token).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
|
||||||
|
|
||||||
pub fn create_schema() -> Schema {
|
|
||||||
Schema::new(Query, Mutation, EmptySubscription::new())
|
|
||||||
}
|
|
||||||
|
|
@ -1,38 +1,35 @@
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use jsonwebtoken as jwt;
|
use jsonwebtoken as jwt;
|
||||||
use juniper::{FieldError, GraphQLObject, IntoFieldError, Object, ScalarValue, Value};
|
use juniper::{FieldResult, GraphQLObject};
|
||||||
|
use lazy_static::lazy_static;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use validator::ValidationError;
|
|
||||||
|
|
||||||
use crate::server::auth::JWT_DEFAULT_EXP;
|
lazy_static! {
|
||||||
|
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
||||||
#[derive(Debug)]
|
jwt_token_secret().as_bytes()
|
||||||
pub struct ValidationErrors {
|
);
|
||||||
pub errors: Vec<ValidationError>,
|
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
|
||||||
|
jwt_token_secret().as_bytes()
|
||||||
|
);
|
||||||
|
static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
pub fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
|
||||||
fn into_field_error(self) -> FieldError<S> {
|
let header = jwt::Header::default();
|
||||||
let errors = self
|
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
|
||||||
.errors
|
Ok(token)
|
||||||
.into_iter()
|
}
|
||||||
.map(|err| {
|
|
||||||
let mut obj = Object::with_capacity(2);
|
|
||||||
obj.add_field("path", Value::scalar(err.code.to_string()));
|
|
||||||
obj.add_field(
|
|
||||||
"message",
|
|
||||||
Value::scalar(err.message.unwrap_or_default().to_string()),
|
|
||||||
);
|
|
||||||
obj.into()
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let mut ext = Object::with_capacity(2);
|
|
||||||
ext.add_field("code", Value::scalar("validation-error".to_string()));
|
|
||||||
ext.add_field("errors", Value::list(errors));
|
|
||||||
|
|
||||||
FieldError::new("Invalid input parameters", ext.into())
|
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
|
||||||
}
|
let validation = jwt::Validation::default();
|
||||||
|
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
|
||||||
|
Ok(data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn jwt_token_secret() -> String {
|
||||||
|
std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, GraphQLObject)]
|
#[derive(Debug, GraphQLObject)]
|
||||||
|
|
@ -127,3 +124,39 @@ impl Claims {
|
||||||
&self.user
|
&self.user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait AuthenticationService: Send + Sync {
|
||||||
|
async fn register(
|
||||||
|
&self,
|
||||||
|
email: String,
|
||||||
|
password1: String,
|
||||||
|
password2: String,
|
||||||
|
) -> FieldResult<RegisterResponse>;
|
||||||
|
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
|
||||||
|
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
|
||||||
|
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
#[test]
|
||||||
|
fn test_generate_jwt() {
|
||||||
|
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||||
|
let token = generate_jwt(claims).unwrap();
|
||||||
|
|
||||||
|
assert!(!token.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_jwt() {
|
||||||
|
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
||||||
|
let token = generate_jwt(claims).unwrap();
|
||||||
|
let claims = validate_jwt(&token).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
claims.user_info(),
|
||||||
|
&UserInfo::new("test".to_string(), false)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
pub mod auth;
|
||||||
|
pub mod worker;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use auth::AuthenticationService;
|
||||||
|
use juniper::{
|
||||||
|
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, IntoFieldError,
|
||||||
|
Object, RootNode, ScalarValue, Value,
|
||||||
|
};
|
||||||
|
use juniper_axum::FromAuth;
|
||||||
|
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||||
|
use validator::ValidationError;
|
||||||
|
|
||||||
|
use self::{auth::validate_jwt, worker::WorkerService};
|
||||||
|
use crate::schema::{
|
||||||
|
auth::{RegisterResponse, TokenAuthResponse, VerifyTokenResponse},
|
||||||
|
worker::Worker,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub trait ServiceLocator: Send + Sync {
|
||||||
|
fn auth(&self) -> &dyn AuthenticationService;
|
||||||
|
fn worker(&self) -> &dyn WorkerService;
|
||||||
|
fn code(&self) -> &dyn CodeSearch;
|
||||||
|
fn logger(&self) -> &dyn RawEventLogger;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Context {
|
||||||
|
claims: Option<auth::Claims>,
|
||||||
|
server: Arc<dyn ServiceLocator>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromAuth<Arc<dyn ServiceLocator>> for Context {
|
||||||
|
fn build(server: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
|
||||||
|
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
|
||||||
|
Self { claims, server }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||||
|
impl juniper::Context for Context {}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct Query;
|
||||||
|
|
||||||
|
#[graphql_object(context = Context)]
|
||||||
|
impl Query {
|
||||||
|
async fn workers(ctx: &Context) -> Vec<Worker> {
|
||||||
|
ctx.server.worker().list_workers().await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn registration_token(ctx: &Context) -> FieldResult<String> {
|
||||||
|
let token = ctx.server.worker().read_registration_token().await?;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct Mutation;
|
||||||
|
|
||||||
|
#[graphql_object(context = Context)]
|
||||||
|
impl Mutation {
|
||||||
|
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
|
||||||
|
if let Some(claims) = &ctx.claims {
|
||||||
|
if claims.user_info().is_admin() {
|
||||||
|
let reg_token = ctx.server.worker().reset_registration_token().await?;
|
||||||
|
return Ok(reg_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(FieldError::new(
|
||||||
|
"Only admin is able to reset registration token",
|
||||||
|
graphql_value!("Unauthorized"),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register(
|
||||||
|
ctx: &Context,
|
||||||
|
email: String,
|
||||||
|
password1: String,
|
||||||
|
password2: String,
|
||||||
|
) -> FieldResult<RegisterResponse> {
|
||||||
|
ctx.server
|
||||||
|
.auth()
|
||||||
|
.register(email, password1, password2)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn token_auth(
|
||||||
|
ctx: &Context,
|
||||||
|
email: String,
|
||||||
|
password: String,
|
||||||
|
) -> FieldResult<TokenAuthResponse> {
|
||||||
|
ctx.server.auth().token_auth(email, password).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
|
||||||
|
ctx.server.auth().verify_token(token).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ValidationErrors {
|
||||||
|
pub errors: Vec<ValidationError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: ScalarValue> IntoFieldError<S> for ValidationErrors {
|
||||||
|
fn into_field_error(self) -> FieldError<S> {
|
||||||
|
let errors = self
|
||||||
|
.errors
|
||||||
|
.into_iter()
|
||||||
|
.map(|err| {
|
||||||
|
let mut obj = Object::with_capacity(2);
|
||||||
|
obj.add_field("path", Value::scalar(err.code.to_string()));
|
||||||
|
obj.add_field(
|
||||||
|
"message",
|
||||||
|
Value::scalar(err.message.unwrap_or_default().to_string()),
|
||||||
|
);
|
||||||
|
obj.into()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let mut ext = Object::with_capacity(2);
|
||||||
|
ext.add_field("code", Value::scalar("validation-error".to_string()));
|
||||||
|
ext.add_field("errors", Value::list(errors));
|
||||||
|
|
||||||
|
FieldError::new("Invalid input parameters", ext.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
|
||||||
|
|
||||||
|
pub fn create_schema() -> Schema {
|
||||||
|
Schema::new(Query, Mutation, EmptySubscription::new())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use hyper::{Body, Request};
|
||||||
|
use juniper::{GraphQLEnum, GraphQLObject};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
||||||
|
pub enum WorkerKind {
|
||||||
|
Completion,
|
||||||
|
Chat,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
||||||
|
pub struct Worker {
|
||||||
|
pub kind: WorkerKind,
|
||||||
|
pub name: String,
|
||||||
|
pub addr: String,
|
||||||
|
pub device: String,
|
||||||
|
pub arch: String,
|
||||||
|
pub cpu_info: String,
|
||||||
|
pub cpu_count: i32,
|
||||||
|
pub cuda_devices: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||||
|
pub enum RegisterWorkerError {
|
||||||
|
#[error("Invalid token")]
|
||||||
|
InvalidToken(String),
|
||||||
|
|
||||||
|
#[error("Feature requires enterprise license")]
|
||||||
|
RequiresEnterpriseLicense,
|
||||||
|
|
||||||
|
#[error("Each hub client should only calls register_worker once")]
|
||||||
|
RegisterWorkerOnce,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait WorkerService: Send + Sync {
|
||||||
|
async fn read_registration_token(&self) -> Result<String>;
|
||||||
|
async fn reset_registration_token(&self) -> Result<String>;
|
||||||
|
async fn list_workers(&self) -> Vec<Worker>;
|
||||||
|
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError>;
|
||||||
|
async fn unregister_worker(&self, worker_addr: &str);
|
||||||
|
async fn dispatch_request(
|
||||||
|
&self,
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next<Body>,
|
||||||
|
) -> axum::response::Response;
|
||||||
|
}
|
||||||
|
|
@ -1,48 +1,35 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use argon2::{
|
use argon2::{
|
||||||
password_hash,
|
password_hash,
|
||||||
password_hash::{rand_core::OsRng, SaltString},
|
password_hash::{rand_core::OsRng, SaltString},
|
||||||
Argon2, PasswordHasher, PasswordVerifier,
|
Argon2, PasswordHasher, PasswordVerifier,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use jsonwebtoken as jwt;
|
|
||||||
use juniper::{FieldResult, IntoFieldError};
|
use juniper::{FieldResult, IntoFieldError};
|
||||||
use lazy_static::lazy_static;
|
|
||||||
use validator::Validate;
|
use validator::Validate;
|
||||||
|
|
||||||
use crate::{
|
use super::db::DbConn;
|
||||||
db::DbConn,
|
use crate::schema::{
|
||||||
schema::auth::{
|
auth::{
|
||||||
Claims, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UserInfo,
|
generate_jwt, validate_jwt, AuthenticationService, Claims, RefreshTokenResponse,
|
||||||
ValidationErrors, VerifyTokenResponse,
|
RegisterResponse, TokenAuthResponse, UserInfo, VerifyTokenResponse,
|
||||||
},
|
},
|
||||||
|
ValidationErrors,
|
||||||
};
|
};
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
static ref JWT_ENCODING_KEY: jwt::EncodingKey = jwt::EncodingKey::from_secret(
|
|
||||||
jwt_token_secret().as_bytes()
|
|
||||||
);
|
|
||||||
static ref JWT_DECODING_KEY: jwt::DecodingKey = jwt::DecodingKey::from_secret(
|
|
||||||
jwt_token_secret().as_bytes()
|
|
||||||
);
|
|
||||||
pub static ref JWT_DEFAULT_EXP: u64 = 30 * 60; // 30 minutes
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Input parameters for register mutation
|
/// Input parameters for register mutation
|
||||||
/// `validate` attribute is used to validate the input parameters
|
/// `validate` attribute is used to validate the input parameters
|
||||||
/// - `code` argument specifies which parameter causes the failure
|
/// - `code` argument specifies which parameter causes the failure
|
||||||
/// - `message` argument provides client friendly error message
|
/// - `message` argument provides client friendly error message
|
||||||
///
|
///
|
||||||
#[derive(Validate)]
|
#[derive(Validate)]
|
||||||
pub struct RegisterInput {
|
struct RegisterInput {
|
||||||
#[validate(email(code = "email", message = "Email is invalid"))]
|
#[validate(email(code = "email", message = "Email is invalid"))]
|
||||||
#[validate(length(
|
#[validate(length(
|
||||||
max = 128,
|
max = 128,
|
||||||
code = "email",
|
code = "email",
|
||||||
message = "Email must be at most 128 characters"
|
message = "Email must be at most 128 characters"
|
||||||
))]
|
))]
|
||||||
pub email: String,
|
email: String,
|
||||||
#[validate(length(
|
#[validate(length(
|
||||||
min = 8,
|
min = 8,
|
||||||
code = "password1",
|
code = "password1",
|
||||||
|
|
@ -58,7 +45,7 @@ pub struct RegisterInput {
|
||||||
message = "Passwords do not match",
|
message = "Passwords do not match",
|
||||||
other = "password2"
|
other = "password2"
|
||||||
))]
|
))]
|
||||||
pub password1: String,
|
password1: String,
|
||||||
#[validate(length(
|
#[validate(length(
|
||||||
min = 8,
|
min = 8,
|
||||||
code = "password2",
|
code = "password2",
|
||||||
|
|
@ -69,7 +56,7 @@ pub struct RegisterInput {
|
||||||
code = "password2",
|
code = "password2",
|
||||||
message = "Password must be at most 20 characters"
|
message = "Password must be at most 20 characters"
|
||||||
))]
|
))]
|
||||||
pub password2: String,
|
password2: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for RegisterInput {
|
impl std::fmt::Debug for RegisterInput {
|
||||||
|
|
@ -85,14 +72,14 @@ impl std::fmt::Debug for RegisterInput {
|
||||||
/// Input parameters for token_auth mutation
|
/// Input parameters for token_auth mutation
|
||||||
/// See `RegisterInput` for `validate` attribute usage
|
/// See `RegisterInput` for `validate` attribute usage
|
||||||
#[derive(Validate)]
|
#[derive(Validate)]
|
||||||
pub struct TokenAuthInput {
|
struct TokenAuthInput {
|
||||||
#[validate(email(code = "email", message = "Email is invalid"))]
|
#[validate(email(code = "email", message = "Email is invalid"))]
|
||||||
#[validate(length(
|
#[validate(length(
|
||||||
max = 128,
|
max = 128,
|
||||||
code = "email",
|
code = "email",
|
||||||
message = "Email must be at most 128 characters"
|
message = "Email must be at most 128 characters"
|
||||||
))]
|
))]
|
||||||
pub email: String,
|
email: String,
|
||||||
#[validate(length(
|
#[validate(length(
|
||||||
min = 8,
|
min = 8,
|
||||||
code = "password",
|
code = "password",
|
||||||
|
|
@ -103,7 +90,7 @@ pub struct TokenAuthInput {
|
||||||
code = "password",
|
code = "password",
|
||||||
message = "Password must be at most 20 characters"
|
message = "Password must be at most 20 characters"
|
||||||
))]
|
))]
|
||||||
pub password: String,
|
password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for TokenAuthInput {
|
impl std::fmt::Debug for TokenAuthInput {
|
||||||
|
|
@ -115,17 +102,19 @@ impl std::fmt::Debug for TokenAuthInput {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait AuthenticationService {
|
|
||||||
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse>;
|
|
||||||
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse>;
|
|
||||||
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
|
|
||||||
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl AuthenticationService for DbConn {
|
impl AuthenticationService for DbConn {
|
||||||
async fn register(&self, input: RegisterInput) -> FieldResult<RegisterResponse> {
|
async fn register(
|
||||||
|
&self,
|
||||||
|
email: String,
|
||||||
|
password1: String,
|
||||||
|
password2: String,
|
||||||
|
) -> FieldResult<RegisterResponse> {
|
||||||
|
let input = RegisterInput {
|
||||||
|
email,
|
||||||
|
password1,
|
||||||
|
password2,
|
||||||
|
};
|
||||||
input.validate().map_err(|err| {
|
input.validate().map_err(|err| {
|
||||||
let errors = err
|
let errors = err
|
||||||
.field_errors()
|
.field_errors()
|
||||||
|
|
@ -138,7 +127,7 @@ impl AuthenticationService for DbConn {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// check if email exists
|
// check if email exists
|
||||||
if let Some(_) = self.get_user_by_email(&input.email).await? {
|
if self.get_user_by_email(&input.email).await?.is_some() {
|
||||||
return Err("Email already exists".into());
|
return Err("Email already exists".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -157,7 +146,8 @@ impl AuthenticationService for DbConn {
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn token_auth(&self, input: TokenAuthInput) -> FieldResult<TokenAuthResponse> {
|
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse> {
|
||||||
|
let input = TokenAuthInput { email, password };
|
||||||
input.validate().map_err(|err| {
|
input.validate().map_err(|err| {
|
||||||
let errors = err
|
let errors = err
|
||||||
.field_errors()
|
.field_errors()
|
||||||
|
|
@ -217,22 +207,6 @@ fn password_verify(raw: &str, hash: &str) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_jwt(claims: Claims) -> jwt::errors::Result<String> {
|
|
||||||
let header = jwt::Header::default();
|
|
||||||
let token = jwt::encode(&header, &claims, &JWT_ENCODING_KEY)?;
|
|
||||||
Ok(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_jwt(token: &str) -> jwt::errors::Result<Claims> {
|
|
||||||
let validation = jwt::Validation::default();
|
|
||||||
let data = jwt::decode::<Claims>(token, &JWT_DECODING_KEY, &validation)?;
|
|
||||||
Ok(data.claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn jwt_token_secret() -> String {
|
|
||||||
env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET").unwrap_or("default_secret".to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -254,23 +228,4 @@ mod tests {
|
||||||
assert!(password_verify(raw, &hash));
|
assert!(password_verify(raw, &hash));
|
||||||
assert!(!password_verify(raw, "invalid hash"));
|
assert!(!password_verify(raw, "invalid hash"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_generate_jwt() {
|
|
||||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
|
||||||
let token = generate_jwt(claims).unwrap();
|
|
||||||
|
|
||||||
assert!(!token.is_empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_jwt() {
|
|
||||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
|
||||||
let token = generate_jwt(claims).unwrap();
|
|
||||||
let claims = validate_jwt(&token).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
claims.user_info(),
|
|
||||||
&UserInfo::new("test".to_string(), false)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,63 +1,65 @@
|
||||||
pub mod auth;
|
mod auth;
|
||||||
|
mod db;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
mod worker;
|
mod worker;
|
||||||
|
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
||||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||||
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::{
|
use self::db::DbConn;
|
||||||
api::{RegisterWorkerError, Worker, WorkerKind},
|
use crate::schema::{
|
||||||
db::DbConn,
|
auth::AuthenticationService,
|
||||||
server::auth::AuthenticationService,
|
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
|
||||||
|
ServiceLocator,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ServerContext {
|
struct ServerContext {
|
||||||
client: Client<HttpConnector>,
|
client: Client<HttpConnector>,
|
||||||
completion: worker::WorkerGroup,
|
completion: worker::WorkerGroup,
|
||||||
chat: worker::WorkerGroup,
|
chat: worker::WorkerGroup,
|
||||||
db_conn: DbConn,
|
db_conn: DbConn,
|
||||||
|
|
||||||
pub logger: Arc<dyn RawEventLogger>,
|
logger: Arc<dyn RawEventLogger>,
|
||||||
pub code: Arc<dyn CodeSearch>,
|
code: Arc<dyn CodeSearch>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerContext {
|
impl ServerContext {
|
||||||
pub fn new(
|
pub async fn new(logger: Arc<dyn RawEventLogger>, code: Arc<dyn CodeSearch>) -> Self {
|
||||||
db_conn: DbConn,
|
|
||||||
logger: Arc<dyn RawEventLogger>,
|
|
||||||
code: Arc<dyn CodeSearch>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
Self {
|
||||||
client: Client::default(),
|
client: Client::default(),
|
||||||
completion: worker::WorkerGroup::default(),
|
completion: worker::WorkerGroup::default(),
|
||||||
chat: worker::WorkerGroup::default(),
|
chat: worker::WorkerGroup::default(),
|
||||||
db_conn,
|
db_conn: DbConn::new().await.unwrap(),
|
||||||
logger,
|
logger,
|
||||||
code,
|
code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn auth(&self) -> impl AuthenticationService {
|
#[async_trait]
|
||||||
self.db_conn.clone()
|
impl WorkerService for ServerContext {
|
||||||
}
|
|
||||||
|
|
||||||
/// Query current token from the database.
|
/// Query current token from the database.
|
||||||
pub async fn read_registration_token(&self) -> Result<String> {
|
async fn read_registration_token(&self) -> Result<String> {
|
||||||
self.db_conn.read_registration_token().await
|
self.db_conn.read_registration_token().await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate new token, and update it in the database.
|
/// Generate new token, and update it in the database.
|
||||||
/// Return new token after update is done
|
/// Return new token after update is done
|
||||||
pub async fn reset_registration_token(&self) -> Result<String> {
|
async fn reset_registration_token(&self) -> Result<String> {
|
||||||
self.db_conn.reset_registration_token().await
|
self.db_conn.reset_registration_token().await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
async fn list_workers(&self) -> Vec<Worker> {
|
||||||
|
[self.completion.list().await, self.chat.list().await].concat()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
||||||
let worker = match worker.kind {
|
let worker = match worker.kind {
|
||||||
WorkerKind::Completion => self.completion.register(worker).await,
|
WorkerKind::Completion => self.completion.register(worker).await,
|
||||||
WorkerKind::Chat => self.chat.register(worker).await,
|
WorkerKind::Chat => self.chat.register(worker).await,
|
||||||
|
|
@ -74,7 +76,7 @@ impl ServerContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn unregister_worker(&self, worker_addr: &str) {
|
async fn unregister_worker(&self, worker_addr: &str) {
|
||||||
let kind = if self.chat.unregister(worker_addr).await {
|
let kind = if self.chat.unregister(worker_addr).await {
|
||||||
WorkerKind::Chat
|
WorkerKind::Chat
|
||||||
} else if self.completion.unregister(worker_addr).await {
|
} else if self.completion.unregister(worker_addr).await {
|
||||||
|
|
@ -87,11 +89,7 @@ impl ServerContext {
|
||||||
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
|
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
async fn dispatch_request(
|
||||||
[self.completion.list().await, self.chat.list().await].concat()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn dispatch_request(
|
|
||||||
&self,
|
&self,
|
||||||
request: Request<Body>,
|
request: Request<Body>,
|
||||||
next: Next<Body>,
|
next: Next<Body>,
|
||||||
|
|
@ -129,3 +127,28 @@ impl ServerContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ServiceLocator for ServerContext {
|
||||||
|
fn auth(&self) -> &dyn AuthenticationService {
|
||||||
|
&self.db_conn
|
||||||
|
}
|
||||||
|
|
||||||
|
fn worker(&self) -> &dyn WorkerService {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn code(&self) -> &dyn CodeSearch {
|
||||||
|
&*self.code
|
||||||
|
}
|
||||||
|
|
||||||
|
fn logger(&self) -> &dyn RawEventLogger {
|
||||||
|
&*self.logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_service_locator(
|
||||||
|
logger: Arc<dyn RawEventLogger>,
|
||||||
|
code: Arc<dyn CodeSearch>,
|
||||||
|
) -> Arc<dyn ServiceLocator> {
|
||||||
|
Arc::new(ServerContext::new(logger, code).await)
|
||||||
|
}
|
||||||
|
|
@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::api::Worker;
|
use crate::schema::worker::Worker;
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct WorkerGroup {
|
pub struct WorkerGroup {
|
||||||
|
|
@ -61,7 +61,7 @@ fn random_index(size: usize) -> usize {
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::api::WorkerKind;
|
use crate::schema::worker::WorkerKind;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_group() {
|
async fn test_worker_group() {
|
||||||
Loading…
Reference in New Issue