feat: add worker command worker::completion and worker::chat (#778)

release-fix-intellij-update-support-version-range
Meng Zhang 2023-11-13 15:21:57 -08:00 committed by GitHub
parent 510eddca89
commit e521f0637c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 489 additions and 144 deletions

62
Cargo.lock generated
View File

@ -1711,6 +1711,15 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "graphql-introspection-query"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d"
dependencies = [
"serde",
]
[[package]]
name = "graphql-parser"
version = "0.3.0"
@ -1721,6 +1730,56 @@ dependencies = [
"thiserror",
]
[[package]]
name = "graphql-parser"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
dependencies = [
"combine",
"thiserror",
]
[[package]]
name = "graphql_client"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa"
dependencies = [
"graphql_query_derive",
"reqwest",
"serde",
"serde_json",
]
[[package]]
name = "graphql_client_codegen"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506"
dependencies = [
"graphql-introspection-query",
"graphql-parser 0.4.0",
"heck 0.4.1",
"lazy_static",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 1.0.109",
]
[[package]]
name = "graphql_query_derive"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c"
dependencies = [
"graphql_client_codegen",
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "h2"
version = "0.3.19"
@ -2148,7 +2207,7 @@ dependencies = [
"fnv",
"futures",
"futures-enum",
"graphql-parser",
"graphql-parser 0.3.0",
"indexmap 1.9.3",
"juniper_codegen",
"serde",
@ -4285,6 +4344,7 @@ dependencies = [
"chrono",
"clap 4.4.7",
"futures",
"graphql_client",
"http-api-bindings",
"hyper",
"lazy_static",

View File

@ -35,3 +35,6 @@ update-openapi-doc:
["components", "schemas", "DebugOptions"] \
])' | jq '.servers[0] |= { url: "https://playground.app.tabbyml.com", description: "Playground server" }' \
> website/static/openapi.json
update-graphql-schema:
cargo run --package tabby-webserver --example update-schema

View File

@ -47,6 +47,8 @@ async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" }
thiserror.workspace = true
chrono = "0.4.31"
graphql_client = { version = "0.13.0", features = ["reqwest"] }
reqwest.workspace = true
[dependencies.uuid]
version = "1.3.3"

View File

@ -0,0 +1,23 @@
mutation RegisterWorker(
$port: Int!
$kind: WorkerKind!
$name: String!
$device: String!
$arch: String!
$cpuInfo: String!
$cpuCount: Int!
$cudaDevices: [String!]!
) {
worker: registerWorker(
port: $port
kind: $kind
name: $name
device: $device
arch: $arch
cpuInfo: $cpuInfo
cpuCount: $cpuCount
cudaDevices: $cudaDevices
) {
addr
}
}

View File

@ -1,9 +1,11 @@
mod api;
mod download;
mod routes;
mod serve;
mod services;
mod download;
mod serve;
mod worker;
use clap::{Parser, Subcommand};
use opentelemetry::{
global,
@ -36,6 +38,16 @@ pub enum Commands {
/// Run scheduler progress for cron jobs integrating external code repositories.
Scheduler(SchedulerArgs),
/// Run completion model as worker
#[clap(name = "worker::completion")]
#[command(arg_required_else_help = true)]
WorkerCompletion(worker::WorkerArgs),
/// Run chat model as worker
#[clap(name = "worker::chat")]
#[command(arg_required_else_help = true)]
WorkerChat(worker::WorkerArgs),
}
#[derive(clap::Args)]
@ -45,6 +57,41 @@ pub struct SchedulerArgs {
now: bool,
}
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[strum(serialize = "cpu")]
Cpu,
#[cfg(feature = "cuda")]
#[strum(serialize = "cuda")]
Cuda,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,
#[cfg(feature = "experimental-http")]
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}
impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}
#[cfg(feature = "cuda")]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
pub fn ggml_use_gpu(&self) -> bool {
false
}
}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
@ -58,6 +105,10 @@ async fn main() {
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
Commands::WorkerCompletion(args) => {
worker::main(worker::WorkerKind::Completion, args).await
}
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
}
opentelemetry::global::shutdown_tracer_provider();

