refactor: use tarpc for easier worker <-> hub communication (#781)
* temp * generic * adapt client * rename to api * Revert "rename to api" This reverts commit 8a51b24fecd76a78e6df576ec51605b8d8418975. * refactor: remove uselss mutation * remove useless connection * cleanup api structure * restructure * add webserver api error * webserver.rs -> server.rs * rename service to Hub * update schema * update naming * shrink features * update * mv worker.rs -> server/worker.rsrelease-fix-intellij-update-support-version-range
parent
e521f0637c
commit
618009373b
|
|
@ -442,6 +442,7 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64 0.21.2",
|
||||
"bitflags 1.3.2",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
|
|
@ -459,8 +460,10 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
|
|
@ -552,6 +555,15 @@ version = "0.5.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1"
|
||||
|
||||
[[package]]
|
||||
name = "bincode"
|
||||
version = "1.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
|
|
@ -1711,15 +1723,6 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-introspection-query"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-parser"
|
||||
version = "0.3.0"
|
||||
|
|
@ -1730,56 +1733,6 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-parser"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
|
||||
dependencies = [
|
||||
"combine",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_client"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa"
|
||||
dependencies = [
|
||||
"graphql_query_derive",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_client_codegen"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506"
|
||||
dependencies = [
|
||||
"graphql-introspection-query",
|
||||
"graphql-parser 0.4.0",
|
||||
"heck 0.4.1",
|
||||
"lazy_static",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_query_derive"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c"
|
||||
dependencies = [
|
||||
"graphql_client_codegen",
|
||||
"proc-macro2",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.19"
|
||||
|
|
@ -2207,7 +2160,7 @@ dependencies = [
|
|||
"fnv",
|
||||
"futures",
|
||||
"futures-enum",
|
||||
"graphql-parser 0.3.0",
|
||||
"graphql-parser",
|
||||
"indexmap 1.9.3",
|
||||
"juniper_codegen",
|
||||
"serde",
|
||||
|
|
@ -3056,18 +3009,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.1.0"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead"
|
||||
checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.1.0"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
|
||||
checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -4344,7 +4297,6 @@ dependencies = [
|
|||
"chrono",
|
||||
"clap 4.4.7",
|
||||
"futures",
|
||||
"graphql_client",
|
||||
"http-api-bindings",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
|
|
@ -4453,14 +4405,20 @@ version = "0.6.0-dev"
|
|||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"bincode",
|
||||
"futures",
|
||||
"hyper",
|
||||
"juniper",
|
||||
"juniper-axum",
|
||||
"lazy_static",
|
||||
"mime_guess",
|
||||
"pin-project",
|
||||
"rust-embed 8.0.0",
|
||||
"serde",
|
||||
"tarpc",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tracing",
|
||||
"unicase",
|
||||
]
|
||||
|
|
@ -4616,6 +4574,41 @@ dependencies = [
|
|||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tarpc"
|
||||
version = "0.33.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f41bce44d290df0598ae4b9cd6ea7f58f651fd3aa4af1b26060c4fa32b08af7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"fnv",
|
||||
"futures",
|
||||
"humantime",
|
||||
"opentelemetry",
|
||||
"pin-project",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"static_assertions",
|
||||
"tarpc-plugins",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-serde",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tarpc-plugins"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "temp_testdir"
|
||||
version = "0.2.3"
|
||||
|
|
@ -4841,6 +4834,18 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-serde"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "911a61637386b789af998ee23f50aa30d5fd7edcec8d6d3dedae5e5815205466"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"pin-project",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.14"
|
||||
|
|
|
|||
|
|
@ -1,34 +1,26 @@
|
|||
pub mod extract;
|
||||
pub mod response;
|
||||
|
||||
use std::{future, net::SocketAddr};
|
||||
use std::{future, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::{ConnectInfo, Extension, State},
|
||||
extract::{Extension, State},
|
||||
response::{Html, IntoResponse},
|
||||
};
|
||||
use juniper_graphql_ws::Schema;
|
||||
|
||||
use self::{extract::JuniperRequest, response::JuniperResponse};
|
||||
|
||||
pub trait FromStateAndClientAddr<C, S> {
|
||||
fn build(state: S, client_addr: SocketAddr) -> C;
|
||||
}
|
||||
|
||||
#[cfg_attr(text, axum::debug_handler)]
|
||||
pub async fn graphql<S, C>(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<C>,
|
||||
pub async fn graphql<S>(
|
||||
State(state): State<Arc<S::Context>>,
|
||||
Extension(schema): Extension<S>,
|
||||
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
|
||||
S::Context: FromStateAndClientAddr<S::Context, C>,
|
||||
C: Clone,
|
||||
{
|
||||
let context = S::Context::build(state.clone(), addr);
|
||||
JuniperResponse(req.execute(schema.root_node(), &context).await).into_response()
|
||||
JuniperResponse(req.execute(schema.root_node(), &state).await).into_response()
|
||||
}
|
||||
|
||||
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].
|
||||
|
|
|
|||
|
|
@ -47,8 +47,6 @@ async-trait.workspace = true
|
|||
tabby-webserver = { path = "../../ee/tabby-webserver" }
|
||||
thiserror.workspace = true
|
||||
chrono = "0.4.31"
|
||||
graphql_client = { version = "0.13.0", features = ["reqwest"] }
|
||||
reqwest.workspace = true
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
mutation RegisterWorker(
|
||||
$port: Int!
|
||||
$kind: WorkerKind!
|
||||
$name: String!
|
||||
$device: String!
|
||||
$arch: String!
|
||||
$cpuInfo: String!
|
||||
$cpuCount: Int!
|
||||
$cudaDevices: [String!]!
|
||||
) {
|
||||
worker: registerWorker(
|
||||
port: $port
|
||||
kind: $kind
|
||||
name: $name
|
||||
device: $device
|
||||
arch: $arch
|
||||
cpuInfo: $cpuInfo
|
||||
cpuCount: $cpuCount
|
||||
cudaDevices: $cudaDevices
|
||||
) {
|
||||
addr
|
||||
}
|
||||
}
|
||||
|
|
@ -14,6 +14,7 @@ use opentelemetry::{
|
|||
};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tabby_common::config::Config;
|
||||
use tabby_webserver::api::WorkerKind;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
#[derive(Parser)]
|
||||
|
|
@ -105,10 +106,8 @@ async fn main() {
|
|||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||
.await
|
||||
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
|
||||
Commands::WorkerCompletion(args) => {
|
||||
worker::main(worker::WorkerKind::Completion, args).await
|
||||
}
|
||||
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
|
||||
Commands::WorkerCompletion(args) => worker::main(WorkerKind::Completion, args).await,
|
||||
Commands::WorkerChat(args) => worker::main(WorkerKind::Chat, args).await,
|
||||
}
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
.merge(api_router(args, config).await)
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
|
||||
|
||||
let app = attach_webserver(app);
|
||||
let app = attach_webserver(app).await;
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
info!("Listening at {}", address);
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ use std::{
|
|||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{routing, Router};
|
||||
use clap::Args;
|
||||
use graphql_client::{reqwest::post_graphql, GraphQLQuery};
|
||||
use hyper::Server;
|
||||
use tabby_webserver::api::WorkerKind;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
|
|
@ -23,13 +24,6 @@ use crate::{
|
|||
Device,
|
||||
};
|
||||
|
||||
#[derive(GraphQLQuery)]
|
||||
#[graphql(
|
||||
schema_path = "../../ee/tabby-webserver/graphql/schema.graphql",
|
||||
query_path = "./graphql/worker.query.graphql"
|
||||
)]
|
||||
struct RegisterWorker;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct WorkerArgs {
|
||||
/// URL to register this worker.
|
||||
|
|
@ -56,7 +50,7 @@ pub struct WorkerArgs {
|
|||
async fn make_chat_route(args: &WorkerArgs) -> Router {
|
||||
let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
|
||||
|
||||
request_register(register_worker::WorkerKind::CHAT, args).await;
|
||||
request_register(WorkerKind::Chat, args).await;
|
||||
|
||||
Router::new().route(
|
||||
"/v1beta/chat/completions",
|
||||
|
|
@ -71,7 +65,7 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
|
|||
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
|
||||
);
|
||||
|
||||
request_register(register_worker::WorkerKind::COMPLETION, args).await;
|
||||
request_register(WorkerKind::Completion, args).await;
|
||||
|
||||
Router::new().route(
|
||||
"/v1/completions",
|
||||
|
|
@ -79,11 +73,6 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
|
|||
)
|
||||
}
|
||||
|
||||
pub enum WorkerKind {
|
||||
Chat,
|
||||
Completion,
|
||||
}
|
||||
|
||||
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
|
||||
download_model_if_needed(&args.model).await;
|
||||
|
||||
|
|
@ -103,44 +92,45 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
|
|||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||
}
|
||||
|
||||
async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) {
|
||||
request_register_impl(
|
||||
async fn request_register(kind: WorkerKind, args: &WorkerArgs) {
|
||||
if let Err(err) = request_register_impl(
|
||||
kind,
|
||||
args.url.clone(),
|
||||
args.port as i64,
|
||||
args.port,
|
||||
args.model.to_owned(),
|
||||
args.device.to_string(),
|
||||
)
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
warn!("Failed to register worker: {}", err)
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_register_impl(
|
||||
kind: register_worker::WorkerKind,
|
||||
kind: WorkerKind,
|
||||
url: String,
|
||||
port: i64,
|
||||
port: u16,
|
||||
name: String,
|
||||
device: String,
|
||||
) {
|
||||
let client = reqwest::Client::new();
|
||||
) -> Result<()> {
|
||||
let client = tabby_webserver::api::create_client(url).await;
|
||||
let (cpu_info, cpu_count) = read_cpu_info();
|
||||
let cuda_devices = read_cuda_devices().unwrap_or_default();
|
||||
let variables = register_worker::Variables {
|
||||
port,
|
||||
kind,
|
||||
name,
|
||||
device,
|
||||
arch: ARCH.to_string(),
|
||||
cpu_info,
|
||||
cpu_count: cpu_count as i64,
|
||||
cuda_devices,
|
||||
};
|
||||
let worker = client
|
||||
.register_worker(
|
||||
tabby_webserver::api::tracing_context(),
|
||||
kind,
|
||||
port as i32,
|
||||
name,
|
||||
device,
|
||||
ARCH.to_string(),
|
||||
cpu_info,
|
||||
cpu_count as i32,
|
||||
cuda_devices,
|
||||
)
|
||||
.await??;
|
||||
|
||||
let url = format!("{}/graphql", url);
|
||||
match post_graphql::<RegisterWorker, _>(&client, &url, variables).await {
|
||||
Ok(x) => {
|
||||
let addr = x.data.unwrap().worker.addr;
|
||||
info!("Worker alive at {}", addr);
|
||||
}
|
||||
Err(err) => warn!("Failed to register worker: {}", err),
|
||||
}
|
||||
info!("Worker alive at {}", worker.addr);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,15 +7,21 @@ homepage.workspace = true
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum.workspace = true
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
bincode = "1.3.3"
|
||||
futures.workspace = true
|
||||
hyper = { workspace = true, features=["client"]}
|
||||
juniper.workspace = true
|
||||
juniper-axum = { path = "../../crates/juniper-axum" }
|
||||
lazy_static = "1.4.0"
|
||||
mime_guess = "2.0.4"
|
||||
pin-project = "1.1.3"
|
||||
rust-embed = "8.0.0"
|
||||
serde.workspace = true
|
||||
tarpc = { version = "0.33.0", features = ["serde-transport"] }
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-tungstenite = "0.20.1"
|
||||
tracing.workspace = true
|
||||
unicase = "2.7.0"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
use std::fs::write;
|
||||
|
||||
use juniper::EmptySubscription;
|
||||
use tabby_webserver::schema::{Mutation, Query, Schema};
|
||||
use tabby_webserver::create_schema;
|
||||
|
||||
fn main() {
|
||||
let schema = Schema::new(Query, Mutation, EmptySubscription::new());
|
||||
let schema = create_schema();
|
||||
write(
|
||||
"ee/tabby-webserver/graphql/schema.graphql",
|
||||
schema.as_schema_language(),
|
||||
|
|
|
|||
|
|
@ -3,10 +3,6 @@ enum WorkerKind {
|
|||
CHAT
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker!
|
||||
}
|
||||
|
||||
type Query {
|
||||
workers: [Worker!]!
|
||||
}
|
||||
|
|
@ -24,5 +20,4 @@ type Worker {
|
|||
|
||||
schema {
|
||||
query: Query
|
||||
mutation: Mutation
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
use juniper::{GraphQLEnum, GraphQLObject};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio_tungstenite::connect_async;
|
||||
|
||||
use crate::websocket::WebSocketTransport;
|
||||
|
||||
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
||||
pub enum WorkerKind {
|
||||
Completion,
|
||||
Chat,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
pub kind: WorkerKind,
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub device: String,
|
||||
pub arch: String,
|
||||
pub cpu_info: String,
|
||||
pub cpu_count: i32,
|
||||
pub cuda_devices: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||
pub enum HubError {
|
||||
#[error("Invalid worker token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
pub trait Hub {
|
||||
async fn register_worker(
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
name: String,
|
||||
device: String,
|
||||
arch: String,
|
||||
cpu_info: String,
|
||||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
) -> Result<Worker, HubError>;
|
||||
}
|
||||
|
||||
pub fn tracing_context() -> tarpc::context::Context {
|
||||
tarpc::context::current()
|
||||
}
|
||||
|
||||
pub async fn create_client(addr: String) -> HubClient {
|
||||
let addr = format!("ws://{}/hub", addr);
|
||||
let (socket, _) = connect_async(&addr).await.unwrap();
|
||||
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
|
||||
}
|
||||
|
|
@ -1,45 +1,107 @@
|
|||
pub mod schema;
|
||||
pub mod api;
|
||||
|
||||
mod schema;
|
||||
pub use schema::create_schema;
|
||||
use websocket::WebSocketTransport;
|
||||
|
||||
mod server;
|
||||
mod ui;
|
||||
mod webserver;
|
||||
mod worker;
|
||||
mod websocket;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use api::{Hub, HubError, Worker, WorkerKind};
|
||||
use axum::{
|
||||
extract::State,
|
||||
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
||||
http::Request,
|
||||
middleware::{from_fn_with_state, Next},
|
||||
response::IntoResponse,
|
||||
routing, Extension, Router,
|
||||
};
|
||||
use hyper::Body;
|
||||
use juniper::EmptySubscription;
|
||||
use juniper_axum::{graphiql, graphql, playground};
|
||||
use schema::{Mutation, Query, Schema};
|
||||
use webserver::Webserver;
|
||||
use schema::Schema;
|
||||
use server::ServerContext;
|
||||
use tarpc::server::{BaseChannel, Channel};
|
||||
|
||||
pub fn attach_webserver(router: Router) -> Router {
|
||||
let ws = Arc::new(Webserver::default());
|
||||
let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new()));
|
||||
pub async fn attach_webserver(router: Router) -> Router {
|
||||
let ctx = Arc::new(ServerContext::default());
|
||||
let schema = Arc::new(create_schema());
|
||||
|
||||
let app = Router::new()
|
||||
.route("/graphql", routing::get(playground("/graphql", None)))
|
||||
.route("/graphiql", routing::get(graphiql("/graphql", None)))
|
||||
.route(
|
||||
"/graphql",
|
||||
routing::post(graphql::<Arc<Schema>, Arc<Webserver>>).with_state(ws.clone()),
|
||||
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
|
||||
)
|
||||
.layer(Extension(schema));
|
||||
|
||||
router
|
||||
.merge(app)
|
||||
.route("/hub", routing::get(ws_handler).with_state(ctx.clone()))
|
||||
.fallback(ui::handler)
|
||||
.layer(from_fn_with_state(ws, distributed_tabby_layer))
|
||||
.layer(from_fn_with_state(ctx, distributed_tabby_layer))
|
||||
}
|
||||
|
||||
async fn distributed_tabby_layer(
|
||||
State(ws): State<Arc<Webserver>>,
|
||||
State(ws): State<Arc<ServerContext>>,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response {
|
||||
ws.dispatch_request(request, next).await
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<ServerContext>>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
|
||||
}
|
||||
|
||||
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
|
||||
let transport = WebSocketTransport::from(socket);
|
||||
let server = BaseChannel::with_defaults(transport);
|
||||
let imp = Arc::new(HubImpl::new(state.clone(), addr));
|
||||
tokio::spawn(server.execute(imp.serve())).await.unwrap()
|
||||
}
|
||||
|
||||
pub struct HubImpl {
|
||||
ctx: Arc<ServerContext>,
|
||||
conn: SocketAddr,
|
||||
}
|
||||
|
||||
impl HubImpl {
|
||||
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
||||
Self { ctx, conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[tarpc::server]
|
||||
impl Hub for Arc<HubImpl> {
|
||||
async fn register_worker(
|
||||
self,
|
||||
_context: tarpc::context::Context,
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
name: String,
|
||||
device: String,
|
||||
arch: String,
|
||||
cpu_info: String,
|
||||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
) -> Result<Worker, HubError> {
|
||||
let worker = Worker {
|
||||
name,
|
||||
kind,
|
||||
addr: format!("http://{}:{}", self.conn.ip(), port),
|
||||
device,
|
||||
arch,
|
||||
cpu_info,
|
||||
cpu_count,
|
||||
cuda_devices,
|
||||
};
|
||||
self.ctx.register_worker(worker).await
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,98 +1,23 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode};
|
||||
|
||||
use juniper::{
|
||||
graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject,
|
||||
IntoFieldError, RootNode, ScalarValue, Value,
|
||||
};
|
||||
use juniper_axum::FromStateAndClientAddr;
|
||||
|
||||
use crate::webserver::{Webserver, WebserverError};
|
||||
|
||||
pub struct Request {
|
||||
ws: Arc<Webserver>,
|
||||
client_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl FromStateAndClientAddr<Request, Arc<Webserver>> for Request {
|
||||
fn build(ws: Arc<Webserver>, client_addr: SocketAddr) -> Request {
|
||||
Request { ws, client_addr }
|
||||
}
|
||||
}
|
||||
use crate::{api::Worker, server::ServerContext};
|
||||
|
||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||
impl juniper::Context for Request {}
|
||||
|
||||
#[derive(GraphQLEnum, Clone, Debug)]
|
||||
pub enum WorkerKind {
|
||||
Completion,
|
||||
Chat,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
pub kind: WorkerKind,
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub device: String,
|
||||
pub arch: String,
|
||||
pub cpu_info: String,
|
||||
pub cpu_count: i32,
|
||||
pub cuda_devices: Vec<String>,
|
||||
}
|
||||
impl juniper::Context for ServerContext {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Query;
|
||||
|
||||
#[graphql_object(context = Request)]
|
||||
#[graphql_object(context = ServerContext)]
|
||||
impl Query {
|
||||
async fn workers(request: &Request) -> Vec<Worker> {
|
||||
request.ws.list_workers().await
|
||||
async fn workers(ctx: &ServerContext) -> Vec<Worker> {
|
||||
ctx.list_workers().await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
pub type Schema =
|
||||
RootNode<'static, Query, EmptyMutation<ServerContext>, EmptySubscription<ServerContext>>;
|
||||
|
||||
#[graphql_object(context = Request)]
|
||||
impl Mutation {
|
||||
async fn register_worker(
|
||||
request: &Request,
|
||||
port: i32,
|
||||
kind: WorkerKind,
|
||||
name: String,
|
||||
device: String,
|
||||
arch: String,
|
||||
cpu_info: String,
|
||||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
) -> Result<Worker, WebserverError> {
|
||||
let ws = &request.ws;
|
||||
let worker = Worker {
|
||||
name,
|
||||
kind,
|
||||
addr: format!("http://{}:{}", request.client_addr.ip(), port),
|
||||
device,
|
||||
arch,
|
||||
cpu_info,
|
||||
cpu_count,
|
||||
cuda_devices,
|
||||
};
|
||||
ws.register_worker(worker).await
|
||||
}
|
||||
}
|
||||
|
||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Request>>;
|
||||
|
||||
impl<S: ScalarValue> IntoFieldError<S> for WebserverError {
|
||||
fn into_field_error(self) -> FieldError<S> {
|
||||
let msg = format!("{}", &self);
|
||||
match self {
|
||||
WebserverError::InvalidToken(token) => FieldError::new(
|
||||
msg,
|
||||
graphql_value!({
|
||||
"token": token
|
||||
}),
|
||||
),
|
||||
_ => FieldError::new(msg, Value::Null),
|
||||
}
|
||||
}
|
||||
pub fn create_schema() -> Schema {
|
||||
Schema::new(Query, EmptyMutation::new(), EmptySubscription::new())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,35 +1,22 @@
|
|||
mod proxy;
|
||||
mod worker;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
schema::{Worker, WorkerKind},
|
||||
worker,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum WebserverError {
|
||||
#[error("Invalid worker token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
}
|
||||
|
||||
use crate::api::{HubError, Worker, WorkerKind};
|
||||
#[derive(Default)]
|
||||
pub struct Webserver {
|
||||
pub struct ServerContext {
|
||||
client: Client<HttpConnector>,
|
||||
completion: worker::WorkerGroup,
|
||||
chat: worker::WorkerGroup,
|
||||
}
|
||||
|
||||
impl Webserver {
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, WebserverError> {
|
||||
impl ServerContext {
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
|
||||
let worker = match worker.kind {
|
||||
WorkerKind::Completion => self.completion.register(worker).await,
|
||||
WorkerKind::Chat => self.chat.register(worker).await,
|
||||
|
|
@ -42,7 +29,7 @@ impl Webserver {
|
|||
);
|
||||
Ok(worker)
|
||||
} else {
|
||||
Err(WebserverError::RequiresEnterpriseLicense)
|
||||
Err(HubError::RequiresEnterpriseLicense)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::error;
|
||||
|
||||
use crate::schema::Worker;
|
||||
use crate::api::Worker;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WorkerGroup {
|
||||
|
|
@ -51,7 +51,7 @@ fn random_index(size: usize) -> usize {
|
|||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::schema::WorkerKind;
|
||||
use crate::api::WorkerKind;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_group() {
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
use std::{
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use axum::extract::ws;
|
||||
use futures::{Sink, Stream};
|
||||
use pin_project::pin_project;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite as tt;
|
||||
use tokio_tungstenite::tungstenite as ts;
|
||||
|
||||
pub trait IntoData {
|
||||
fn into_data(self) -> Option<Vec<u8>>;
|
||||
}
|
||||
|
||||
impl IntoData for ws::Message {
|
||||
fn into_data(self) -> Option<Vec<u8>> {
|
||||
match self {
|
||||
ws::Message::Binary(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoData for ts::Message {
|
||||
fn into_data(self) -> Option<Vec<u8>> {
|
||||
match self {
|
||||
ts::Message::Binary(x) => Some(x),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub struct WebSocketTransport<Req, Resp, Message, Transport, Error>
|
||||
where
|
||||
Message: IntoData + From<Vec<u8>>,
|
||||
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
|
||||
{
|
||||
#[pin]
|
||||
inner: Transport,
|
||||
ghost: PhantomData<(Req, Resp)>,
|
||||
}
|
||||
|
||||
impl<Req, Resp> From<ws::WebSocket>
|
||||
for WebSocketTransport<Req, Resp, ws::Message, ws::WebSocket, axum::Error>
|
||||
{
|
||||
fn from(inner: ws::WebSocket) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp> From<tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>>
|
||||
for WebSocketTransport<
|
||||
Req,
|
||||
Resp,
|
||||
ts::Message,
|
||||
tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>,
|
||||
tt::tungstenite::Error,
|
||||
>
|
||||
{
|
||||
fn from(inner: tokio_tungstenite::WebSocketStream<tt::MaybeTlsStream<TcpStream>>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
ghost: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, Message, Transport, Error> Stream
|
||||
for WebSocketTransport<Req, Resp, Message, Transport, Error>
|
||||
where
|
||||
Req: for<'de> serde::Deserialize<'de>,
|
||||
Message: IntoData + From<Vec<u8>> + std::fmt::Debug,
|
||||
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
|
||||
{
|
||||
type Item = Result<Req, Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match futures::ready!(self.as_mut().project().inner.poll_next(cx)) {
|
||||
Some(Ok(msg)) => {
|
||||
let bin = msg.into_data();
|
||||
match bin {
|
||||
Some(bin) => Poll::Ready(Some(Ok(bincode::deserialize_from::<&[u8], Req>(
|
||||
bin.as_ref(),
|
||||
)
|
||||
.unwrap()))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
Some(Err(err)) => Poll::Ready(Some(Err(err))),
|
||||
None => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Req, Resp, Message, Transport, Error> Sink<Resp>
|
||||
for WebSocketTransport<Req, Resp, Message, Transport, Error>
|
||||
where
|
||||
Resp: serde::Serialize,
|
||||
Message: IntoData + From<Vec<u8>>,
|
||||
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.as_mut().project().inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Resp) -> Result<(), Self::Error> {
|
||||
let msg = Message::from(bincode::serialize(&item).unwrap());
|
||||
self.as_mut().project().inner.start_send(msg)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.as_mut().project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.as_mut().project().inner.poll_close(cx)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue