refactor: use tarpc for easier worker <-> hub communication (#781)

* temp

* generic

* adapt client

* rename to api

* Revert "rename to api"

This reverts commit 8a51b24fecd76a78e6df576ec51605b8d8418975.

* refactor: remove uselss mutation

* remove useless connection

* cleanup api structure

* restructure

* add webserver api error

* webserver.rs -> server.rs

* rename service to Hub

* update schema

* update naming

* shrink features

* update

* mv worker.rs -> server/worker.rs
release-fix-intellij-update-support-version-range
Meng Zhang 2023-11-14 12:48:20 -08:00 committed by GitHub
parent e521f0637c
commit 618009373b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 397 additions and 278 deletions

135
Cargo.lock generated
View File

@ -442,6 +442,7 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"base64 0.21.2",
"bitflags 1.3.2", "bitflags 1.3.2",
"bytes", "bytes",
"futures-util", "futures-util",
@ -459,8 +460,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -552,6 +555,15 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -1711,15 +1723,6 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "graphql-introspection-query"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "graphql-parser" name = "graphql-parser"
version = "0.3.0" version = "0.3.0"
@ -1730,56 +1733,6 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "graphql-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
dependencies = [
"combine",
"thiserror",
]
[[package]]
name = "graphql_client"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa"
dependencies = [
"graphql_query_derive",
"reqwest",
"serde",
"serde_json",
]
[[package]]
name = "graphql_client_codegen"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506"
dependencies = [
"graphql-introspection-query",
"graphql-parser 0.4.0",
"heck 0.4.1",
"lazy_static",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 1.0.109",
]
[[package]]
name = "graphql_query_derive"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c"
dependencies = [
"graphql_client_codegen",
"proc-macro2",
"syn 1.0.109",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.19" version = "0.3.19"
@ -2207,7 +2160,7 @@ dependencies = [
"fnv", "fnv",
"futures", "futures",
"futures-enum", "futures-enum",
"graphql-parser 0.3.0", "graphql-parser",
"indexmap 1.9.3", "indexmap 1.9.3",
"juniper_codegen", "juniper_codegen",
"serde", "serde",
@ -3056,18 +3009,18 @@ dependencies = [
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.0" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
dependencies = [ dependencies = [
"pin-project-internal", "pin-project-internal",
] ]
[[package]] [[package]]
name = "pin-project-internal" name = "pin-project-internal"
version = "1.1.0" version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -4344,7 +4297,6 @@ dependencies = [
"chrono", "chrono",
"clap 4.4.7", "clap 4.4.7",
"futures", "futures",
"graphql_client",
"http-api-bindings", "http-api-bindings",
"hyper", "hyper",
"lazy_static", "lazy_static",
@ -4453,14 +4405,20 @@ version = "0.6.0-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"bincode",
"futures",
"hyper", "hyper",
"juniper", "juniper",
"juniper-axum", "juniper-axum",
"lazy_static", "lazy_static",
"mime_guess", "mime_guess",
"pin-project",
"rust-embed 8.0.0", "rust-embed 8.0.0",
"serde",
"tarpc",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-tungstenite",
"tracing", "tracing",
"unicase", "unicase",
] ]
@ -4616,6 +4574,41 @@ dependencies = [
"xattr", "xattr",
] ]
[[package]]
name = "tarpc"
version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f41bce44d290df0598ae4b9cd6ea7f58f651fd3aa4af1b26060c4fa32b08af7"
dependencies = [
"anyhow",
"fnv",
"futures",
"humantime",
"opentelemetry",
"pin-project",
"rand 0.8.5",
"serde",
"static_assertions",
"tarpc-plugins",
"thiserror",
"tokio",
"tokio-serde",
"tokio-util",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "tarpc-plugins"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "temp_testdir" name = "temp_testdir"
version = "0.2.3" version = "0.2.3"
@ -4841,6 +4834,18 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "911a61637386b789af998ee23f50aa30d5fd7edcec8d6d3dedae5e5815205466"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.14" version = "0.1.14"

View File

@ -1,34 +1,26 @@
pub mod extract; pub mod extract;
pub mod response; pub mod response;
use std::{future, net::SocketAddr}; use std::{future, sync::Arc};
use axum::{ use axum::{
extract::{ConnectInfo, Extension, State}, extract::{Extension, State},
response::{Html, IntoResponse}, response::{Html, IntoResponse},
}; };
use juniper_graphql_ws::Schema; use juniper_graphql_ws::Schema;
use self::{extract::JuniperRequest, response::JuniperResponse}; use self::{extract::JuniperRequest, response::JuniperResponse};
pub trait FromStateAndClientAddr<C, S> {
fn build(state: S, client_addr: SocketAddr) -> C;
}
#[cfg_attr(text, axum::debug_handler)] #[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S, C>( pub async fn graphql<S>(
ConnectInfo(addr): ConnectInfo<SocketAddr>, State(state): State<Arc<S::Context>>,
State(state): State<C>,
Extension(schema): Extension<S>, Extension(schema): Extension<S>,
JuniperRequest(req): JuniperRequest<S::ScalarValue>, JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse ) -> impl IntoResponse
where where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromStateAndClientAddr<S::Context, C>,
C: Clone,
{ {
let context = S::Context::build(state.clone(), addr); JuniperResponse(req.execute(schema.root_node(), &state).await).into_response()
JuniperResponse(req.execute(schema.root_node(), &context).await).into_response()
} }
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. /// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].

View File

@ -47,8 +47,6 @@ async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" } tabby-webserver = { path = "../../ee/tabby-webserver" }
thiserror.workspace = true thiserror.workspace = true
chrono = "0.4.31" chrono = "0.4.31"
graphql_client = { version = "0.13.0", features = ["reqwest"] }
reqwest.workspace = true
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"

View File

@ -1,23 +0,0 @@
mutation RegisterWorker(
$port: Int!
$kind: WorkerKind!
$name: String!
$device: String!
$arch: String!
$cpuInfo: String!
$cpuCount: Int!
$cudaDevices: [String!]!
) {
worker: registerWorker(
port: $port
kind: $kind
name: $name
device: $device
arch: $arch
cpuInfo: $cpuInfo
cpuCount: $cpuCount
cudaDevices: $cudaDevices
) {
addr
}
}

View File

@ -14,6 +14,7 @@ use opentelemetry::{
}; };
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use tabby_common::config::Config; use tabby_common::config::Config;
use tabby_webserver::api::WorkerKind;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Parser)] #[derive(Parser)]
@ -105,10 +106,8 @@ async fn main() {
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await .await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
Commands::WorkerCompletion(args) => { Commands::WorkerCompletion(args) => worker::main(WorkerKind::Completion, args).await,
worker::main(worker::WorkerKind::Completion, args).await Commands::WorkerChat(args) => worker::main(WorkerKind::Chat, args).await,
}
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
} }
opentelemetry::global::shutdown_tracer_provider(); opentelemetry::global::shutdown_tracer_provider();

