From 618009373b4290892db499f93cf7fb057e180fa9 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 14 Nov 2023 12:48:20 -0800 Subject: [PATCH] refactor: use tarpc for easier worker <-> hub communication (#781) * temp * generic * adapt client * rename to api * Revert "rename to api" This reverts commit 8a51b24fecd76a78e6df576ec51605b8d8418975. * refactor: remove uselss mutation * remove useless connection * cleanup api structure * restructure * add webserver api error * webserver.rs -> server.rs * rename service to Hub * update schema * update naming * shrink features * update * mv worker.rs -> server/worker.rs --- Cargo.lock | 135 +++++++++--------- crates/juniper-axum/src/lib.rs | 18 +-- crates/tabby/Cargo.toml | 2 - crates/tabby/graphql/worker.query.graphql | 23 --- crates/tabby/src/main.rs | 7 +- crates/tabby/src/serve.rs | 2 +- crates/tabby/src/worker.rs | 72 ++++------ ee/tabby-webserver/Cargo.toml | 8 +- ee/tabby-webserver/examples/update-schema.rs | 5 +- ee/tabby-webserver/graphql/schema.graphql | 5 - ee/tabby-webserver/src/api.rs | 57 ++++++++ ee/tabby-webserver/src/lib.rs | 90 ++++++++++-- ee/tabby-webserver/src/schema.rs | 95 ++---------- .../src/{webserver.rs => server.rs} | 25 +--- .../src/{webserver => server}/proxy.rs | 0 ee/tabby-webserver/src/{ => server}/worker.rs | 4 +- ee/tabby-webserver/src/websocket.rs | 127 ++++++++++++++++ 17 files changed, 397 insertions(+), 278 deletions(-) delete mode 100644 crates/tabby/graphql/worker.query.graphql create mode 100644 ee/tabby-webserver/src/api.rs rename ee/tabby-webserver/src/{webserver.rs => server.rs} (84%) rename ee/tabby-webserver/src/{webserver => server}/proxy.rs (100%) rename ee/tabby-webserver/src/{ => server}/worker.rs (97%) create mode 100644 ee/tabby-webserver/src/websocket.rs diff --git a/Cargo.lock b/Cargo.lock index c037737..dbb183b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -442,6 +442,7 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" dependencies = [ "async-trait", "axum-core", + "base64 0.21.2", "bitflags 1.3.2", "bytes", "futures-util", @@ -459,8 +460,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -552,6 +555,15 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -1711,15 +1723,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "graphql-introspection-query" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d" -dependencies = [ - "serde", -] - [[package]] name = "graphql-parser" version = "0.3.0" @@ -1730,56 +1733,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "graphql-parser" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474" -dependencies = [ - "combine", - "thiserror", -] - -[[package]] -name = "graphql_client" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa" -dependencies = [ - "graphql_query_derive", - "reqwest", - "serde", - "serde_json", -] - -[[package]] -name = "graphql_client_codegen" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506" -dependencies = [ - "graphql-introspection-query", - "graphql-parser 0.4.0", - "heck 0.4.1", - "lazy_static", - "proc-macro2", - "quote", - "serde", - "serde_json", - "syn 1.0.109", -] - -[[package]] -name = "graphql_query_derive" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c" -dependencies = [ - "graphql_client_codegen", - "proc-macro2", - "syn 1.0.109", -] - [[package]] name = "h2" version = "0.3.19" @@ -2207,7 +2160,7 @@ dependencies = [ "fnv", "futures", "futures-enum", - "graphql-parser 0.3.0", + "graphql-parser", "indexmap 1.9.3", "juniper_codegen", "serde", @@ -3056,18 +3009,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.0" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.0" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", @@ -4344,7 +4297,6 @@ dependencies = [ "chrono", "clap 4.4.7", "futures", - "graphql_client", "http-api-bindings", "hyper", "lazy_static", @@ -4453,14 +4405,20 @@ version = "0.6.0-dev" dependencies = [ "anyhow", "axum", + "bincode", + "futures", "hyper", "juniper", "juniper-axum", "lazy_static", "mime_guess", + "pin-project", "rust-embed 8.0.0", + "serde", + "tarpc", "thiserror", "tokio", + "tokio-tungstenite", "tracing", "unicase", ] @@ -4616,6 +4574,41 @@ dependencies = [ "xattr", ] +[[package]] +name = "tarpc" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f41bce44d290df0598ae4b9cd6ea7f58f651fd3aa4af1b26060c4fa32b08af7" +dependencies = [ + "anyhow", + "fnv", + "futures", + "humantime", + "opentelemetry", + "pin-project", + "rand 0.8.5", + "serde", + "static_assertions", + "tarpc-plugins", + "thiserror", + "tokio", + "tokio-serde", + "tokio-util", + "tracing", + "tracing-opentelemetry", +] + +[[package]] +name = "tarpc-plugins" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "temp_testdir" version = "0.2.3" @@ -4841,6 +4834,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911a61637386b789af998ee23f50aa30d5fd7edcec8d6d3dedae5e5815205466" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project", +] + [[package]] name = "tokio-stream" version = "0.1.14" diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs index ae0b52a..e80b900 100644 --- a/crates/juniper-axum/src/lib.rs +++ b/crates/juniper-axum/src/lib.rs @@ -1,34 +1,26 @@ pub mod extract; pub mod response; -use std::{future, net::SocketAddr}; +use std::{future, sync::Arc}; use axum::{ - extract::{ConnectInfo, Extension, State}, + extract::{Extension, State}, response::{Html, IntoResponse}, }; use juniper_graphql_ws::Schema; use self::{extract::JuniperRequest, response::JuniperResponse}; -pub trait FromStateAndClientAddr { - fn build(state: S, client_addr: SocketAddr) -> C; -} - #[cfg_attr(text, axum::debug_handler)] -pub async fn graphql( - ConnectInfo(addr): ConnectInfo, - State(state): State, +pub async fn graphql( + State(state): State>, Extension(schema): Extension, JuniperRequest(req): JuniperRequest, ) -> impl IntoResponse where S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. - S::Context: FromStateAndClientAddr, - C: Clone, { - let context = S::Context::build(state.clone(), addr); - JuniperResponse(req.execute(schema.root_node(), &context).await).into_response() + JuniperResponse(req.execute(schema.root_node(), &state).await).into_response() } /// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 3178fc8..66d9e04 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -47,8 +47,6 @@ async-trait.workspace = true tabby-webserver = { path = "../../ee/tabby-webserver" } thiserror.workspace = true chrono = "0.4.31" -graphql_client = { version = "0.13.0", features = ["reqwest"] } -reqwest.workspace = true [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/graphql/worker.query.graphql b/crates/tabby/graphql/worker.query.graphql deleted file mode 100644 index 157299e..0000000 --- a/crates/tabby/graphql/worker.query.graphql +++ /dev/null @@ -1,23 +0,0 @@ -mutation RegisterWorker( - $port: Int! - $kind: WorkerKind! - $name: String! - $device: String! - $arch: String! - $cpuInfo: String! - $cpuCount: Int! - $cudaDevices: [String!]! -) { - worker: registerWorker( - port: $port - kind: $kind - name: $name - device: $device - arch: $arch - cpuInfo: $cpuInfo - cpuCount: $cpuCount - cudaDevices: $cudaDevices - ) { - addr - } -} diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 7617627..79e5bf9 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -14,6 +14,7 @@ use opentelemetry::{ }; use opentelemetry_otlp::WithExportConfig; use tabby_common::config::Config; +use tabby_webserver::api::WorkerKind; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; #[derive(Parser)] @@ -105,10 +106,8 @@ async fn main() { Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) .await .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), - Commands::WorkerCompletion(args) => { - worker::main(worker::WorkerKind::Completion, args).await - } - Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await, + Commands::WorkerCompletion(args) => worker::main(WorkerKind::Completion, args).await, + Commands::WorkerChat(args) => worker::main(WorkerKind::Chat, args).await, } opentelemetry::global::shutdown_tracer_provider(); diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index ff8e9e5..361714b 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -109,7 +109,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { .merge(api_router(args, config).await) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)); - let app = attach_webserver(app); + let app = attach_webserver(app).await; let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); info!("Listening at {}", address); diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index d034434..71fe68d 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -4,10 +4,11 @@ use std::{ sync::Arc, }; +use anyhow::Result; use axum::{routing, Router}; use clap::Args; -use graphql_client::{reqwest::post_graphql, GraphQLQuery}; use hyper::Server; +use tabby_webserver::api::WorkerKind; use tracing::{info, warn}; use crate::{ @@ -23,13 +24,6 @@ use crate::{ Device, }; -#[derive(GraphQLQuery)] -#[graphql( - schema_path = "../../ee/tabby-webserver/graphql/schema.graphql", - query_path = "./graphql/worker.query.graphql" -)] -struct RegisterWorker; - #[derive(Args)] pub struct WorkerArgs { /// URL to register this worker. @@ -56,7 +50,7 @@ pub struct WorkerArgs { async fn make_chat_route(args: &WorkerArgs) -> Router { let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); - request_register(register_worker::WorkerKind::CHAT, args).await; + request_register(WorkerKind::Chat, args).await; Router::new().route( "/v1beta/chat/completions", @@ -71,7 +65,7 @@ async fn make_completion_route(args: &WorkerArgs) -> Router { create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await, ); - request_register(register_worker::WorkerKind::COMPLETION, args).await; + request_register(WorkerKind::Completion, args).await; Router::new().route( "/v1/completions", @@ -79,11 +73,6 @@ async fn make_completion_route(args: &WorkerArgs) -> Router { ) } -pub enum WorkerKind { - Chat, - Completion, -} - pub async fn main(kind: WorkerKind, args: &WorkerArgs) { download_model_if_needed(&args.model).await; @@ -103,44 +92,45 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) { - request_register_impl( +async fn request_register(kind: WorkerKind, args: &WorkerArgs) { + if let Err(err) = request_register_impl( kind, args.url.clone(), - args.port as i64, + args.port, args.model.to_owned(), args.device.to_string(), ) - .await; + .await + { + warn!("Failed to register worker: {}", err) + } } async fn request_register_impl( - kind: register_worker::WorkerKind, + kind: WorkerKind, url: String, - port: i64, + port: u16, name: String, device: String, -) { - let client = reqwest::Client::new(); +) -> Result<()> { + let client = tabby_webserver::api::create_client(url).await; let (cpu_info, cpu_count) = read_cpu_info(); let cuda_devices = read_cuda_devices().unwrap_or_default(); - let variables = register_worker::Variables { - port, - kind, - name, - device, - arch: ARCH.to_string(), - cpu_info, - cpu_count: cpu_count as i64, - cuda_devices, - }; + let worker = client + .register_worker( + tabby_webserver::api::tracing_context(), + kind, + port as i32, + name, + device, + ARCH.to_string(), + cpu_info, + cpu_count as i32, + cuda_devices, + ) + .await??; - let url = format!("{}/graphql", url); - match post_graphql::(&client, &url, variables).await { - Ok(x) => { - let addr = x.data.unwrap().worker.addr; - info!("Worker alive at {}", addr); - } - Err(err) => warn!("Failed to register worker: {}", err), - } + info!("Worker alive at {}", worker.addr); + + Ok(()) } diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 8fa01c0..ed0811f 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -7,15 +7,21 @@ homepage.workspace = true [dependencies] anyhow.workspace = true -axum.workspace = true +axum = { workspace = true, features = ["ws"] } +bincode = "1.3.3" +futures.workspace = true hyper = { workspace = true, features=["client"]} juniper.workspace = true juniper-axum = { path = "../../crates/juniper-axum" } lazy_static = "1.4.0" mime_guess = "2.0.4" +pin-project = "1.1.3" rust-embed = "8.0.0" +serde.workspace = true +tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true tokio.workspace = true +tokio-tungstenite = "0.20.1" tracing.workspace = true unicase = "2.7.0" diff --git a/ee/tabby-webserver/examples/update-schema.rs b/ee/tabby-webserver/examples/update-schema.rs index 5ecce9f..40c340d 100644 --- a/ee/tabby-webserver/examples/update-schema.rs +++ b/ee/tabby-webserver/examples/update-schema.rs @@ -1,10 +1,9 @@ use std::fs::write; -use juniper::EmptySubscription; -use tabby_webserver::schema::{Mutation, Query, Schema}; +use tabby_webserver::create_schema; fn main() { - let schema = Schema::new(Query, Mutation, EmptySubscription::new()); + let schema = create_schema(); write( "ee/tabby-webserver/graphql/schema.graphql", schema.as_schema_language(), diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index aa836eb..4a14aac 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -3,10 +3,6 @@ enum WorkerKind { CHAT } -type Mutation { - registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker! -} - type Query { workers: [Worker!]! } @@ -24,5 +20,4 @@ type Worker { schema { query: Query - mutation: Mutation } diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs new file mode 100644 index 0000000..39aabfe --- /dev/null +++ b/ee/tabby-webserver/src/api.rs @@ -0,0 +1,57 @@ +use juniper::{GraphQLEnum, GraphQLObject}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio_tungstenite::connect_async; + +use crate::websocket::WebSocketTransport; + +#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)] +pub enum WorkerKind { + Completion, + Chat, +} + +#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)] +pub struct Worker { + pub kind: WorkerKind, + pub name: String, + pub addr: String, + pub device: String, + pub arch: String, + pub cpu_info: String, + pub cpu_count: i32, + pub cuda_devices: Vec, +} + +#[derive(Serialize, Deserialize, Error, Debug)] +pub enum HubError { + #[error("Invalid worker token")] + InvalidToken(String), + + #[error("Feature requires enterprise license")] + RequiresEnterpriseLicense, +} + +#[tarpc::service] +pub trait Hub { + async fn register_worker( + kind: WorkerKind, + port: i32, + name: String, + device: String, + arch: String, + cpu_info: String, + cpu_count: i32, + cuda_devices: Vec, + ) -> Result; +} + +pub fn tracing_context() -> tarpc::context::Context { + tarpc::context::current() +} + +pub async fn create_client(addr: String) -> HubClient { + let addr = format!("ws://{}/hub", addr); + let (socket, _) = connect_async(&addr).await.unwrap(); + HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn() +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 7f5dd63..ad24042 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,45 +1,107 @@ -pub mod schema; +pub mod api; + +mod schema; +pub use schema::create_schema; +use websocket::WebSocketTransport; + +mod server; mod ui; -mod webserver; -mod worker; +mod websocket; -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc}; +use api::{Hub, HubError, Worker, WorkerKind}; use axum::{ - extract::State, + extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, http::Request, middleware::{from_fn_with_state, Next}, + response::IntoResponse, routing, Extension, Router, }; use hyper::Body; -use juniper::EmptySubscription; use juniper_axum::{graphiql, graphql, playground}; -use schema::{Mutation, Query, Schema}; -use webserver::Webserver; +use schema::Schema; +use server::ServerContext; +use tarpc::server::{BaseChannel, Channel}; -pub fn attach_webserver(router: Router) -> Router { - let ws = Arc::new(Webserver::default()); - let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new())); +pub async fn attach_webserver(router: Router) -> Router { + let ctx = Arc::new(ServerContext::default()); + let schema = Arc::new(create_schema()); let app = Router::new() .route("/graphql", routing::get(playground("/graphql", None))) .route("/graphiql", routing::get(graphiql("/graphql", None))) .route( "/graphql", - routing::post(graphql::, Arc>).with_state(ws.clone()), + routing::post(graphql::>).with_state(ctx.clone()), ) .layer(Extension(schema)); router .merge(app) + .route("/hub", routing::get(ws_handler).with_state(ctx.clone())) .fallback(ui::handler) - .layer(from_fn_with_state(ws, distributed_tabby_layer)) + .layer(from_fn_with_state(ctx, distributed_tabby_layer)) } async fn distributed_tabby_layer( - State(ws): State>, + State(ws): State>, request: Request, next: Next, ) -> axum::response::Response { ws.dispatch_request(request, next).await } + +async fn ws_handler( + ws: WebSocketUpgrade, + State(state): State>, + ConnectInfo(addr): ConnectInfo, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_socket(state, socket, addr)) +} + +async fn handle_socket(state: Arc, socket: WebSocket, addr: SocketAddr) { + let transport = WebSocketTransport::from(socket); + let server = BaseChannel::with_defaults(transport); + let imp = Arc::new(HubImpl::new(state.clone(), addr)); + tokio::spawn(server.execute(imp.serve())).await.unwrap() +} + +pub struct HubImpl { + ctx: Arc, + conn: SocketAddr, +} + +impl HubImpl { + pub fn new(ctx: Arc, conn: SocketAddr) -> Self { + Self { ctx, conn } + } +} + +#[tarpc::server] +impl Hub for Arc { + async fn register_worker( + self, + _context: tarpc::context::Context, + kind: WorkerKind, + port: i32, + name: String, + device: String, + arch: String, + cpu_info: String, + cpu_count: i32, + cuda_devices: Vec, + ) -> Result { + let worker = Worker { + name, + kind, + addr: format!("http://{}:{}", self.conn.ip(), port), + device, + arch, + cpu_info, + cpu_count, + cuda_devices, + }; + self.ctx.register_worker(worker).await + } +} diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs index 5f130df..f4364dd 100644 --- a/ee/tabby-webserver/src/schema.rs +++ b/ee/tabby-webserver/src/schema.rs @@ -1,98 +1,23 @@ -use std::{net::SocketAddr, sync::Arc}; +use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode}; -use juniper::{ - graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject, - IntoFieldError, RootNode, ScalarValue, Value, -}; -use juniper_axum::FromStateAndClientAddr; - -use crate::webserver::{Webserver, WebserverError}; - -pub struct Request { - ws: Arc, - client_addr: SocketAddr, -} - -impl FromStateAndClientAddr> for Request { - fn build(ws: Arc, client_addr: SocketAddr) -> Request { - Request { ws, client_addr } - } -} +use crate::{api::Worker, server::ServerContext}; // To make our context usable by Juniper, we have to implement a marker trait. -impl juniper::Context for Request {} - -#[derive(GraphQLEnum, Clone, Debug)] -pub enum WorkerKind { - Completion, - Chat, -} - -#[derive(GraphQLObject, Clone, Debug)] -pub struct Worker { - pub kind: WorkerKind, - pub name: String, - pub addr: String, - pub device: String, - pub arch: String, - pub cpu_info: String, - pub cpu_count: i32, - pub cuda_devices: Vec, -} +impl juniper::Context for ServerContext {} #[derive(Default)] pub struct Query; -#[graphql_object(context = Request)] +#[graphql_object(context = ServerContext)] impl Query { - async fn workers(request: &Request) -> Vec { - request.ws.list_workers().await + async fn workers(ctx: &ServerContext) -> Vec { + ctx.list_workers().await } } -pub struct Mutation; +pub type Schema = + RootNode<'static, Query, EmptyMutation, EmptySubscription>; -#[graphql_object(context = Request)] -impl Mutation { - async fn register_worker( - request: &Request, - port: i32, - kind: WorkerKind, - name: String, - device: String, - arch: String, - cpu_info: String, - cpu_count: i32, - cuda_devices: Vec, - ) -> Result { - let ws = &request.ws; - let worker = Worker { - name, - kind, - addr: format!("http://{}:{}", request.client_addr.ip(), port), - device, - arch, - cpu_info, - cpu_count, - cuda_devices, - }; - ws.register_worker(worker).await - } -} - -pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; - -impl IntoFieldError for WebserverError { - fn into_field_error(self) -> FieldError { - let msg = format!("{}", &self); - match self { - WebserverError::InvalidToken(token) => FieldError::new( - msg, - graphql_value!({ - "token": token - }), - ), - _ => FieldError::new(msg, Value::Null), - } - } +pub fn create_schema() -> Schema { + Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()) } diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/server.rs similarity index 84% rename from ee/tabby-webserver/src/webserver.rs rename to ee/tabby-webserver/src/server.rs index 67576dd..3ab1c4b 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/server.rs @@ -1,35 +1,22 @@ mod proxy; +mod worker; use std::net::SocketAddr; use axum::{http::Request, middleware::Next, response::IntoResponse}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; -use thiserror::Error; use tracing::{info, warn}; -use crate::{ - schema::{Worker, WorkerKind}, - worker, -}; - -#[derive(Error, Debug)] -pub enum WebserverError { - #[error("Invalid worker token")] - InvalidToken(String), - - #[error("Feature requires enterprise license")] - RequiresEnterpriseLicense, -} - +use crate::api::{HubError, Worker, WorkerKind}; #[derive(Default)] -pub struct Webserver { +pub struct ServerContext { client: Client, completion: worker::WorkerGroup, chat: worker::WorkerGroup, } -impl Webserver { - pub async fn register_worker(&self, worker: Worker) -> Result { +impl ServerContext { + pub async fn register_worker(&self, worker: Worker) -> Result { let worker = match worker.kind { WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await, @@ -42,7 +29,7 @@ impl Webserver { ); Ok(worker) } else { - Err(WebserverError::RequiresEnterpriseLicense) + Err(HubError::RequiresEnterpriseLicense) } } diff --git a/ee/tabby-webserver/src/webserver/proxy.rs b/ee/tabby-webserver/src/server/proxy.rs similarity index 100% rename from ee/tabby-webserver/src/webserver/proxy.rs rename to ee/tabby-webserver/src/server/proxy.rs diff --git a/ee/tabby-webserver/src/worker.rs b/ee/tabby-webserver/src/server/worker.rs similarity index 97% rename from ee/tabby-webserver/src/worker.rs rename to ee/tabby-webserver/src/server/worker.rs index 53eb755..9709c02 100644 --- a/ee/tabby-webserver/src/worker.rs +++ b/ee/tabby-webserver/src/server/worker.rs @@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tracing::error; -use crate::schema::Worker; +use crate::api::Worker; #[derive(Default)] pub struct WorkerGroup { @@ -51,7 +51,7 @@ fn random_index(size: usize) -> usize { mod tests { use super::*; - use crate::schema::WorkerKind; + use crate::api::WorkerKind; #[tokio::test] async fn test_worker_group() { diff --git a/ee/tabby-webserver/src/websocket.rs b/ee/tabby-webserver/src/websocket.rs new file mode 100644 index 0000000..26db6a7 --- /dev/null +++ b/ee/tabby-webserver/src/websocket.rs @@ -0,0 +1,127 @@ +use std::{ + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use axum::extract::ws; +use futures::{Sink, Stream}; +use pin_project::pin_project; +use tokio::net::TcpStream; +use tokio_tungstenite as tt; +use tokio_tungstenite::tungstenite as ts; + +pub trait IntoData { + fn into_data(self) -> Option>; +} + +impl IntoData for ws::Message { + fn into_data(self) -> Option> { + match self { + ws::Message::Binary(x) => Some(x), + _ => None, + } + } +} + +impl IntoData for ts::Message { + fn into_data(self) -> Option> { + match self { + ts::Message::Binary(x) => Some(x), + _ => None, + } + } +} + +#[pin_project] +pub struct WebSocketTransport +where + Message: IntoData + From>, + Transport: Stream> + Sink, +{ + #[pin] + inner: Transport, + ghost: PhantomData<(Req, Resp)>, +} + +impl From + for WebSocketTransport +{ + fn from(inner: ws::WebSocket) -> Self { + Self { + inner, + ghost: PhantomData, + } + } +} + +impl From>> + for WebSocketTransport< + Req, + Resp, + ts::Message, + tt::WebSocketStream>, + tt::tungstenite::Error, + > +{ + fn from(inner: tokio_tungstenite::WebSocketStream>) -> Self { + Self { + inner, + ghost: PhantomData, + } + } +} + +impl Stream + for WebSocketTransport +where + Req: for<'de> serde::Deserialize<'de>, + Message: IntoData + From> + std::fmt::Debug, + Transport: Stream> + Sink, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match futures::ready!(self.as_mut().project().inner.poll_next(cx)) { + Some(Ok(msg)) => { + let bin = msg.into_data(); + match bin { + Some(bin) => Poll::Ready(Some(Ok(bincode::deserialize_from::<&[u8], Req>( + bin.as_ref(), + ) + .unwrap()))), + None => Poll::Ready(None), + } + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } +} + +impl Sink + for WebSocketTransport +where + Resp: serde::Serialize, + Message: IntoData + From>, + Transport: Stream> + Sink, +{ + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().inner.poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Resp) -> Result<(), Self::Error> { + let msg = Message::from(bincode::serialize(&item).unwrap()); + self.as_mut().project().inner.start_send(msg) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().inner.poll_close(cx) + } +}