2023-11-14 20:48:20 +00:00
|
|
|
pub mod api;
|
|
|
|
|
|
|
|
|
|
mod schema;
|
|
|
|
|
pub use schema::create_schema;
|
2023-11-18 23:45:00 +00:00
|
|
|
use tabby_common::api::{
|
|
|
|
|
code::{CodeSearch, SearchResponse},
|
|
|
|
|
event::RawEventLogger,
|
|
|
|
|
};
|
|
|
|
|
use tracing::{error, warn};
|
2023-11-14 20:48:20 +00:00
|
|
|
use websocket::WebSocketTransport;
|
|
|
|
|
|
2023-11-17 07:05:39 +00:00
|
|
|
mod db;
|
2023-11-14 20:48:20 +00:00
|
|
|
mod server;
|
2023-11-09 18:51:07 +00:00
|
|
|
mod ui;
|
2023-11-14 20:48:20 +00:00
|
|
|
mod websocket;
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
use std::{net::SocketAddr, sync::Arc};
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
use api::{Hub, HubError, Worker, WorkerKind};
|
2023-11-09 18:51:07 +00:00
|
|
|
use axum::{
|
2023-11-14 20:48:20 +00:00
|
|
|
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
2023-11-09 18:51:07 +00:00
|
|
|
http::Request,
|
|
|
|
|
middleware::{from_fn_with_state, Next},
|
2023-11-14 20:48:20 +00:00
|
|
|
response::IntoResponse,
|
2023-11-12 22:52:28 +00:00
|
|
|
routing, Extension, Router,
|
2023-11-09 18:51:07 +00:00
|
|
|
};
|
2023-11-12 22:52:28 +00:00
|
|
|
use hyper::Body;
|
|
|
|
|
use juniper_axum::{graphiql, graphql, playground};
|
2023-11-14 20:48:20 +00:00
|
|
|
use schema::Schema;
|
|
|
|
|
use server::ServerContext;
|
|
|
|
|
use tarpc::server::{BaseChannel, Channel};
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-18 23:45:00 +00:00
|
|
|
pub async fn attach_webserver(
|
2023-11-20 01:00:35 +00:00
|
|
|
api: Router,
|
|
|
|
|
ui: Router,
|
2023-11-18 23:45:00 +00:00
|
|
|
logger: Arc<dyn RawEventLogger>,
|
|
|
|
|
code: Arc<dyn CodeSearch>,
|
2023-11-20 01:00:35 +00:00
|
|
|
) -> (Router, Router) {
|
2023-11-17 07:05:39 +00:00
|
|
|
let conn = db::DbConn::new().await.unwrap();
|
2023-11-18 23:45:00 +00:00
|
|
|
let ctx = Arc::new(ServerContext::new(conn, logger, code));
|
2023-11-14 20:48:20 +00:00
|
|
|
let schema = Arc::new(create_schema());
|
2023-11-12 22:52:28 +00:00
|
|
|
|
2023-11-20 01:00:35 +00:00
|
|
|
let api = api
|
|
|
|
|
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
|
2023-11-12 22:52:28 +00:00
|
|
|
.route(
|
|
|
|
|
"/graphql",
|
2023-11-14 20:48:20 +00:00
|
|
|
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
|
2023-11-12 22:52:28 +00:00
|
|
|
)
|
2023-11-21 23:08:27 +00:00
|
|
|
.route("/graphql", routing::get(playground("/graphql", None)))
|
2023-11-20 01:00:35 +00:00
|
|
|
.layer(Extension(schema))
|
2023-11-20 21:26:57 +00:00
|
|
|
.route("/hub", routing::get(ws_handler).with_state(ctx));
|
2023-11-20 01:00:35 +00:00
|
|
|
|
|
|
|
|
let ui = ui
|
|
|
|
|
.route("/graphiql", routing::get(graphiql("/graphql", None)))
|
|
|
|
|
.fallback(ui::handler);
|
2023-11-09 18:51:07 +00:00
|
|
|
|
2023-11-20 01:00:35 +00:00
|
|
|
(api, ui)
|
2023-11-09 18:51:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn distributed_tabby_layer(
|
2023-11-14 20:48:20 +00:00
|
|
|
State(ws): State<Arc<ServerContext>>,
|
2023-11-09 18:51:07 +00:00
|
|
|
request: Request<Body>,
|
|
|
|
|
next: Next<Body>,
|
|
|
|
|
) -> axum::response::Response {
|
|
|
|
|
ws.dispatch_request(request, next).await
|
|
|
|
|
}
|
2023-11-14 20:48:20 +00:00
|
|
|
|
|
|
|
|
async fn ws_handler(
|
|
|
|
|
ws: WebSocketUpgrade,
|
|
|
|
|
State(state): State<Arc<ServerContext>>,
|
|
|
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
|
|
|
) -> impl IntoResponse {
|
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
|
|
|
|
|
let transport = WebSocketTransport::from(socket);
|
|
|
|
|
let server = BaseChannel::with_defaults(transport);
|
|
|
|
|
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
|
|
|
|
tokio::spawn(server.execute(imp.serve())).await.unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct HubImpl {
|
|
|
|
|
ctx: Arc<ServerContext>,
|
|
|
|
|
conn: SocketAddr,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl HubImpl {
|
|
|
|
|
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
|
|
|
|
Self { ctx, conn }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tarpc::server]
|
|
|
|
|
impl Hub for Arc<HubImpl> {
|
|
|
|
|
async fn register_worker(
|
|
|
|
|
self,
|
|
|
|
|
_context: tarpc::context::Context,
|
|
|
|
|
kind: WorkerKind,
|
|
|
|
|
port: i32,
|
|
|
|
|
name: String,
|
|
|
|
|
device: String,
|
|
|
|
|
arch: String,
|
|
|
|
|
cpu_info: String,
|
|
|
|
|
cpu_count: i32,
|
|
|
|
|
cuda_devices: Vec<String>,
|
2023-11-17 07:05:39 +00:00
|
|
|
token: String,
|
2023-11-14 20:48:20 +00:00
|
|
|
) -> Result<Worker, HubError> {
|
2023-11-17 07:05:39 +00:00
|
|
|
if token.is_empty() {
|
|
|
|
|
return Err(HubError::InvalidToken("Empty worker token".to_string()));
|
|
|
|
|
}
|
|
|
|
|
let server_token = match self.ctx.read_registration_token().await {
|
|
|
|
|
Ok(t) => t,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
error!("fetch server token: {}", err.to_string());
|
|
|
|
|
return Err(HubError::InvalidToken(
|
|
|
|
|
"Failed to fetch server token".to_string(),
|
|
|
|
|
));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
if server_token != token {
|
|
|
|
|
return Err(HubError::InvalidToken("Token mismatch".to_string()));
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-14 20:48:20 +00:00
|
|
|
let worker = Worker {
|
|
|
|
|
name,
|
|
|
|
|
kind,
|
|
|
|
|
addr: format!("http://{}:{}", self.conn.ip(), port),
|
|
|
|
|
device,
|
|
|
|
|
arch,
|
|
|
|
|
cpu_info,
|
|
|
|
|
cpu_count,
|
|
|
|
|
cuda_devices,
|
|
|
|
|
};
|
|
|
|
|
self.ctx.register_worker(worker).await
|
|
|
|
|
}
|
2023-11-18 23:17:54 +00:00
|
|
|
|
|
|
|
|
async fn log_event(self, _context: tarpc::context::Context, content: String) {
|
|
|
|
|
self.ctx.logger.log(content)
|
|
|
|
|
}
|
2023-11-18 23:45:00 +00:00
|
|
|
|
|
|
|
|
async fn search(
|
|
|
|
|
self,
|
|
|
|
|
_context: tarpc::context::Context,
|
|
|
|
|
q: String,
|
|
|
|
|
limit: usize,
|
|
|
|
|
offset: usize,
|
|
|
|
|
) -> SearchResponse {
|
|
|
|
|
match self.ctx.code.search(&q, limit, offset).await {
|
|
|
|
|
Ok(serp) => serp,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
warn!("Failed to search: {}", err);
|
|
|
|
|
SearchResponse::default()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn search_in_language(
|
|
|
|
|
self,
|
|
|
|
|
_context: tarpc::context::Context,
|
|
|
|
|
language: String,
|
|
|
|
|
tokens: Vec<String>,
|
|
|
|
|
limit: usize,
|
|
|
|
|
offset: usize,
|
|
|
|
|
) -> SearchResponse {
|
|
|
|
|
match self
|
|
|
|
|
.ctx
|
|
|
|
|
.code
|
|
|
|
|
.search_in_language(&language, &tokens, limit, offset)
|
|
|
|
|
.await
|
|
|
|
|
{
|
|
|
|
|
Ok(serp) => serp,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
warn!("Failed to search: {}", err);
|
|
|
|
|
SearchResponse::default()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-11-14 20:48:20 +00:00
|
|
|
}
|