View File

@ -109,7 +109,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.merge(api_router(args, config).await) .merge(api_router(args, config).await)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)); .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
let app = attach_webserver(app); let app = attach_webserver(app).await;
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address); info!("Listening at {}", address);

View File

@ -4,10 +4,11 @@ use std::{
sync::Arc, sync::Arc,
}; };
use anyhow::Result;
use axum::{routing, Router}; use axum::{routing, Router};
use clap::Args; use clap::Args;
use graphql_client::{reqwest::post_graphql, GraphQLQuery};
use hyper::Server; use hyper::Server;
use tabby_webserver::api::WorkerKind;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use crate::{
@ -23,13 +24,6 @@ use crate::{
Device, Device,
}; };
#[derive(GraphQLQuery)]
#[graphql(
schema_path = "../../ee/tabby-webserver/graphql/schema.graphql",
query_path = "./graphql/worker.query.graphql"
)]
struct RegisterWorker;
#[derive(Args)] #[derive(Args)]
pub struct WorkerArgs { pub struct WorkerArgs {
/// URL to register this worker. /// URL to register this worker.
@ -56,7 +50,7 @@ pub struct WorkerArgs {
async fn make_chat_route(args: &WorkerArgs) -> Router { async fn make_chat_route(args: &WorkerArgs) -> Router {
let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
request_register(register_worker::WorkerKind::CHAT, args).await; request_register(WorkerKind::Chat, args).await;
Router::new().route( Router::new().route(
"/v1beta/chat/completions", "/v1beta/chat/completions",
@ -71,7 +65,7 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await, create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
); );
request_register(register_worker::WorkerKind::COMPLETION, args).await; request_register(WorkerKind::Completion, args).await;
Router::new().route( Router::new().route(
"/v1/completions", "/v1/completions",
@ -79,11 +73,6 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
) )
} }
pub enum WorkerKind {
Chat,
Completion,
}
pub async fn main(kind: WorkerKind, args: &WorkerArgs) { pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
download_model_if_needed(&args.model).await; download_model_if_needed(&args.model).await;
@ -103,44 +92,45 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
} }
async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) { async fn request_register(kind: WorkerKind, args: &WorkerArgs) {
request_register_impl( if let Err(err) = request_register_impl(
kind, kind,
args.url.clone(), args.url.clone(),
args.port as i64, args.port,
args.model.to_owned(), args.model.to_owned(),
args.device.to_string(), args.device.to_string(),
) )
.await; .await
{
warn!("Failed to register worker: {}", err)
}
} }
async fn request_register_impl( async fn request_register_impl(
kind: register_worker::WorkerKind, kind: WorkerKind,
url: String, url: String,
port: i64, port: u16,
name: String, name: String,
device: String, device: String,
) { ) -> Result<()> {
let client = reqwest::Client::new(); let client = tabby_webserver::api::create_client(url).await;
let (cpu_info, cpu_count) = read_cpu_info(); let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = read_cuda_devices().unwrap_or_default(); let cuda_devices = read_cuda_devices().unwrap_or_default();
let variables = register_worker::Variables { let worker = client
port, .register_worker(
kind, tabby_webserver::api::tracing_context(),
name, kind,
device, port as i32,
arch: ARCH.to_string(), name,
cpu_info, device,
cpu_count: cpu_count as i64, ARCH.to_string(),
cuda_devices, cpu_info,
}; cpu_count as i32,
cuda_devices,
)
.await??;
let url = format!("{}/graphql", url); info!("Worker alive at {}", worker.addr);
match post_graphql::<RegisterWorker, _>(&client, &url, variables).await {
Ok(x) => { Ok(())
let addr = x.data.unwrap().worker.addr;
info!("Worker alive at {}", addr);
}
Err(err) => warn!("Failed to register worker: {}", err),
}
} }

View File

@ -7,15 +7,21 @@ homepage.workspace = true
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
axum.workspace = true axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
futures.workspace = true
hyper = { workspace = true, features=["client"]} hyper = { workspace = true, features=["client"]}
juniper.workspace = true juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" } juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0" lazy_static = "1.4.0"
mime_guess = "2.0.4" mime_guess = "2.0.4"
pin-project = "1.1.3"
rust-embed = "8.0.0" rust-embed = "8.0.0"
serde.workspace = true
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true thiserror.workspace = true
tokio.workspace = true tokio.workspace = true
tokio-tungstenite = "0.20.1"
tracing.workspace = true tracing.workspace = true
unicase = "2.7.0" unicase = "2.7.0"

View File

@ -1,10 +1,9 @@
use std::fs::write; use std::fs::write;
use juniper::EmptySubscription; use tabby_webserver::create_schema;
use tabby_webserver::schema::{Mutation, Query, Schema};
fn main() { fn main() {
let schema = Schema::new(Query, Mutation, EmptySubscription::new()); let schema = create_schema();
write( write(
"ee/tabby-webserver/graphql/schema.graphql", "ee/tabby-webserver/graphql/schema.graphql",
schema.as_schema_language(), schema.as_schema_language(),

View File

@ -3,10 +3,6 @@ enum WorkerKind {
CHAT CHAT
} }
type Mutation {
registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker!
}
type Query { type Query {
workers: [Worker!]! workers: [Worker!]!
} }
@ -24,5 +20,4 @@ type Worker {
schema { schema {
query: Query query: Query
mutation: Mutation
} }

View File

@ -0,0 +1,57 @@
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio_tungstenite::connect_async;
use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
#[error("Invalid worker token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
#[tarpc::service]
pub trait Hub {
async fn register_worker(
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, HubError>;
}
pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}
pub async fn create_client(addr: String) -> HubClient {
let addr = format!("ws://{}/hub", addr);
let (socket, _) = connect_async(&addr).await.unwrap();
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
}

View File

@ -1,45 +1,107 @@
pub mod schema; pub mod api;
mod schema;
pub use schema::create_schema;
use websocket::WebSocketTransport;
mod server;
mod ui; mod ui;
mod webserver; mod websocket;
mod worker;
use std::sync::Arc; use std::{net::SocketAddr, sync::Arc};
use api::{Hub, HubError, Worker, WorkerKind};
use axum::{ use axum::{
extract::State, extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request, http::Request,
middleware::{from_fn_with_state, Next}, middleware::{from_fn_with_state, Next},
response::IntoResponse,
routing, Extension, Router, routing, Extension, Router,
}; };
use hyper::Body; use hyper::Body;
use juniper::EmptySubscription;
use juniper_axum::{graphiql, graphql, playground}; use juniper_axum::{graphiql, graphql, playground};
use schema::{Mutation, Query, Schema}; use schema::Schema;
use webserver::Webserver; use server::ServerContext;
use tarpc::server::{BaseChannel, Channel};
pub fn attach_webserver(router: Router) -> Router { pub async fn attach_webserver(router: Router) -> Router {
let ws = Arc::new(Webserver::default()); let ctx = Arc::new(ServerContext::default());
let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new())); let schema = Arc::new(create_schema());
let app = Router::new() let app = Router::new()
.route("/graphql", routing::get(playground("/graphql", None))) .route("/graphql", routing::get(playground("/graphql", None)))
.route("/graphiql", routing::get(graphiql("/graphql", None))) .route("/graphiql", routing::get(graphiql("/graphql", None)))
.route( .route(
"/graphql", "/graphql",
routing::post(graphql::<Arc<Schema>, Arc<Webserver>>).with_state(ws.clone()), routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
) )
.layer(Extension(schema)); .layer(Extension(schema));
router router
.merge(app) .merge(app)
.route("/hub", routing::get(ws_handler).with_state(ctx.clone()))
.fallback(ui::handler) .fallback(ui::handler)
.layer(from_fn_with_state(ws, distributed_tabby_layer)) .layer(from_fn_with_state(ctx, distributed_tabby_layer))
} }
async fn distributed_tabby_layer( async fn distributed_tabby_layer(
State(ws): State<Arc<Webserver>>, State(ws): State<Arc<ServerContext>>,
request: Request<Body>, request: Request<Body>,
next: Next<Body>, next: Next<Body>,
) -> axum::response::Response { ) -> axum::response::Response {
ws.dispatch_request(request, next).await ws.dispatch_request(request, next).await
} }
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<ServerContext>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
}
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
let transport = WebSocketTransport::from(socket);
let server = BaseChannel::with_defaults(transport);
let imp = Arc::new(HubImpl::new(state.clone(), addr));
tokio::spawn(server.execute(imp.serve())).await.unwrap()
}
pub struct HubImpl {
ctx: Arc<ServerContext>,
conn: SocketAddr,
}
impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
Self { ctx, conn }
}
}
#[tarpc::server]
impl Hub for Arc<HubImpl> {
async fn register_worker(
self,
_context: tarpc::context::Context,
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, HubError> {
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", self.conn.ip(), port),
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
};
self.ctx.register_worker(worker).await
}
}

