2023-11-14 20:48:20 +00:00
|
|
|
pub mod api;
|
|
|
|
|
|
|
|
|
|
mod schema;
|
|
|
|
|
pub use schema::create_schema;
|
|
|
|
|
use websocket::WebSocketTransport;
|
|
|
|
|
|
|
|
|
|
mod server;
|
2023-11-09 18:51:07 +00:00
|
|
|
mod ui;
|
2023-11-14 20:48:20 +00:00
|
|
|
mod websocket;
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
use std::{net::SocketAddr, sync::Arc};
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
use api::{Hub, HubError, Worker, WorkerKind};
|
2023-11-09 18:51:07 +00:00
|
|
|
use axum::{
|
2023-11-14 20:48:20 +00:00
|
|
|
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
2023-11-09 18:51:07 +00:00
|
|
|
http::Request,
|
|
|
|
|
middleware::{from_fn_with_state, Next},
|
2023-11-14 20:48:20 +00:00
|
|
|
response::IntoResponse,
|
2023-11-12 22:52:28 +00:00
|
|
|
routing, Extension, Router,
|
2023-11-09 18:51:07 +00:00
|
|
|
};
|
2023-11-12 22:52:28 +00:00
|
|
|
use hyper::Body;
|
|
|
|
|
use juniper_axum::{graphiql, graphql, playground};
|
2023-11-14 20:48:20 +00:00
|
|
|
use schema::Schema;
|
|
|
|
|
use server::ServerContext;
|
|
|
|
|
use tarpc::server::{BaseChannel, Channel};
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
pub async fn attach_webserver(router: Router) -> Router {
|
|
|
|
|
let ctx = Arc::new(ServerContext::default());
|
|
|
|
|
let schema = Arc::new(create_schema());
|
2023-11-12 22:52:28 +00:00
|
|
|
|
|
|
|
|
let app = Router::new()
|
|
|
|
|
.route("/graphql", routing::get(playground("/graphql", None)))
|
|
|
|
|
.route("/graphiql", routing::get(graphiql("/graphql", None)))
|
|
|
|
|
.route(
|
|
|
|
|
"/graphql",
|
2023-11-14 20:48:20 +00:00
|
|
|
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
|
2023-11-12 22:52:28 +00:00
|
|
|
)
|
|
|
|
|
.layer(Extension(schema));
|
2023-11-09 18:51:07 +00:00
|
|
|
|
|
|
|
|
router
|
2023-11-12 22:52:28 +00:00
|
|
|
.merge(app)
|
2023-11-14 20:48:20 +00:00
|
|
|
.route("/hub", routing::get(ws_handler).with_state(ctx.clone()))
|
2023-11-09 18:51:07 +00:00
|
|
|
.fallback(ui::handler)
|
2023-11-14 20:48:20 +00:00
|
|
|
.layer(from_fn_with_state(ctx, distributed_tabby_layer))
|
2023-11-09 18:51:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn distributed_tabby_layer(
|
2023-11-14 20:48:20 +00:00
|
|
|
State(ws): State<Arc<ServerContext>>,
|
2023-11-09 18:51:07 +00:00
|
|
|
request: Request<Body>,
|
|
|
|
|
next: Next<Body>,
|
|
|
|
|
) -> axum::response::Response {
|
|
|
|
|
ws.dispatch_request(request, next).await
|
|
|
|
|
}
|
2023-11-14 20:48:20 +00:00
|
|
|
|
|
|
|
|
async fn ws_handler(
|
|
|
|
|
ws: WebSocketUpgrade,
|
|
|
|
|
State(state): State<Arc<ServerContext>>,
|
|
|
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
|
|
|
) -> impl IntoResponse {
|
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
|
|
|
|
|
let transport = WebSocketTransport::from(socket);
|
|
|
|
|
let server = BaseChannel::with_defaults(transport);
|
|
|
|
|
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
|
|
|
|
tokio::spawn(server.execute(imp.serve())).await.unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct HubImpl {
|
|
|
|
|
ctx: Arc<ServerContext>,
|
|
|
|
|
conn: SocketAddr,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl HubImpl {
|
|
|
|
|
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
|
|
|
|
Self { ctx, conn }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tarpc::server]
|
|
|
|
|
impl Hub for Arc<HubImpl> {
|
|
|
|
|
async fn register_worker(
|
|
|
|
|
self,
|
|
|
|
|
_context: tarpc::context::Context,
|
|
|
|
|
kind: WorkerKind,
|
|
|
|
|
port: i32,
|
|
|
|
|
name: String,
|
|
|
|
|
device: String,
|
|
|
|
|
arch: String,
|
|
|
|
|
cpu_info: String,
|
|
|
|
|
cpu_count: i32,
|
|
|
|
|
cuda_devices: Vec<String>,
|
|
|
|
|
) -> Result<Worker, HubError> {
|
|
|
|
|
let worker = Worker {
|
|
|
|
|
name,
|
|
|
|
|
kind,
|
|
|
|
|
addr: format!("http://{}:{}", self.conn.ip(), port),
|
|
|
|
|
device,
|
|
|
|
|
arch,
|
|
|
|
|
cpu_info,
|
|
|
|
|
cpu_count,
|
|
|
|
|
cuda_devices,
|
|
|
|
|
};
|
|
|
|
|
self.ctx.register_worker(worker).await
|
|
|
|
|
}
|
|
|
|
|
}
|