feat: implement worker unregistration (#872)
* feat: implement worker unregisteration logic * refactor: rename HubError -> RegisterWorkerErrorr0.6
parent
820465d59f
commit
2296b01d2f
|
|
@ -29,12 +29,15 @@ pub struct Worker {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Error, Debug)]
|
#[derive(Serialize, Deserialize, Error, Debug)]
|
||||||
pub enum HubError {
|
pub enum RegisterWorkerError {
|
||||||
#[error("Invalid token")]
|
#[error("Invalid token")]
|
||||||
InvalidToken(String),
|
InvalidToken(String),
|
||||||
|
|
||||||
#[error("Feature requires enterprise license")]
|
#[error("Feature requires enterprise license")]
|
||||||
RequiresEnterpriseLicense,
|
RequiresEnterpriseLicense,
|
||||||
|
|
||||||
|
#[error("Each hub client should only calls register_worker once")]
|
||||||
|
RegisterWorkerOnce,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tarpc::service]
|
#[tarpc::service]
|
||||||
|
|
@ -49,7 +52,7 @@ pub trait Hub {
|
||||||
cpu_count: i32,
|
cpu_count: i32,
|
||||||
cuda_devices: Vec<String>,
|
cuda_devices: Vec<String>,
|
||||||
token: String,
|
token: String,
|
||||||
) -> Result<Worker, HubError>;
|
) -> Result<Worker, RegisterWorkerError>;
|
||||||
|
|
||||||
async fn log_event(content: String);
|
async fn log_event(content: String);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ use tabby_common::api::{
|
||||||
code::{CodeSearch, SearchResponse},
|
code::{CodeSearch, SearchResponse},
|
||||||
event::RawEventLogger,
|
event::RawEventLogger,
|
||||||
};
|
};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tracing::{error, warn};
|
use tracing::{error, warn};
|
||||||
use websocket::WebSocketTransport;
|
use websocket::WebSocketTransport;
|
||||||
|
|
||||||
|
|
@ -16,7 +17,7 @@ mod websocket;
|
||||||
|
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
use api::{Hub, HubError, Worker, WorkerKind};
|
use api::{Hub, RegisterWorkerError, Worker, WorkerKind};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
extract::{ws::WebSocket, ConnectInfo, State, WebSocketUpgrade},
|
||||||
http::Request,
|
http::Request,
|
||||||
|
|
@ -83,11 +84,31 @@ async fn handle_socket(state: Arc<ServerContext>, socket: WebSocket, addr: Socke
|
||||||
pub struct HubImpl {
|
pub struct HubImpl {
|
||||||
ctx: Arc<ServerContext>,
|
ctx: Arc<ServerContext>,
|
||||||
conn: SocketAddr,
|
conn: SocketAddr,
|
||||||
|
|
||||||
|
worker_addr: Arc<Mutex<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubImpl {
|
impl HubImpl {
|
||||||
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
pub fn new(ctx: Arc<ServerContext>, conn: SocketAddr) -> Self {
|
||||||
Self { ctx, conn }
|
Self {
|
||||||
|
ctx,
|
||||||
|
conn,
|
||||||
|
worker_addr: Arc::new(Mutex::new("".to_owned())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for HubImpl {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let ctx = self.ctx.clone();
|
||||||
|
let worker_addr = self.worker_addr.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let worker_addr = worker_addr.lock().await;
|
||||||
|
if !worker_addr.is_empty() {
|
||||||
|
ctx.unregister_worker(worker_addr.as_str()).await;
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -105,27 +126,39 @@ impl Hub for Arc<HubImpl> {
|
||||||
cpu_count: i32,
|
cpu_count: i32,
|
||||||
cuda_devices: Vec<String>,
|
cuda_devices: Vec<String>,
|
||||||
token: String,
|
token: String,
|
||||||
) -> Result<Worker, HubError> {
|
) -> Result<Worker, RegisterWorkerError> {
|
||||||
if token.is_empty() {
|
if token.is_empty() {
|
||||||
return Err(HubError::InvalidToken("Empty worker token".to_string()));
|
return Err(RegisterWorkerError::InvalidToken(
|
||||||
|
"Empty worker token".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
let server_token = match self.ctx.read_registration_token().await {
|
let server_token = match self.ctx.read_registration_token().await {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("fetch server token: {}", err.to_string());
|
error!("fetch server token: {}", err.to_string());
|
||||||
return Err(HubError::InvalidToken(
|
return Err(RegisterWorkerError::InvalidToken(
|
||||||
"Failed to fetch server token".to_string(),
|
"Failed to fetch server token".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if server_token != token {
|
if server_token != token {
|
||||||
return Err(HubError::InvalidToken("Token mismatch".to_string()));
|
return Err(RegisterWorkerError::InvalidToken(
|
||||||
|
"Token mismatch".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut worker_addr = self.worker_addr.lock().await;
|
||||||
|
if !worker_addr.is_empty() {
|
||||||
|
return Err(RegisterWorkerError::RegisterWorkerOnce);
|
||||||
|
}
|
||||||
|
|
||||||
|
let addr = format!("http://{}:{}", self.conn.ip(), port);
|
||||||
|
*worker_addr = addr.clone();
|
||||||
|
|
||||||
let worker = Worker {
|
let worker = Worker {
|
||||||
name,
|
name,
|
||||||
kind,
|
kind,
|
||||||
addr: format!("http://{}:{}", self.conn.ip(), port),
|
addr,
|
||||||
device,
|
device,
|
||||||
arch,
|
arch,
|
||||||
cpu_info,
|
cpu_info,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{HubError, Worker, WorkerKind},
|
api::{RegisterWorkerError, Worker, WorkerKind},
|
||||||
db::DbConn,
|
db::DbConn,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -51,7 +51,7 @@ impl ServerContext {
|
||||||
self.db_conn.reset_registration_token().await
|
self.db_conn.reset_registration_token().await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, HubError> {
|
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, RegisterWorkerError> {
|
||||||
let worker = match worker.kind {
|
let worker = match worker.kind {
|
||||||
WorkerKind::Completion => self.completion.register(worker).await,
|
WorkerKind::Completion => self.completion.register(worker).await,
|
||||||
WorkerKind::Chat => self.chat.register(worker).await,
|
WorkerKind::Chat => self.chat.register(worker).await,
|
||||||
|
|
@ -64,10 +64,23 @@ impl ServerContext {
|
||||||
);
|
);
|
||||||
Ok(worker)
|
Ok(worker)
|
||||||
} else {
|
} else {
|
||||||
Err(HubError::RequiresEnterpriseLicense)
|
Err(RegisterWorkerError::RequiresEnterpriseLicense)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn unregister_worker(&self, worker_addr: &str) {
|
||||||
|
let kind = if self.chat.unregister(worker_addr).await {
|
||||||
|
WorkerKind::Chat
|
||||||
|
} else if self.completion.unregister(worker_addr).await {
|
||||||
|
WorkerKind::Completion
|
||||||
|
} else {
|
||||||
|
warn!("Trying to unregister a worker missing in registry");
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("unregistering <{:?}> worker at {}", kind, worker_addr);
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
pub async fn list_workers(&self) -> Vec<Worker> {
|
||||||
[self.completion.list().await, self.chat.list().await].concat()
|
[self.completion.list().await, self.chat.list().await].concat()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,16 @@ impl WorkerGroup {
|
||||||
|
|
||||||
Some(worker)
|
Some(worker)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn unregister(&self, worker_addr: &str) -> bool {
|
||||||
|
let mut workers = self.workers.write().await;
|
||||||
|
if let Some(index) = workers.iter().position(|x| x.addr == worker_addr) {
|
||||||
|
workers.remove(index);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn random_index(size: usize) -> usize {
|
fn random_index(size: usize) -> usize {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue