mod proxy; mod ui; mod worker; use std::{net::SocketAddr, sync::Arc}; use axum::{ extract::State, http::Request, middleware::{from_fn_with_state, Next}, response::IntoResponse, Router, }; use hyper::{client::HttpConnector, Body, Client, StatusCode}; use tracing::warn; #[derive(Default)] pub struct Webserver { client: Client, completion: worker::WorkerGroup, chat: worker::WorkerGroup, } impl Webserver { async fn dispatch_request( &self, request: Request, next: Next, ) -> axum::response::Response { let path = request.uri().path(); let remote_addr = request .extensions() .get::>() .map(|ci| ci.0) .expect("Unable to extract remote addr"); let worker = if path.starts_with("/v1/completions") { self.completion.select().await } else if path.starts_with("/v1beta/chat/completions") { self.chat.select().await } else { None }; if let Some(worker) = worker { match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await { Ok(res) => res.into_response(), Err(err) => { warn!("Failed to proxy request {}", err); axum::response::Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::empty()) .unwrap() .into_response() } } } else { next.run(request).await } } } pub fn attach_webserver(router: Router) -> Router { let ws = Arc::new(Webserver::default()); router .fallback(ui::handler) .layer(from_fn_with_state(ws, distributed_tabby_layer)) } async fn distributed_tabby_layer( State(ws): State>, request: Request, next: Next, ) -> axum::response::Response { ws.dispatch_request(request, next).await }