From 2296b01d2f969df01cca21445950723ea408360b Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 23 Nov 2023 14:22:04 +0800 Subject: [PATCH] feat: implement worker unregistration (#872) * feat: implement worker unregisteration logic * refactor: rename HubError -> RegisterWorkerError --- ee/tabby-webserver/src/api.rs | 7 ++-- ee/tabby-webserver/src/lib.rs | 47 +++++++++++++++++++++---- ee/tabby-webserver/src/server.rs | 19 ++++++++-- ee/tabby-webserver/src/server/worker.rs | 10 ++++++ 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index 512238a..fa5998f 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -29,12 +29,15 @@ pub struct Worker { } #[derive(Serialize, Deserialize, Error, Debug)] -pub enum HubError { +pub enum RegisterWorkerError { #[error("Invalid token")] InvalidToken(String), #[error("Feature requires enterprise license")] RequiresEnterpriseLicense, + + #[error("Each hub client should only calls register_worker once")] + RegisterWorkerOnce, } #[tarpc::service] @@ -49,7 +52,7 @@ pub trait Hub { cpu_count: i32, cuda_devices: Vec, token: String, - ) -> Result; + ) -> Result; async fn log_event(content: String); diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index f0da4ab..e93f338 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -6,6 +6,7 @@ use tabby_common::api::{ code::{CodeSearch, SearchResponse}, event::RawEventLogger, }; +use tokio::sync::Mutex; use tracing::{error, warn}; use websocket::WebSocketTransport; @@ -16,7 +17,7 @@ mod websocket; use std::{net::SocketAddr, sync::Arc}; -use api::{Hub, HubError, Worker, WorkerKind}; +use api::{Hub, RegisterWorkerError, Worker, WorkerKind}; use axum::{ extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, http::Request, @@ -83,11 +84,31 @@ async fn handle_socket(state: Arc, socket: WebSocket, addr: Socke pub struct HubImpl { ctx: Arc, conn: SocketAddr, + + worker_addr: Arc>, } impl HubImpl { pub fn new(ctx: Arc, conn: SocketAddr) -> Self { - Self { ctx, conn } + Self { + ctx, + conn, + worker_addr: Arc::new(Mutex::new("".to_owned())), + } + } +} + +impl Drop for HubImpl { + fn drop(&mut self) { + let ctx = self.ctx.clone(); + let worker_addr = self.worker_addr.clone(); + + tokio::spawn(async move { + let worker_addr = worker_addr.lock().await; + if !worker_addr.is_empty() { + ctx.unregister_worker(worker_addr.as_str()).await; + } + }); } } @@ -105,27 +126,39 @@ impl Hub for Arc { cpu_count: i32, cuda_devices: Vec, token: String, - ) -> Result { + ) -> Result { if token.is_empty() { - return Err(HubError::InvalidToken("Empty worker token".to_string())); + return Err(RegisterWorkerError::InvalidToken( + "Empty worker token".to_string(), + )); } let server_token = match self.ctx.read_registration_token().await { Ok(t) => t, Err(err) => { error!("fetch server token: {}", err.to_string()); - return Err(HubError::InvalidToken( + return Err(RegisterWorkerError::InvalidToken( "Failed to fetch server token".to_string(), )); } }; if server_token != token { - return Err(HubError::InvalidToken("Token mismatch".to_string())); + return Err(RegisterWorkerError::InvalidToken( + "Token mismatch".to_string(), + )); } + let mut worker_addr = self.worker_addr.lock().await; + if !worker_addr.is_empty() { + return Err(RegisterWorkerError::RegisterWorkerOnce); + } + + let addr = format!("http://{}:{}", self.conn.ip(), port); + *worker_addr = addr.clone(); + let worker = Worker { name, kind, - addr: format!("http://{}:{}", self.conn.ip(), port), + addr, device, arch, cpu_info, diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/server.rs index 435539f..dc40534 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/server.rs @@ -10,7 +10,7 @@ use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tracing::{info, warn}; use crate::{ - api::{HubError, Worker, WorkerKind}, + api::{RegisterWorkerError, Worker, WorkerKind}, db::DbConn, }; @@ -51,7 +51,7 @@ impl ServerContext { self.db_conn.reset_registration_token().await } - pub async fn register_worker(&self, worker: Worker) -> Result { + pub async fn register_worker(&self, worker: Worker) -> Result { let worker = match worker.kind { WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await, @@ -64,10 +64,23 @@ impl ServerContext { ); Ok(worker) } else { - Err(HubError::RequiresEnterpriseLicense) + Err(RegisterWorkerError::RequiresEnterpriseLicense) } } + pub async fn unregister_worker(&self, worker_addr: &str) { + let kind = if self.chat.unregister(worker_addr).await { + WorkerKind::Chat + } else if self.completion.unregister(worker_addr).await { + WorkerKind::Completion + } else { + warn!("Trying to unregister a worker missing in registry"); + return; + }; + + info!("unregistering <{:?}> worker at {}", kind, worker_addr); + } + pub async fn list_workers(&self) -> Vec { [self.completion.list().await, self.chat.list().await].concat() } diff --git a/ee/tabby-webserver/src/server/worker.rs b/ee/tabby-webserver/src/server/worker.rs index 9709c02..128da9c 100644 --- a/ee/tabby-webserver/src/server/worker.rs +++ b/ee/tabby-webserver/src/server/worker.rs @@ -37,6 +37,16 @@ impl WorkerGroup { Some(worker) } + + pub async fn unregister(&self, worker_addr: &str) -> bool { + let mut workers = self.workers.write().await; + if let Some(index) = workers.iter().position(|x| x.addr == worker_addr) { + workers.remove(index); + true + } else { + false + } + } } fn random_index(size: usize) -> usize {