tabby/ee/tabby-webserver/src/api.rs

127 lines
3.1 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use tabby_common::api::{
code::{CodeSearch, CodeSearchError, SearchResponse},
event::RawEventLogger,
};
use thiserror::Error;
use tokio_tungstenite::connect_async;
use crate::websocket::WebSocketTransport;
#[derive(GraphQLEnum, Serialize, Deserialize, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Serialize, Deserialize, Clone, Debug)]
pub struct Worker {
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[tarpc::service]
pub trait Hub {
async fn register_worker(
kind: WorkerKind,
port: i32,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, RegisterWorkerError>;
async fn log_event(content: String);
async fn search(q: String, limit: usize, offset: usize) -> SearchResponse;
async fn search_in_language(
language: String,
tokens: Vec<String>,
limit: usize,
offset: usize,
) -> SearchResponse;
}
pub fn tracing_context() -> tarpc::context::Context {
tarpc::context::current()
}
pub async fn create_client(addr: &str) -> HubClient {
let addr = format!("ws://{}/hub", addr);
let (socket, _) = connect_async(&addr).await.unwrap();
HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn()
}
impl RawEventLogger for HubClient {
fn log(&self, content: String) {
let context = tarpc::context::current();
let client = self.clone();
tokio::spawn(async move { client.log_event(context, content).await });
}
}
#[async_trait]
impl CodeSearch for HubClient {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
match self
.search(tracing_context(), q.to_owned(), limit, offset)
.await
{
Ok(serp) => Ok(serp),
Err(_) => Err(CodeSearchError::NotReady),
}
}
async fn search_in_language(
&self,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
match self
.search_in_language(
tracing_context(),
language.to_owned(),
tokens.to_owned(),
limit,
offset,
)
.await
{
Ok(serp) => Ok(serp),
Err(_) => Err(CodeSearchError::NotReady),
}
}
}