refactor: use tarpc for easier worker <-> hub communication (#781)

* temp

* generic

* adapt client

* rename to api

* Revert "rename to api"

This reverts commit 8a51b24fecd76a78e6df576ec51605b8d8418975.

* refactor: remove uselss mutation

* remove useless connection

* cleanup api structure

* restructure

* add webserver api error

* webserver.rs -> server.rs

* rename service to Hub

* update schema

* update naming

* shrink features

* update

* mv worker.rs -> server/worker.rs
release-fix-intellij-update-support-version-range
Meng Zhang 2023-11-14 12:48:20 -08:00 committed by GitHub
parent e521f0637c
commit 618009373b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 397 additions and 278 deletions

135
Cargo.lock generated
View File

@ -442,6 +442,7 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core",
"base64 0.21.2",
"bitflags 1.3.2",
"bytes",
"futures-util",
@ -459,8 +460,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@ -552,6 +555,15 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -1711,15 +1723,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "graphql-introspection-query"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d"
dependencies = [
"serde",
]
[[package]]
name = "graphql-parser"
version = "0.3.0"
@ -1730,56 +1733,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "graphql-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
dependencies = [
"combine",
"thiserror",
]
[[package]]
name = "graphql_client"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa"
dependencies = [
"graphql_query_derive",
"reqwest",
"serde",
"serde_json",
]
[[package]]
name = "graphql_client_codegen"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506"
dependencies = [
"graphql-introspection-query",
"graphql-parser 0.4.0",
"heck 0.4.1",
"lazy_static",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 1.0.109",
]
[[package]]
name = "graphql_query_derive"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c"
dependencies = [
"graphql_client_codegen",
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "h2"
version = "0.3.19"
@ -2207,7 +2160,7 @@ dependencies = [
"fnv",
"futures",
"futures-enum",
"graphql-parser 0.3.0",
"graphql-parser",
"indexmap 1.9.3",
"juniper_codegen",
"serde",
@ -3056,18 +3009,18 @@ dependencies = [
[[package]]
name = "pin-project"
version = "1.1.0"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead"
checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.0"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07"
checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
@ -4344,7 +4297,6 @@ dependencies = [
"chrono",
"clap 4.4.7",
"futures",
"graphql_client",
"http-api-bindings",
"hyper",
"lazy_static",
@ -4453,14 +4405,20 @@ version = "0.6.0-dev"
dependencies = [
"anyhow",
"axum",
"bincode",
"futures",
"hyper",
"juniper",
"juniper-axum",
"lazy_static",
"mime_guess",
"pin-project",
"rust-embed 8.0.0",
"serde",
"tarpc",
"thiserror",
"tokio",
"tokio-tungstenite",
"tracing",
"unicase",
]
@ -4616,6 +4574,41 @@ dependencies = [
"xattr",
]
[[package]]
name = "tarpc"
version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f41bce44d290df0598ae4b9cd6ea7f58f651fd3aa4af1b26060c4fa32b08af7"
dependencies = [
"anyhow",
"fnv",
"futures",
"humantime",
"opentelemetry",
"pin-project",
"rand 0.8.5",
"serde",
"static_assertions",
"tarpc-plugins",
"thiserror",
"tokio",
"tokio-serde",
"tokio-util",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "tarpc-plugins"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ee42b4e559f17bce0385ebf511a7beb67d5cc33c12c96b7f4e9789919d9c10f"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "temp_testdir"
version = "0.2.3"
@ -4841,6 +4834,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "911a61637386b789af998ee23f50aa30d5fd7edcec8d6d3dedae5e5815205466"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project",
]
[[package]]
name = "tokio-stream"
version = "0.1.14"

View File

@ -1,34 +1,26 @@
pub mod extract;
pub mod response;
use std::{future, net::SocketAddr};
use std::{future, sync::Arc};
use axum::{
extract::{ConnectInfo, Extension, State},
extract::{Extension, State},
response::{Html, IntoResponse},
};
use juniper_graphql_ws::Schema;
use self::{extract::JuniperRequest, response::JuniperResponse};
pub trait FromStateAndClientAddr<C, S> {
fn build(state: S, client_addr: SocketAddr) -> C;
}
#[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S, C>(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<C>,
pub async fn graphql<S>(
State(state): State<Arc<S::Context>>,
Extension(schema): Extension<S>,
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse
where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromStateAndClientAddr<S::Context, C>,
C: Clone,
{
let context = S::Context::build(state.clone(), addr);
JuniperResponse(req.execute(schema.root_node(), &context).await).into_response()
JuniperResponse(req.execute(schema.root_node(), &state).await).into_response()
}
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].

View File

@ -47,8 +47,6 @@ async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" }
thiserror.workspace = true
chrono = "0.4.31"
graphql_client = { version = "0.13.0", features = ["reqwest"] }
reqwest.workspace = true
[dependencies.uuid]
version = "1.3.3"

View File

@ -1,23 +0,0 @@
mutation RegisterWorker(
$port: Int!
$kind: WorkerKind!
$name: String!
$device: String!
$arch: String!
$cpuInfo: String!
$cpuCount: Int!
$cudaDevices: [String!]!
) {
worker: registerWorker(
port: $port
kind: $kind
name: $name
device: $device
arch: $arch
cpuInfo: $cpuInfo
cpuCount: $cpuCount
cudaDevices: $cudaDevices
) {
addr
}
}

View File

@ -14,6 +14,7 @@ use opentelemetry::{
};
use opentelemetry_otlp::WithExportConfig;
use tabby_common::config::Config;
use tabby_webserver::api::WorkerKind;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Parser)]
@ -105,10 +106,8 @@ async fn main() {
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
Commands::WorkerCompletion(args) => {
worker::main(worker::WorkerKind::Completion, args).await
}
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
Commands::WorkerCompletion(args) => worker::main(WorkerKind::Completion, args).await,
Commands::WorkerChat(args) => worker::main(WorkerKind::Chat, args).await,
}
opentelemetry::global::shutdown_tracer_provider();

View File

@ -109,7 +109,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.merge(api_router(args, config).await)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
let app = attach_webserver(app);
let app = attach_webserver(app).await;
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address);

View File

@ -4,10 +4,11 @@ use std::{
sync::Arc,
};
use anyhow::Result;
use axum::{routing, Router};
use clap::Args;
use graphql_client::{reqwest::post_graphql, GraphQLQuery};
use hyper::Server;
use tabby_webserver::api::WorkerKind;
use tracing::{info, warn};
use crate::{
@ -23,13 +24,6 @@ use crate::{
Device,
};
#[derive(GraphQLQuery)]
#[graphql(
schema_path = "../../ee/tabby-webserver/graphql/schema.graphql",
query_path = "./graphql/worker.query.graphql"
)]
struct RegisterWorker;
#[derive(Args)]
pub struct WorkerArgs {
/// URL to register this worker.
@ -56,7 +50,7 @@ pub struct WorkerArgs {
async fn make_chat_route(args: &WorkerArgs) -> Router {
let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
request_register(register_worker::WorkerKind::CHAT, args).await;
request_register(WorkerKind::Chat, args).await;
Router::new().route(
"/v1beta/chat/completions",
@ -71,7 +65,7 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
);
request_register(register_worker::WorkerKind::COMPLETION, args).await;
request_register(WorkerKind::Completion, args).await;
Router::new().route(
"/v1/completions",
@ -79,11 +73,6 @@ async fn make_completion_route(args: &WorkerArgs) -> Router {
)
}
pub enum WorkerKind {
Chat,
Completion,
}
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
download_model_if_needed(&args.model).await;
@ -103,44 +92,45 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}
async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) {
request_register_impl(
async fn request_register(kind: WorkerKind, args: &WorkerArgs) {
if let Err(err) = request_register_impl(
kind,
args.url.clone(),
args.port as i64,
args.port,
args.model.to_owned(),
args.device.to_string(),
)
.await;
.await
{
warn!("Failed to register worker: {}", err)
}
}
async fn request_register_impl(
kind: register_worker::WorkerKind,
kind: WorkerKind,
url: String,
port: i64,
port: u16,
name: String,
device: String,
) {
let client = reqwest::Client::new();
) -> Result<()> {
let client = tabby_webserver::api::create_client(url).await;
let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = read_cuda_devices().unwrap_or_default();
let variables = register_worker::Variables {
port,
kind,
name,
device,
arch: ARCH.to_string(),
cpu_info,
cpu_count: cpu_count as i64,
cuda_devices,
};
let worker = client
.register_worker(
tabby_webserver::api::tracing_context(),
kind,
port as i32,
name,
device,
ARCH.to_string(),
cpu_info,
cpu_count as i32,
cuda_devices,
)
.await??;
let url = format!("{}/graphql", url);
match post_graphql::<RegisterWorker, _>(&client, &url, variables).await {
Ok(x) => {
let addr = x.data.unwrap().worker.addr;
info!("Worker alive at {}", addr);
}
Err(err) => warn!("Failed to register worker: {}", err),
}
info!("Worker alive at {}", worker.addr);
Ok(())
}

View File

@ -7,15 +7,21 @@ homepage.workspace = true
[dependencies]
anyhow.workspace = true
axum.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
futures.workspace = true
hyper = { workspace = true, features=["client"]}
juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
mime_guess = "2.0.4"
pin-project = "1.1.3"
rust-embed = "8.0.0"
serde.workspace = true
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio.workspace = true
tokio-tungstenite = "0.20.1"
tracing.workspace = true
unicase = "2.7.0"

View File

@ -1,10 +1,9 @@
use std::fs::write;
use juniper::EmptySubscription;
use tabby_webserver::schema::{Mutation, Query, Schema};
use tabby_webserver::create_schema;
fn main() {
let schema = Schema::new(Query, Mutation, EmptySubscription::new());
let schema = create_schema();
write(
"ee/tabby-webserver/graphql/schema.graphql",
schema.as_schema_language(),

View File

@ -3,10 +3,6 @@ enum WorkerKind {
CHAT
}
type Mutation {
registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker!
}
type Query {
workers: [Worker!]!
}
@ -24,5 +20,4 @@ type Worker {
schema {
query: Query
mutation: Mutation
}

View File

@ -0,0 +1,57 @@
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio_tungstenite::connect_async;
use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
#[error("Invalid worker token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
#[tarpc::service]
pub trait Hub {
async fn register_worker(
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, HubError>;
}
pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}
pub async fn create_client(addr: String) -> HubClient {
let addr = format!("ws://{}/hub", addr);
let (socket, _) = connect_async(&addr).await.unwrap();
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
}

View File

@ -1,45 +1,107 @@
pub mod schema;
pub mod api;
mod schema;
pub use schema::create_schema;
use websocket::WebSocketTransport;
mod server;
mod ui;
mod webserver;
mod worker;
mod websocket;
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
use api::{Hub, HubError, Worker, WorkerKind};
use axum::{
extract::State,
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request,
middleware::{from_fn_with_state, Next},
response::IntoResponse,
routing, Extension, Router,
};
use hyper::Body;
use juniper::EmptySubscription;
use juniper_axum::{graphiql, graphql, playground};
use schema::{Mutation, Query, Schema};
use webserver::Webserver;
use schema::Schema;
use server::ServerContext;
use tarpc::server::{BaseChannel, Channel};
pub fn attach_webserver(router: Router) -> Router {
let ws = Arc::new(Webserver::default());
let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new()));
pub async fn attach_webserver(router: Router) -> Router {
let ctx = Arc::new(ServerContext::default());
let schema = Arc::new(create_schema());
let app = Router::new()
.route("/graphql", routing::get(playground("/graphql", None)))
.route("/graphiql", routing::get(graphiql("/graphql", None)))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>, Arc<Webserver>>).with_state(ws.clone()),
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()),
)
.layer(Extension(schema));
router
.merge(app)
.route("/hub", routing::get(ws_handler).with_state(ctx.clone()))
.fallback(ui::handler)
.layer(from_fn_with_state(ws, distributed_tabby_layer))
.layer(from_fn_with_state(ctx, distributed_tabby_layer))
}
async fn distributed_tabby_layer(
State(ws): State<Arc<Webserver>>,
State(ws): State<Arc<ServerContext>>,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
ws.dispatch_request(request, next).await
}
async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<ServerContext>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(state, socket, addr))
}
async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: SocketAddr) {
let transport = WebSocketTransport::from(socket);
let server = BaseChannel::with_defaults(transport);
let imp = Arc::new(HubImpl::new(state.clone(), addr));
tokio::spawn(server.execute(imp.serve())).await.unwrap()
}
pub struct HubImpl {
ctx: Arc<ServerContext>,
conn: SocketAddr,
}
impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
Self { ctx, conn }
}
}
#[tarpc::server]
impl Hub for Arc<HubImpl> {
async fn register_worker(
self,
_context: tarpc::context::Context,
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, HubError> {
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", self.conn.ip(), port),
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
};
self.ctx.register_worker(worker).await
}
}

