use async_trait::async_trait; use juniper::{GraphQLEnum, GraphQLObject}; use serde::{Deserialize, Serialize}; use tabby_common::api::{ code::{CodeSearch, CodeSearchError, SearchResponse}, event::RawEventLogger, }; use thiserror::Error; use tokio_tungstenite::connect_async; use crate::websocket::WebSocketTransport; #[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)] pub enum WorkerKind { Completion, Chat, } #[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)] pub struct Worker { pub kind: WorkerKind, pub name: String, pub addr: String, pub device: String, pub arch: String, pub cpu_info: String, pub cpu_count: i32, pub cuda_devices: Vec, } #[derive(Serialize, Deserialize, Error, Debug)] pub enum RegisterWorkerError { #[error("Invalid token")] InvalidToken(String), #[error("Feature requires enterprise license")] RequiresEnterpriseLicense, #[error("Each hub client should only calls register_worker once")] RegisterWorkerOnce, } #[tarpc::service] pub trait Hub { async fn register_worker( kind: WorkerKind, port: i32, name: String, device: String, arch: String, cpu_info: String, cpu_count: i32, cuda_devices: Vec, token: String, ) -> Result; async fn log_event(content: String); async fn search(q: String, limit: usize, offset: usize) -> SearchResponse; async fn search_in_language( language: String, tokens: Vec, limit: usize, offset: usize, ) -> SearchResponse; } pub fn tracing_context() -> tarpc::context::Context { tarpc::context::current() } pub async fn create_client(addr: &str) -> HubClient { let addr = format!("ws://{}/hub", addr); let (socket, _) = connect_async(&addr).await.unwrap(); HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn() } impl RawEventLogger for HubClient { fn log(&self, content: String) { let context = tarpc::context::current(); let client = self.clone(); tokio::spawn(async move { client.log_event(context, content).await }); } } #[async_trait] impl CodeSearch for HubClient { async fn search( &self, q: &str, limit: usize, offset: usize, ) -> Result { match self .search(tracing_context(), q.to_owned(), limit, offset) .await { Ok(serp) => Ok(serp), Err(_) => Err(CodeSearchError::NotReady), } } async fn search_in_language( &self, language: &str, tokens: &[String], limit: usize, offset: usize, ) -> Result { match self .search_in_language( tracing_context(), language.to_owned(), tokens.to_owned(), limit, offset, ) .await { Ok(serp) => Ok(serp), Err(_) => Err(CodeSearchError::NotReady), } } }