feat: implement worker unregistration (#872)
* feat: implement worker unregisteration logic * refactor: rename HubError -> RegisterWorkerErrorwsxiaoys-patch-3
parent
a067366b5d
commit
3bd3a3304d
|
|
@ -29,12 +29,15 @@ pub struct Worker {
|
|||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||
pub enum HubError {
|
||||
pub enum RegisterWorkerError {
|
||||
#[error("Invalid token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
|
||||
#[error("Each hub client should only calls register_worker once")]
|
||||
RegisterWorkerOnce,
|
||||
}
|
||||
|
||||
#[tarpc::service]
|
||||
|
|
@ -49,7 +52,7 @@ pub trait Hub {
|
|||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
token: String,
|
||||
) -> Result<Worker, HubError>;
|
||||
) -> Result<Worker, RegisterWorkerError>;
|
||||
|
||||
async fn log_event(content: String);
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use tabby_common::api::{
|
|||
code::{CodeSearch, SearchResponse},
|
||||
event::RawEventLogger,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, warn};
|
||||
use websocket::WebSocketTransport;
|
||||
|
||||
|
|
@ -16,7 +17,7 @@ mod websocket;
|
|||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use api::{Hub, HubError, Worker, WorkerKind};
|
||||
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
|
||||
use axum::{
|
||||
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
||||
http::Request,
|
||||
|
|
@ -83,11 +84,31 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
|
|||
pub struct HubImpl {
|
||||
ctx: Arc<ServerContext>,
|
||||
conn: SocketAddr,
|
||||
|
||||
worker_addr: Arc<Mutex<String>>,
|
||||
}
|
||||
|
||||
impl HubImpl {
|
||||
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
||||
Self { ctx, conn }
|
||||
Self {
|
||||
ctx,
|
||||
conn,
|
||||
worker_addr: Arc::new(Mutex::new("".to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HubImpl {
|
||||
fn drop(&mut self) {
|
||||
let ctx = self.ctx.clone();
|
||||
let worker_addr = self.worker_addr.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let worker_addr = worker_addr.lock().await;
|
||||
if !worker_addr.is_empty() {
|
||||
ctx.unregister_worker(worker_addr.as_str()).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -105,27 +126,39 @@ impl Hub for Arc<HubImpl> {
|
|||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
token: String,
|
||||
) -> Result<Worker, HubError> {
|
||||
) -> Result<Worker, RegisterWorkerError> {
|
||||
if token.is_empty() {
|
||||
return Err(HubError::InvalidToken("Empty worker token".to_string()));
|
||||
return Err(RegisterWorkerError::InvalidToken(
|
||||
"Empty worker token".to_string(),
|
||||
));
|
||||
}
|
||||
let server_token = match self.ctx.read_registration_token().await {
|
||||
Ok(t) => t,
|
||||
Err(err) => {
|
||||
error!("fetch server token: {}", err.to_string());
|
||||
return Err(HubError::InvalidToken(
|
||||
return Err(RegisterWorkerError::InvalidToken(
|
||||
"Failed to fetch server token".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
if server_token != token {
|
||||
return Err(HubError::InvalidToken("Token mismatch".to_string()));
|
||||
return Err(RegisterWorkerError::InvalidToken(
|
||||
"Token mismatch".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut worker_addr = self.worker_addr.lock().await;
|
||||
if !worker_addr.is_empty() {
|
||||
return Err(RegisterWorkerError::RegisterWorkerOnce);
|
||||
}
|
||||
|
||||
let addr = format!("http://{}:{}", self.conn.ip(), port);
|
||||
*worker_addr = addr.clone();
|
||||
|
||||
let worker = Worker {
|
||||
name,
|
||||
kind,
|
||||
addr: format!("http://{}:{}", self.conn.ip(), port),
|
||||
addr,
|
||||
device,
|
||||
arch,
|
||||
cpu_info,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
|||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
api::{HubError, Worker, WorkerKind},
|
||||
api::{RegisterWorkerError, Worker, WorkerKind},
|
||||
db::DbConn,
|
||||
};
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ impl ServerContext {
|
|||
self.db_conn.reset_registration_token().await
|
||||
}
|
||||
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
||||
let worker = match worker.kind {
|
||||
WorkerKind::Completion => self.completion.register(worker).await,
|
||||
WorkerKind::Chat => self.chat.register(worker).await,
|
||||
|
|
@ -64,10 +64,23 @@ impl ServerContext {
|
|||
);
|
||||
Ok(worker)
|
||||
} else {
|
||||
Err(HubError::RequiresEnterpriseLicense)
|
||||
Err(RegisterWorkerError::RequiresEnterpriseLicense)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unregister_worker(&self, worker_addr: &str) {
|
||||
let kind = if self.chat.unregister(worker_addr).await {
|
||||
WorkerKind::Chat
|
||||
} else if self.completion.unregister(worker_addr).await {
|
||||
WorkerKind::Completion
|
||||
} else {
|
||||
warn!("Trying to unregister a worker missing in registry");
|
||||
return;
|
||||
};
|
||||
|
||||
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
|
||||
}
|
||||
|
||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
||||
[self.completion.list().await, self.chat.list().await].concat()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,16 @@ impl WorkerGroup {
|
|||
|
||||
Some(worker)
|
||||
}
|
||||
|
||||
pub async fn unregister(&self, worker_addr: &str) -> bool {
|
||||
let mut workers = self.workers.write().await;
|
||||
if let Some(index) = workers.iter().position(|x| x.addr == worker_addr) {
|
||||
workers.remove(index);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn random_index(size: usize) -> usize {
|
||||
|
|
|
|||
Loading…
Reference in New Issue