tabby/ee/tabby-webserver/src/api.rs

71 lines
1.7 KiB
Rust

use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use tabby_common::api::event::RawEventLogger;
use thiserror::Error;
use tokio_tungstenite::connect_async;
use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
#[tarpc::service]
pub trait Hub {
async fn register_worker(
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError>;
async fn log_event(content: String);
}
pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}
pub async fn create_client(addr: &str) -> HubClient {
let addr = format!("ws://{}/hub", addr);
let (socket, _) = connect_async(&addr).await.unwrap();
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
}
impl RawEventLogger for HubClient {
fn log(&self, content: String) {
let context = tarpc::context::current();
let client = self.clone();
tokio::spawn(async move { client.log_event(context, content).await });
}
}