View File

@ -1,5 +1,4 @@
use std::{
fs,
net::{Ipv4Addr, SocketAddr},
sync::Arc,
time::Duration,
@ -9,7 +8,6 @@ use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args;
use tabby_common::{config::Config, usage};
use tabby_download::download_model;
use tabby_webserver::attach_webserver;
use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
@ -20,7 +18,14 @@ use utoipa_swagger_ui::SwaggerUi;
use crate::{
api::{self},
fatal, routes,
services::{chat, completion, event::create_event_logger, health, model},
services::{
chat::{self, create_chat_service},
completion::{self, create_completion_service},
event::create_logger,
health,
model::download_model_if_needed,
},
Device,
};
#[derive(OpenApi)]
@ -62,41 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
)]
struct ApiDoc;
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[strum(serialize = "cpu")]
Cpu,
#[cfg(feature = "cuda")]
#[strum(serialize = "cuda")]
Cuda,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,
#[cfg(feature = "experimental-http")]
#[strum(serialize = "experimental_http")]
ExperimentalHttp,
}
impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}
#[cfg(feature = "cuda")]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Cuda
}
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
pub fn ggml_use_gpu(&self) -> bool {
false
}
}
#[derive(Args)]
pub struct ServeArgs {
/// Model id for `/completions` API endpoint.
@ -152,43 +122,30 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
async fn load_model(args: &ServeArgs) {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
} else {
download_model(&args.model, true).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, true).await;
}
download_model_if_needed(&args.model).await;
if let Some(chat_model) = &args.chat_model {
download_model_if_needed(chat_model).await
}
}
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let logger = Arc::new(create_event_logger());
let logger = Arc::new(create_logger());
let code = Arc::new(crate::services::code::create_code_search());
let completion_state = {
let (
engine,
model::PromptInfo {
prompt_template, ..
},
) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = completion::CompletionService::new(
engine.clone(),
let completion = Arc::new(
create_completion_service(
code.clone(),
logger.clone(),
prompt_template,
);
Arc::new(state)
};
&args.model,
&args.device,
args.parallelism,
)
.await,
);
let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template");
};
let state = chat::ChatService::new(engine, chat_template);
Some(Arc::new(state))
let chat_state = if let Some(_chat_model) = &args.chat_model {
Some(Arc::new(
create_chat_service(&args.model, &args.device, args.parallelism).await,
))
} else {
None
};
@ -220,7 +177,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
Router::new()
.route(
"/v1/completions",
routing::post(routes::completions).with_state(completion_state),
routing::post(routes::completions).with_state(completion),
)
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,

View File

@ -11,6 +11,9 @@ use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptio
use tracing::debug;
use utoipa::ToSchema;
use super::model;
use crate::{fatal, Device};
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"messages": [
@ -40,7 +43,7 @@ pub struct ChatService {
}
impl ChatService {
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
Self {
engine,
prompt_builder: ChatPromptBuilder::new(chat_template),
@ -73,3 +76,14 @@ impl ChatService {
Box::pin(s)
}
}
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await;
let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};
ChatService::new(engine, chat_template)
}

View File