View File

@ -1,98 +1,23 @@
use std::{net::SocketAddr, sync::Arc}; use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode};
use juniper::{ use crate::{api::Worker, server::ServerContext};
graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject,
IntoFieldError, RootNode, ScalarValue, Value,
};
use juniper_axum::FromStateAndClientAddr;
use crate::webserver::{Webserver, WebserverError};
pub struct Request {
ws: Arc<Webserver>,
client_addr: SocketAddr,
}
impl FromStateAndClientAddr<Request, Arc<Webserver>> for Request {
fn build(ws: Arc<Webserver>, client_addr: SocketAddr) -> Request {
Request { ws, client_addr }
}
}
// To make our context usable by Juniper, we have to implement a marker trait. // To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Request {} impl juniper::Context for ServerContext {}
#[derive(GraphQLEnum, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Default)] #[derive(Default)]
pub struct Query; pub struct Query;
#[graphql_object(context = Request)] #[graphql_object(context = ServerContext)]
impl Query { impl Query {
async fn workers(request: &Request) -> Vec<Worker> { async fn workers(ctx: &ServerContext) -> Vec<Worker> {
request.ws.list_workers().await ctx.list_workers().await
} }
} }
pub struct Mutation; pub type Schema =
RootNode<'static, Query, EmptyMutation<ServerContext>, EmptySubscription<ServerContext>>;
#[graphql_object(context = Request)] pub fn create_schema() -> Schema {
impl Mutation { Schema::new(Query, EmptyMutation::new(), EmptySubscription::new())
async fn register_worker(
request: &Request,
port: i32,
kind: WorkerKind,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, WebserverError> {
let ws = &request.ws;
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", request.client_addr.ip(), port),
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
};
ws.register_worker(worker).await
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Request>>;
impl<S: ScalarValue> IntoFieldError<S> for WebserverError {
fn into_field_error(self) -> FieldError<S> {
let msg = format!("{}", &self);
match self {
WebserverError::InvalidToken(token) => FieldError::new(
msg,
graphql_value!({
"token": token
}),
),
_ => FieldError::new(msg, Value::Null),
}
}
} }

View File

@ -1,35 +1,22 @@
mod proxy; mod proxy;
mod worker;
use std::net::SocketAddr; use std::net::SocketAddr;
use axum::{http::Request, middleware::Next, response::IntoResponse}; use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode}; use hyper::{client::HttpConnector, Body, Client, StatusCode};
use thiserror::Error;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use crate::api::{HubError, Worker, WorkerKind};
schema::{Worker, WorkerKind},
worker,
};
#[derive(Error, Debug)]
pub enum WebserverError {
#[error("Invalid worker token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
#[derive(Default)] #[derive(Default)]
pub struct Webserver { pub struct ServerContext {
client: Client<HttpConnector>, client: Client<HttpConnector>,
completion: worker::WorkerGroup, completion: worker::WorkerGroup,
chat: worker::WorkerGroup, chat: worker::WorkerGroup,
} }
impl Webserver { impl ServerContext {
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, WebserverError> { pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
let worker = match worker.kind { let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await,
@ -42,7 +29,7 @@ impl Webserver {
); );
Ok(worker) Ok(worker)
} else { } else {
Err(WebserverError::RequiresEnterpriseLicense) Err(HubError::RequiresEnterpriseLicense)
} }
} }