View File

@ -1,98 +1,23 @@
use std::{net::SocketAddr, sync::Arc};
use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode};
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject,
IntoFieldError, RootNode, ScalarValue, Value,
};
use juniper_axum::FromStateAndClientAddr;
use crate::webserver::{Webserver, WebserverError};
pub struct Request {
ws: Arc<Webserver>,
client_addr: SocketAddr,
}
impl FromStateAndClientAddr<Request, Arc<Webserver>> for Request {
fn build(ws: Arc<Webserver>, client_addr: SocketAddr) -> Request {
Request { ws, client_addr }
}
}
use crate::{api::Worker, server::ServerContext};
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Request {}
#[derive(GraphQLEnum, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
impl juniper::Context for ServerContext {}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Request)]
#[graphql_object(context = ServerContext)]
impl Query {
async fn workers(request: &Request) -> Vec<Worker> {
request.ws.list_workers().await
async fn workers(ctx: &ServerContext) -> Vec<Worker> {
ctx.list_workers().await
}
}
pub struct Mutation;
pub type Schema =
RootNode<'static, Query, EmptyMutation<ServerContext>, EmptySubscription<ServerContext>>;
#[graphql_object(context = Request)]
impl Mutation {
async fn register_worker(
request: &Request,
port: i32,
kind: WorkerKind,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, WebserverError> {
let ws = &request.ws;
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", request.client_addr.ip(), port),
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
};
ws.register_worker(worker).await
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Request>>;
impl<S: ScalarValue> IntoFieldError<S> for WebserverError {
fn into_field_error(self) -> FieldError<S> {
let msg = format!("{}", &self);
match self {
WebserverError::InvalidToken(token) => FieldError::new(
msg,
graphql_value!({
"token": token
}),
),
_ => FieldError::new(msg, Value::Null),
}
}
pub fn create_schema() -> Schema {
Schema::new(Query, EmptyMutation::new(), EmptySubscription::new())
}

View File

@ -1,35 +1,22 @@
mod proxy;
mod worker;
use std::net::SocketAddr;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use thiserror::Error;
use tracing::{info, warn};
use crate::{
schema::{Worker, WorkerKind},
worker,
};
#[derive(Error, Debug)]
pub enum WebserverError {
#[error("Invalid worker token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
use crate::api::{HubError, Worker, WorkerKind};
#[derive(Default)]
pub struct Webserver {
pub struct ServerContext {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
}
impl Webserver {
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, WebserverError> {
impl ServerContext {
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await,
@ -42,7 +29,7 @@ impl Webserver {
);
Ok(worker)
} else {
Err(WebserverError::RequiresEnterpriseLicense)
Err(HubError::RequiresEnterpriseLicense)
}
}

View File

@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::error;
use crate::schema::Worker;
use crate::api::Worker;
#[derive(Default)]
pub struct WorkerGroup {
@ -51,7 +51,7 @@ fn random_index(size: usize) -> usize {
mod tests {
use super::*;
use crate::schema::WorkerKind;
use crate::api::WorkerKind;
#[tokio::test]
async fn test_worker_group() {

View File

@ -0,0 +1,127 @@
use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use axum::extract::ws;
use futures::{Sink, Stream};
use pin_project::pin_project;
use tokio::net::TcpStream;
use tokio_tungstenite as tt;
use tokio_tungstenite::tungstenite as ts;
pub trait IntoData {
fn into_data(self) -> Option<Vec<u8>>;
}
impl IntoData for ws::Message {
fn into_data(self) -> Option<Vec<u8>> {
match self {
ws::Message::Binary(x) => Some(x),
_ => None,
}
}
}
impl IntoData for ts::Message {
fn into_data(self) -> Option<Vec<u8>> {
match self {
ts::Message::Binary(x) => Some(x),
_ => None,
}
}
}
#[pin_project]
pub struct WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Message: IntoData + From<Vec<u8>>,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
#[pin]
inner: Transport,
ghost: PhantomData<(Req, Resp)>,
}
impl<Req, Resp> From<ws::WebSocket>
for WebSocketTransport<Req, Resp, ws::Message, ws::WebSocket, axum::Error>
{
fn from(inner: ws::WebSocket) -> Self {
Self {
inner,
ghost: PhantomData,
}
}
}
impl<Req, Resp> From<tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>>
for WebSocketTransport<
Req,
Resp,
ts::Message,
tt::WebSocketStream<tt::MaybeTlsStream<TcpStream>>,
tt::tungstenite::Error,
>
{
fn from(inner: tokio_tungstenite::WebSocketStream<tt::MaybeTlsStream<TcpStream>>) -> Self {
Self {
inner,
ghost: PhantomData,
}
}
}
impl<Req, Resp, Message, Transport, Error> Stream
for WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Req: for<'de> serde::Deserialize<'de>,
Message: IntoData + From<Vec<u8>> + std::fmt::Debug,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
type Item = Result<Req, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match futures::ready!(self.as_mut().project().inner.poll_next(cx)) {
Some(Ok(msg)) => {
let bin = msg.into_data();
match bin {
Some(bin) => Poll::Ready(Some(Ok(bincode::deserialize_from::<&[u8], Req>(
bin.as_ref(),
)
.unwrap()))),
None => Poll::Ready(None),
}
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}
impl<Req, Resp, Message, Transport, Error> Sink<Resp>
for WebSocketTransport<Req, Resp, Message, Transport, Error>
where
Resp: serde::Serialize,
Message: IntoData + From<Vec<u8>>,
Transport: Stream<Item = Result<Message, Error>> + Sink<Message, Error = Error>,
{
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Resp) -> Result<(), Self::Error> {
let msg = Message::from(bincode::serialize(&item).unwrap());
self.as_mut().project().inner.start_send(msg)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.as_mut().project().inner.poll_close(cx)
}
}