2023-11-14 20:48:20 +00:00
|
|
|
use juniper::{GraphQLEnum, GraphQLObject};
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
2023-11-18 23:17:54 +00:00
|
|
|
use tabby_common::api::event::RawEventLogger;
|
2023-11-14 20:48:20 +00:00
|
|
|
use thiserror::Error;
|
|
|
|
|
use tokio_tungstenite::connect_async;
|
|
|
|
|
|
|
|
|
|
use crate::websocket::WebSocketTransport;
|
|
|
|
|
|
|
|
|
|
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
|
|
|
|
pub enum WorkerKind {
|
|
|
|
|
Completion,
|
|
|
|
|
Chat,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
|
|
|
|
pub struct Worker {
|
|
|
|
|
pub kind: WorkerKind,
|
|
|
|
|
pub name: String,
|
|
|
|
|
pub addr: String,
|
|
|
|
|
pub device: String,
|
|
|
|
|
pub arch: String,
|
|
|
|
|
pub cpu_info: String,
|
|
|
|
|
pub cpu_count: i32,
|
|
|
|
|
pub cuda_devices: Vec<String>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Serialize, Deserialize, Error, Debug)]
|
|
|
|
|
pub enum HubError {
|
2023-11-17 07:05:39 +00:00
|
|
|
#[error("Invalid token")]
|
2023-11-14 20:48:20 +00:00
|
|
|
InvalidToken(String),
|
|
|
|
|
|
|
|
|
|
#[error("Feature requires enterprise license")]
|
|
|
|
|
RequiresEnterpriseLicense,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tarpc::service]
|
|
|
|
|
pub trait Hub {
|
|
|
|
|
async fn register_worker(
|
|
|
|
|
kind: WorkerKind,
|
|
|
|
|
port: i32,
|
|
|
|
|
name: String,
|
|
|
|
|
device: String,
|
|
|
|
|
arch: String,
|
|
|
|
|
cpu_info: String,
|
|
|
|
|
cpu_count: i32,
|
|
|
|
|
cuda_devices: Vec<String>,
|
2023-11-17 07:05:39 +00:00
|
|
|
token: String,
|
2023-11-14 20:48:20 +00:00
|
|
|
) -> Result<Worker, HubError>;
|
2023-11-18 23:17:54 +00:00
|
|
|
|
|
|
|
|
async fn log_event(content: String);
|
2023-11-14 20:48:20 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn tracing_context() -> tarpc::context::Context {
|
|
|
|
|
tarpc::context::current()
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-18 23:17:54 +00:00
|
|
|
pub async fn create_client(addr: &str) -> HubClient {
|
2023-11-14 20:48:20 +00:00
|
|
|
let addr = format!("ws://{}/hub", addr);
|
|
|
|
|
let (socket, _) = connect_async(&addr).await.unwrap();
|
|
|
|
|
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
|
|
|
|
|
}
|
2023-11-18 23:17:54 +00:00
|
|
|
|
|
|
|
|
impl RawEventLogger for HubClient {
|
|
|
|
|
fn log(&self, content: String) {
|
|
|
|
|
let context = tarpc::context::current();
|
|
|
|
|
let client = self.clone();
|
|
|
|
|
|
|
|
|
|
tokio::spawn(async move { client.log_event(context, content).await });
|
|
|
|
|
}
|
|
|
|
|
}
|