View File

@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::error; use tracing::error;
use crate::schema::Worker; use crate::api::Worker;
#[derive(Default)] #[derive(Default)]
pub struct WorkerGroup { pub struct WorkerGroup {
@ -51,7 +51,7 @@ fn random_index(size: usize) -> usize {
mod tests { mod tests {
use super::*; use super::*;
use crate::schema::WorkerKind; use crate::api::WorkerKind;
#[tokio::test] #[tokio::test]
async fn test_worker_group() { async fn test_worker_group() {

View File

@ -0,0 +1,127 @@
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use axum::extract::ws;
use futures::{Sink, Stream};
use pin_project::pin_project;
use tokio::net::TcpStream;
use tokio_tungstenite as tt;
use tokio_tungstenite::tungstenite as ts;
pub trait IntoData {
fn into_data(self) -> Option<Vec<u8>>;
}
impl IntoData for ws::Message {
fn into_data(self) -> Option<Vec<u8>> {
match self {
ws::Message::Binary(x) => Some(x),
_ => None,
}
}
}
impl IntoData for ts::Message {
fn into_data(self) -> Option<Vec<u8>> {
match self {
ts::Message::Binary(x) => Some(x),
_ => None,
}
}
}
#[pin_project]
pub struct WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Message: IntoData + From<Vec<u8>>,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
#[pin]
inner: Transport,
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp> From<ws::WebSocket>
for WebSocketTransport<Req, Resp, ws::Message, ws::WebSocket, axum::Error>
{
fn from(inner: ws::WebSocket) -> Self {
Self {
inner,
ghost: PhantomData,
}
}
}
impl<Req, Resp> From<tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>>
for WebSocketTransport<
Req,
Resp,
ts::Message,
tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>,
tt::tungstenite::Error,
>
{
fn from(inner: tokio_tungstenite::WebSocketStream<tt::MaybeTlsStream<TcpStream>>) -> Self {
Self {
inner,
ghost: PhantomData,
}
}
}
impl<Req, Resp, Message, Transport, Error> Stream
for WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Req: for<'de> serde::Deserialize<'de>,
Message: IntoData + From<Vec<u8>> + std::fmt::Debug,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
type Item = Result<Req, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match futures::ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(Ok(msg)) => {
let bin = msg.into_data();
match bin {
Some(bin) => Poll::Ready(Some(Ok(bincode::deserialize_from::<&[u8], Req>(
bin.as_ref(),
)
.unwrap()))),
None => Poll::Ready(None),
}
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}
impl<Req, Resp, Message, Transport, Error> Sink<Resp>
for WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Resp: serde::Serialize,
Message: IntoData + From<Vec<u8>>,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Resp) -> Result<(), Self::Error> {
let msg = Message::from(bincode::serialize(&item).unwrap());
self.as_mut().project().inner.start_send(msg)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_close(cx)
}
}