feat: implement worker unregistration (#872)

* feat: implement worker unregisteration logic

* refactor: rename HubError -> RegisterWorkerError
wsxiaoys-patch-3
Meng Zhang 2023-11-23 14:22:04 +08:00 committed by GitHub
parent a067366b5d
commit 3bd3a3304d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 12 deletions

View File

@ -29,12 +29,15 @@ pub struct Worker {
}
#[derive(Serialize, Deserialize, Error, Debug)]
pub enum HubError {
pub enum RegisterWorkerError {
#[error("Invalid token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
#[error("Each hub client should only calls register_worker once")]
RegisterWorkerOnce,
}
#[tarpc::service]
@ -49,7 +52,7 @@ pub trait Hub {
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError>;
) -> Result<Worker, RegisterWorkerError>;
async fn log_event(content: String);

View File

@ -6,6 +6,7 @@ use tabby_common::api::{
code::{CodeSearch, SearchResponse},
event::RawEventLogger,
};
use tokio::sync::Mutex;
use tracing::{error, warn};
use websocket::WebSocketTransport;
@ -16,7 +17,7 @@ mod websocket;
use std::{net::SocketAddr, sync::Arc};
use api::{Hub, HubError, Worker, WorkerKind};
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
use axum::{
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
http::Request,
@ -83,11 +84,31 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
pub struct HubImpl {
ctx: Arc<ServerContext>,
conn: SocketAddr,
worker_addr: Arc<Mutex<String>>,
}
impl HubImpl {
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
Self { ctx, conn }
Self {
ctx,
conn,
worker_addr: Arc::new(Mutex::new("".to_owned())),
}
}
}
impl Drop for HubImpl {
fn drop(&mut self) {
let ctx = self.ctx.clone();
let worker_addr = self.worker_addr.clone();
tokio::spawn(async move {
let worker_addr = worker_addr.lock().await;
if !worker_addr.is_empty() {
ctx.unregister_worker(worker_addr.as_str()).await;
}
});
}
}
@ -105,27 +126,39 @@ impl Hub for Arc<HubImpl> {
cpu_count: i32,
cuda_devices: Vec<String>,
token: String,
) -> Result<Worker, HubError> {
) -> Result<Worker, RegisterWorkerError> {
if token.is_empty() {
return Err(HubError::InvalidToken("Empty worker token".to_string()));
return Err(RegisterWorkerError::InvalidToken(
"Empty worker token".to_string(),
));
}
let server_token = match self.ctx.read_registration_token().await {
Ok(t) => t,
Err(err) => {
error!("fetch server token: {}", err.to_string());
return Err(HubError::InvalidToken(
return Err(RegisterWorkerError::InvalidToken(
"Failed to fetch server token".to_string(),
));
}
};
if server_token != token {
return Err(HubError::InvalidToken("Token mismatch".to_string()));
return Err(RegisterWorkerError::InvalidToken(
"Token mismatch".to_string(),
));
}
let mut worker_addr = self.worker_addr.lock().await;
if !worker_addr.is_empty() {
return Err(RegisterWorkerError::RegisterWorkerOnce);
}
let addr = format!("http://{}:{}", self.conn.ip(), port);
*worker_addr = addr.clone();
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", self.conn.ip(), port),
addr,
device,
arch,
cpu_info,

View File

@ -10,7 +10,7 @@ use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn};
use crate::{
api::{HubError, Worker, WorkerKind},
api::{RegisterWorkerError, Worker, WorkerKind},
db::DbConn,
};
@ -51,7 +51,7 @@ impl ServerContext {
self.db_conn.reset_registration_token().await
}
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await,
@ -64,10 +64,23 @@ impl ServerContext {
);
Ok(worker)
} else {
Err(HubError::RequiresEnterpriseLicense)
Err(RegisterWorkerError::RequiresEnterpriseLicense)
}
}
pub async fn unregister_worker(&self, worker_addr: &str) {
let kind = if self.chat.unregister(worker_addr).await {
WorkerKind::Chat
} else if self.completion.unregister(worker_addr).await {
WorkerKind::Completion
} else {
warn!("Trying to unregister a worker missing in registry");
return;
};
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
}
pub async fn list_workers(&self) -> Vec<Worker> {
[self.completion.list().await, self.chat.list().await].concat()
}

View File

@ -37,6 +37,16 @@ impl WorkerGroup {
Some(worker)
}
pub async fn unregister(&self, worker_addr: &str) -> bool {
let mut workers = self.workers.write().await;
if let Some(index) = workers.iter().position(|x| x.addr == worker_addr) {
workers.remove(index);
true
} else {
false
}
}
}
fn random_index(size: usize) -> usize {