127 lines
3.1 KiB
Rust
127 lines
3.1 KiB
Rust
use async_trait::async_trait;
|
|
use juniper::{GraphQLEnum, GraphQLObject};
|
|
use serde::{Deserialize, Serialize};
|
|
use tabby_common::api::{
|
|
code::{CodeSearch, CodeSearchError, SearchResponse},
|
|
event::RawEventLogger,
|
|
};
|
|
use thiserror::Error;
|
|
use tokio_tungstenite::connect_async;
|
|
|
|
use crate::websocket::WebSocketTransport;
|
|
|
|
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
|
|
pub enum WorkerKind {
|
|
Completion,
|
|
Chat,
|
|
}
|
|
|
|
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
|
|
pub struct Worker {
|
|
pub kind: WorkerKind,
|
|
pub name: String,
|
|
pub addr: String,
|
|
pub device: String,
|
|
pub arch: String,
|
|
pub cpu_info: String,
|
|
pub cpu_count: i32,
|
|
pub cuda_devices: Vec<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Error, Debug)]
|
|
pub enum RegisterWorkerError {
|
|
#[error("Invalid token")]
|
|
InvalidToken(String),
|
|
|
|
#[error("Feature requires enterprise license")]
|
|
RequiresEnterpriseLicense,
|
|
|
|
#[error("Each hub client should only calls register_worker once")]
|
|
RegisterWorkerOnce,
|
|
}
|
|
|
|
#[tarpc::service]
|
|
pub trait Hub {
|
|
async fn register_worker(
|
|
kind: WorkerKind,
|
|
port: i32,
|
|
name: String,
|
|
device: String,
|
|
arch: String,
|
|
cpu_info: String,
|
|
cpu_count: i32,
|
|
cuda_devices: Vec<String>,
|
|
token: String,
|
|
) -> Result<Worker, RegisterWorkerError>;
|
|
|
|
async fn log_event(content: String);
|
|
|
|
async fn search(q: String, limit: usize, offset: usize) -> SearchResponse;
|
|
|
|
async fn search_in_language(
|
|
language: String,
|
|
tokens: Vec<String>,
|
|
limit: usize,
|
|
offset: usize,
|
|
) -> SearchResponse;
|
|
}
|
|
|
|
pub fn tracing_context() -> tarpc::context::Context {
|
|
tarpc::context::current()
|
|
}
|
|
|
|
pub async fn create_client(addr: &str) -> HubClient {
|
|
let addr = format!("ws://{}/hub", addr);
|
|
let (socket, _) = connect_async(&addr).await.unwrap();
|
|
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
|
|
}
|
|
|
|
impl RawEventLogger for HubClient {
|
|
fn log(&self, content: String) {
|
|
let context = tarpc::context::current();
|
|
let client = self.clone();
|
|
|
|
tokio::spawn(async move { client.log_event(context, content).await });
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl CodeSearch for HubClient {
|
|
async fn search(
|
|
&self,
|
|
q: &str,
|
|
limit: usize,
|
|
offset: usize,
|
|
) -> Result<SearchResponse, CodeSearchError> {
|
|
match self
|
|
.search(tracing_context(), q.to_owned(), limit, offset)
|
|
.await
|
|
{
|
|
Ok(serp) => Ok(serp),
|
|
Err(_) => Err(CodeSearchError::NotReady),
|
|
}
|
|
}
|
|
|
|
async fn search_in_language(
|
|
&self,
|
|
language: &str,
|
|
tokens: &[String],
|
|
limit: usize,
|
|
offset: usize,
|
|
) -> Result<SearchResponse, CodeSearchError> {
|
|
match self
|
|
.search_in_language(
|
|
tracing_context(),
|
|
language.to_owned(),
|
|
tokens.to_owned(),
|
|
limit,
|
|
offset,
|
|
)
|
|
.await
|
|
{
|
|
Ok(serp) => Ok(serp),
|
|
Err(_) => Err(CodeSearchError::NotReady),
|
|
}
|
|
}
|
|
}
|