@ -9,10 +9,14 @@ use thiserror::Error;
use tracing::debug;
use utoipa::ToSchema;
use crate::api::{
self,
code::CodeSearch,
event::{Event, EventLogger},
use super::model;
use crate::{
api::{
self,
code::CodeSearch,
event::{Event, EventLogger},
},
Device,
};
#[derive(Error, Debug)]
@ -166,7 +170,7 @@ pub struct CompletionService {
}
impl CompletionService {
pub fn new(
fn new(
engine: Arc<dyn TextGeneration>,
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
@ -260,3 +264,20 @@ impl CompletionService {
))
}
}
pub async fn create_completion_service(
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
model: &str,
device: &Device,
parallelism: u8,
) -> CompletionService {
let (
engine,
model::PromptInfo {
prompt_template, ..
},
) = model::load_text_generation(model, device, parallelism).await;
CompletionService::new(engine.clone(), code, logger, prompt_template)
}

View File

@ -85,6 +85,16 @@ fn timestamp() -> u128 {
.as_millis()
}
pub fn create_event_logger() -> impl EventLogger {
pub fn create_logger() -> impl EventLogger {
EventService
}
struct NullLogger;
impl EventLogger for NullLogger {
fn log(&self, _e: &Event) {}
}
pub fn create_null_logger() -> impl EventLogger {
NullLogger
}

View File

@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
use sysinfo::{CpuExt, System, SystemExt};
use utoipa::ToSchema;
use crate::serve::Device;
use crate::Device;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct HealthState {
@ -43,7 +43,7 @@ impl HealthState {
}
}
fn read_cpu_info() -> (String, usize) {
pub fn read_cpu_info() -> (String, usize) {
let mut system = System::new_all();
system.refresh_cpu();
let cpus = system.cpus();
@ -58,7 +58,7 @@ fn read_cpu_info() -> (String, usize) {
(info, count)
}
fn read_cuda_devices() -> Result<Vec<String>> {
pub fn read_cuda_devices() -> Result<Vec<String>> {
// In cases of MacOS or docker containers where --gpus are not specified,
// the Nvml::init() would return an error. In these scenarios, we
// assign cuda_devices to be empty, indicating that the current runtime

View File

@ -2,9 +2,11 @@ use std::{fs, path::PathBuf, sync::Arc};
use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_download::download_model;
use tabby_inference::TextGeneration;
use tracing::info;
use crate::{fatal, serve::Device};
use crate::{fatal, Device};
pub async fn load_text_generation(
model_id: &str,
@ -72,3 +74,11 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp
llama_cpp_bindings::LlamaTextGeneration::new(options)
}
pub async fn download_model_if_needed(model: &str) {
if fs::metadata(model).is_ok() {
info!("Loading model from local path {}", model);
} else {
download_model(model, true).await;
}
}

146
crates/tabby/src/worker.rs Normal file
View File

@ -0,0 +1,146 @@
use std::{
env::consts::ARCH,
net::{Ipv4Addr, SocketAddr},
sync::Arc,
};
use axum::{routing, Router};
use clap::Args;
use graphql_client::{reqwest::post_graphql, GraphQLQuery};
use hyper::Server;
use tracing::{info, warn};
use crate::{
fatal, routes,
services::{
chat::create_chat_service,
code,
completion::create_completion_service,
event::{self},
health::{read_cpu_info, read_cuda_devices},
model::download_model_if_needed,
},
Device,
};
#[derive(GraphQLQuery)]
#[graphql(
schema_path = "../../ee/tabby-webserver/graphql/schema.graphql",
query_path = "./graphql/worker.query.graphql"
)]
struct RegisterWorker;
#[derive(Args)]
pub struct WorkerArgs {
/// URL to register this worker.
#[clap(long)]
url: String,
#[clap(long, default_value_t = 8080)]
port: u16,
/// Model id
#[clap(long, help_heading=Some("Model Options"))]
model: String,
/// Device to run model inference.
#[clap(long, default_value_t=Device::Cpu, help_heading=Some("Model Options"))]
device: Device,
/// Parallelism for model serving - increasing this number will have a significant impact on the
/// memory requirement e.g., GPU vRAM.
#[clap(long, default_value_t = 1, help_heading=Some("Model Options"))]
parallelism: u8,
}
async fn make_chat_route(args: &WorkerArgs) -> Router {
let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
request_register(register_worker::WorkerKind::CHAT, args).await;
Router::new().route(
"/v1beta/chat/completions",
routing::post(routes::chat_completions).with_state(state),
)
}
async fn make_completion_route(args: &WorkerArgs) -> Router {
let code = Arc::new(code::create_code_search());
let logger = Arc::new(event::create_null_logger());
let state = Arc::new(
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
);
request_register(register_worker::WorkerKind::COMPLETION, args).await;
Router::new().route(
"/v1/completions",
routing::post(routes::completions).with_state(state),
)
}
pub enum WorkerKind {
Chat,
Completion,
}
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
download_model_if_needed(&args.model).await;
info!("Starting worker, this might takes a few minutes...");
let app = match kind {
WorkerKind::Completion => make_completion_route(args).await,
WorkerKind::Chat => make_chat_route(args).await,
};
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address);
Server::bind(&address)
.serve(app.into_make_service())
.await
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}
async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) {
request_register_impl(
kind,
args.url.clone(),
args.port as i64,
args.model.to_owned(),
args.device.to_string(),
)
.await;
}
async fn request_register_impl(
kind: register_worker::WorkerKind,
url: String,
port: i64,
name: String,
device: String,
) {
let client = reqwest::Client::new();
let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = read_cuda_devices().unwrap_or_default();
let variables = register_worker::Variables {
port,
kind,
name,
device,
arch: ARCH.to_string(),
cpu_info,
cpu_count: cpu_count as i64,
cuda_devices,
};
let url = format!("{}/graphql", url);
match post_graphql::<RegisterWorker, _>(&client, &url, variables).await {
Ok(x) => {
let addr = x.data.unwrap().worker.addr;
info!("Worker alive at {}", addr);
}
Err(err) => warn!("Failed to register worker: {}", err),
}
}

View File

@ -0,0 +1,13 @@
use std::fs::write;
use juniper::EmptySubscription;
use tabby_webserver::schema::{Mutation, Query, Schema};
fn main() {
let schema = Schema::new(Query, Mutation, EmptySubscription::new());
write(
"ee/tabby-webserver/graphql/schema.graphql",
schema.as_schema_language(),
)
.unwrap();
}

View File

@ -0,0 +1,28 @@
enum WorkerKind {
COMPLETION
CHAT
}
type Mutation {
registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker!
}
type Query {
workers: [Worker!]!
}
type Worker {
kind: WorkerKind!
name: String!
addr: String!
device: String!
arch: String!
cpuInfo: String!
cpuCount: Int!
cudaDevices: [String!]!
}
schema {
query: Query
mutation: Mutation
}

View File

@ -1,4 +1,4 @@
mod schema;
pub mod schema;
mod ui;
mod webserver;
mod worker;

View File

@ -30,14 +30,14 @@ pub enum WorkerKind {
#[derive(GraphQLObject, Clone, Debug)]
pub struct Worker {
kind: WorkerKind,
addr: String,
}
impl Worker {
pub fn new(kind: WorkerKind, addr: String) -> Self {
Self { kind, addr }
}
pub kind: WorkerKind,
pub name: String,
pub addr: String,
pub device: String,
pub arch: String,
pub cpu_info: String,
pub cpu_count: i32,
pub cuda_devices: Vec<String>,
}
#[derive(Default)]
@ -56,13 +56,27 @@ pub struct Mutation;
impl Mutation {
async fn register_worker(
request: &Request,
token: String,
kind: WorkerKind,
port: i32,
kind: WorkerKind,
name: String,
device: String,
arch: String,
cpu_info: String,
cpu_count: i32,
cuda_devices: Vec<String>,
) -> Result<Worker, WebserverError> {
let ws = &request.ws;
ws.register_worker(token, request.client_addr, kind, port)
.await
let worker = Worker {
name,
kind,
addr: format!("http://{}:{}", request.client_addr.ip(), port),
device,
arch,
cpu_info,
cpu_count,
cuda_devices,
};
ws.register_worker(worker).await
}
}

View File

@ -28,47 +28,26 @@ pub struct Webserver {
chat: worker::WorkerGroup,
}
// FIXME: generate token and support refreshing in database.
static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382";
impl Webserver {
pub async fn register_worker(
&self,
token: String,
client_addr: SocketAddr,
kind: WorkerKind,
port: i32,
) -> Result<Worker, WebserverError> {
if token != WORKER_TOKEN {
return Err(WebserverError::InvalidToken(token));
}
let addr = SocketAddr::new(client_addr.ip(), port as u16);
let addr = match kind {
WorkerKind::Completion => self.completion.register(addr).await,
WorkerKind::Chat => self.chat.register(addr).await,
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, WebserverError> {
let worker = match worker.kind {
WorkerKind::Completion => self.completion.register(worker).await,
WorkerKind::Chat => self.chat.register(worker).await,
};
if let Some(addr) = addr {
info!("registering <{:?}> worker running at {}", kind, addr);
Ok(Worker::new(kind, addr))
if let Some(worker) = worker {
info!(
"registering <{:?}> worker running at {}",
worker.kind, worker.addr
);
Ok(worker)
} else {
Err(WebserverError::RequiresEnterpriseLicense)
}
}
pub async fn list_workers(&self) -> Vec<Worker> {
let make_workers = |x: WorkerKind, lst: Vec<String>| -> Vec<Worker> {
lst.into_iter()
.map(|addr| Worker::new(x.clone(), addr))
.collect()
};
[
make_workers(WorkerKind::Completion, self.completion.list().await),
make_workers(WorkerKind::Chat, self.chat.list().await),
]
.concat()
[self.completion.list().await, self.chat.list().await].concat()
}
pub async fn dispatch_request(

View File

@ -1,42 +1,41 @@
use std::{
net::SocketAddr,
time::{SystemTime, UNIX_EPOCH},
};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::error;
use crate::schema::Worker;
#[derive(Default)]
pub struct WorkerGroup {
workers: RwLock<Vec<String>>,
workers: RwLock<Vec<Worker>>,
}
impl WorkerGroup {
pub async fn select(&self) -> Option<String> {
let workers = self.workers.read().await;
if workers.len() > 0 {
Some(workers[random_index(workers.len())].clone())
Some(workers[random_index(workers.len())].addr.clone())
} else {
None
}
}
pub async fn list(&self) -> Vec<String> {
pub async fn list(&self) -> Vec<Worker> {
self.workers.read().await.clone()
}
pub async fn register(&self, addr: SocketAddr) -> Option<String> {
let addr = format!("http://{}", addr);
pub async fn register(&self, worker: Worker) -> Option<Worker> {
let mut workers = self.workers.write().await;
if workers.len() >= 1 {
error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information.");
return None;
}
if !workers.contains(&addr) {
workers.push(addr.clone());
if workers.iter().all(|x| x.addr != worker.addr) {
workers.push(worker.clone());
}
Some(addr)
Some(worker)
}
}
@ -50,23 +49,38 @@ fn random_index(size: usize) -> usize {
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::WorkerKind;
#[tokio::test]
async fn test_worker_group() {
let wg = WorkerGroup::default();
let addr1 = "127.0.0.1:8080".parse().unwrap();
let addr2 = "127.0.0.2:8080".parse().unwrap();
let worker1 = make_worker("http://127.0.0.1:8080");
let worker2 = make_worker("http://127.0.0.2:8080");
// Register success.
assert!(wg.register(addr1).await.is_some());
assert!(wg.register(worker1.clone()).await.is_some());
// Register failed, as > 1 workers requires enterprise license.
assert!(wg.register(addr2).await.is_none());
assert!(wg.register(worker2).await.is_none());
let workers = wg.list().await;
assert_eq!(workers.len(), 1);
assert_eq!(workers[0], format!("http://{}", addr1));
assert_eq!(workers[0].addr, worker1.addr);
}
fn make_worker(addr: &str) -> Worker {
Worker {
name: "Fake worker".to_owned(),
kind: WorkerKind::Chat,
addr: addr.to_owned(),
device: "cuda".to_owned(),
arch: "x86_64".to_owned(),
cpu_info: "Fake CPU".to_owned(),
cpu_count: 32,
cuda_devices: vec![],
}
}
}