feat: implement worker unregistration (#872)

* feat: implement worker unregisteration logic

* refactor: rename HubError -> RegisterWorkerError
r0.6
Meng Zhang 2023-11-23 14:22:04 +08:00
parent 820465d59f
commit 2296b01d2f
4 changed files with 71 additions and 12 deletions

View File

@ -29,12 +29,15 @@ pub struct Worker {
} }
#[derive(Serialize, Deserialize, Error, Debug)] #[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError { pub enum RegisterWorkerError {
#[error("Invalid token")] #[error("Invalid token")]
InvalidToken(String), InvalidToken(String),
#[error("Feature requires enterprise license")] #[error("Feature requires enterprise license")]
RequiresEnterpriseLicense, RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
} }
#[tarpc::service] #[tarpc::service]
@ -49,7 +52,7 @@ pub trait Hub {
cpu_count: i32, cpu_count: i32,
cuda_devices: Vec<String>, cuda_devices: Vec<String>,
token: String, token: String,
) -> Result<Worker, HubError>; ) -> Result<Worker, RegisterWorkerError>;
async fn log_event(content: String); async fn log_event(content: String);

View File

@ -6,6 +6,7 @@ use tabby_common::api::{
code::{CodeSearch, SearchResponse}, code::{CodeSearch, SearchResponse},
event::RawEventLogger, event::RawEventLogger,
}; };
use tokio::sync::Mutex;
use tracing::{error, warn}; use tracing::{error, warn};
use websocket::WebSocketTransport; use websocket::WebSocketTransport;
@ -16,7 +17,7 @@ mod websocket;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use api::{Hub, HubError, Worker, WorkerKind}; use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
use axum::{ use axum::{
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade}, extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request, http::Request,
@ -83,11 +84,31 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
pub struct HubImpl { pub struct HubImpl {
ctx: Arc<ServerContext>, ctx: Arc<ServerContext>,
conn: SocketAddr, conn: SocketAddr,
worker_addr: Arc<Mutex<String>>,
} }
impl HubImpl { impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self { pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
Self { ctx, conn } Self {
ctx,
conn,
worker_addr: Arc::new(Mutex::new("".to_owned())),
}
}
}
impl Drop for HubImpl {
fn drop(&mut self) {
let ctx = self.ctx.clone();
let worker_addr = self.worker_addr.clone();
tokio::spawn(async move {
let worker_addr = worker_addr.lock().await;
if !worker_addr.is_empty() {
ctx.unregister_worker(worker_addr.as_str()).await;
}
});
} }
} }
@ -105,27 +126,39 @@ impl Hub for Arc<HubImpl> {
cpu_count: i32, cpu_count: i32,
cuda_devices: Vec<String>, cuda_devices: Vec<String>,
token: String, token: String,
) -> Result<Worker, HubError> { ) -> Result<Worker, RegisterWorkerError> {
if token.is_empty() { if token.is_empty() {
return Err(HubError::InvalidToken("Empty worker token".to_string())); return Err(RegisterWorkerError::InvalidToken(
"Empty worker token".to_string(),
));
} }
let server_token = match self.ctx.read_registration_token().await { let server_token = match self.ctx.read_registration_token().await {
Ok(t) => t, Ok(t) => t,
Err(err) => { Err(err) => {
error!("fetch server token: {}", err.to_string()); error!("fetch server token: {}", err.to_string());
return Err(HubError::InvalidToken( return Err(RegisterWorkerError::InvalidToken(
"Failed to fetch server token".to_string(), "Failed to fetch server token".to_string(),
)); ));
} }
}; };
if server_token != token { if server_token != token {
return Err(HubError::InvalidToken("Token mismatch".to_string())); return Err(RegisterWorkerError::InvalidToken(
"Token mismatch".to_string(),
));
} }
let mut worker_addr = self.worker_addr.lock().await;
if !worker_addr.is_empty() {
return Err(RegisterWorkerError::RegisterWorkerOnce);
}
let addr = format!("http://{}:{}", self.conn.ip(), port);
*worker_addr = addr.clone();
let worker = Worker { let worker = Worker {
name, name,
kind, kind,
addr: format!("http://{}:{}", self.conn.ip(), port), addr,
device, device,
arch, arch,
cpu_info, cpu_info,

View File

@ -10,7 +10,7 @@ use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use crate::{
api::{HubError, Worker, WorkerKind}, api::{RegisterWorkerError, Worker, WorkerKind},
db::DbConn, db::DbConn,
}; };
@ -51,7 +51,7 @@ impl ServerContext {
self.db_conn.reset_registration_token().await self.db_conn.reset_registration_token().await
} }
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> { pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
let worker = match worker.kind { let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await, WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await, WorkerKind::Chat => self.chat.register(worker).await,
@ -64,10 +64,23 @@ impl ServerContext {
); );
Ok(worker) Ok(worker)
} else { } else {
Err(HubError::RequiresEnterpriseLicense) Err(RegisterWorkerError::RequiresEnterpriseLicense)
} }
} }
pub async fn unregister_worker(&self, worker_addr: &str) {
let kind = if self.chat.unregister(worker_addr).await {
WorkerKind::Chat
} else if self.completion.unregister(worker_addr).await {
WorkerKind::Completion
} else {
warn!("Trying to unregister a worker missing in registry");
return;
};
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
}
pub async fn list_workers(&self) -> Vec<Worker> { pub async fn list_workers(&self) -> Vec<Worker> {
[self.completion.list().await, self.chat.list().await].concat() [self.completion.list().await, self.chat.list().await].concat()
} }

View File

@ -37,6 +37,16 @@ impl WorkerGroup {
Some(worker) Some(worker)
} }
pub async fn unregister(&self, worker_addr: &str) -> bool {
let mut workers = self.workers.write().await;
if let Some(index) = workers.iter().position(|x| x.addr == worker_addr) {
workers.remove(index);
true
} else {
false
}
}
} }
fn random_index(size: usize) -> usize { fn random_index(size: usize) -> usize {