diff --git a/Cargo.lock b/Cargo.lock index dabb7e6..c037737 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1711,6 +1711,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "graphql-introspection-query" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d" +dependencies = [ + "serde", +] + [[package]] name = "graphql-parser" version = "0.3.0" @@ -1721,6 +1730,56 @@ dependencies = [ "thiserror", ] +[[package]] +name = "graphql-parser" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474" +dependencies = [ + "combine", + "thiserror", +] + +[[package]] +name = "graphql_client" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa" +dependencies = [ + "graphql_query_derive", + "reqwest", + "serde", + "serde_json", +] + +[[package]] +name = "graphql_client_codegen" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506" +dependencies = [ + "graphql-introspection-query", + "graphql-parser 0.4.0", + "heck 0.4.1", + "lazy_static", + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn 1.0.109", +] + +[[package]] +name = "graphql_query_derive" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c" +dependencies = [ + "graphql_client_codegen", + "proc-macro2", + "syn 1.0.109", +] + [[package]] name = "h2" version = "0.3.19" @@ -2148,7 +2207,7 @@ dependencies = [ "fnv", "futures", "futures-enum", - "graphql-parser", + "graphql-parser 0.3.0", "indexmap 1.9.3", "juniper_codegen", "serde", @@ -4285,6 +4344,7 @@ dependencies = [ "chrono", "clap 4.4.7", "futures", + "graphql_client", "http-api-bindings", "hyper", "lazy_static", diff --git a/Makefile b/Makefile index 442ff7a..9c56e95 100644 --- a/Makefile +++ b/Makefile @@ -35,3 +35,6 @@ update-openapi-doc: ["components", "schemas", "DebugOptions"] \ ])' | jq '.servers[0] |= { url: "https://playground.app.tabbyml.com", description: "Playground server" }' \ > website/static/openapi.json + +update-graphql-schema: + cargo run --package tabby-webserver --example update-schema \ No newline at end of file diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 66d9e04..3178fc8 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -47,6 +47,8 @@ async-trait.workspace = true tabby-webserver = { path = "../../ee/tabby-webserver" } thiserror.workspace = true chrono = "0.4.31" +graphql_client = { version = "0.13.0", features = ["reqwest"] } +reqwest.workspace = true [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/graphql/worker.query.graphql b/crates/tabby/graphql/worker.query.graphql new file mode 100644 index 0000000..157299e --- /dev/null +++ b/crates/tabby/graphql/worker.query.graphql @@ -0,0 +1,23 @@ +mutation RegisterWorker( + $port: Int! + $kind: WorkerKind! + $name: String! + $device: String! + $arch: String! + $cpuInfo: String! + $cpuCount: Int! + $cudaDevices: [String!]! +) { + worker: registerWorker( + port: $port + kind: $kind + name: $name + device: $device + arch: $arch + cpuInfo: $cpuInfo + cpuCount: $cpuCount + cudaDevices: $cudaDevices + ) { + addr + } +} diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 4cfc093..7617627 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,9 +1,11 @@ mod api; -mod download; mod routes; -mod serve; mod services; +mod download; +mod serve; +mod worker; + use clap::{Parser, Subcommand}; use opentelemetry::{ global, @@ -36,6 +38,16 @@ pub enum Commands { /// Run scheduler progress for cron jobs integrating external code repositories. Scheduler(SchedulerArgs), + + /// Run completion model as worker + #[clap(name = "worker::completion")] + #[command(arg_required_else_help = true)] + WorkerCompletion(worker::WorkerArgs), + + /// Run chat model as worker + #[clap(name = "worker::chat")] + #[command(arg_required_else_help = true)] + WorkerChat(worker::WorkerArgs), } #[derive(clap::Args)] @@ -45,6 +57,41 @@ pub struct SchedulerArgs { now: bool, } +#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] +pub enum Device { + #[strum(serialize = "cpu")] + Cpu, + + #[cfg(feature = "cuda")] + #[strum(serialize = "cuda")] + Cuda, + + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + #[strum(serialize = "metal")] + Metal, + + #[cfg(feature = "experimental-http")] + #[strum(serialize = "experimental_http")] + ExperimentalHttp, +} + +impl Device { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + pub fn ggml_use_gpu(&self) -> bool { + *self == Device::Metal + } + + #[cfg(feature = "cuda")] + pub fn ggml_use_gpu(&self) -> bool { + *self == Device::Cuda + } + + #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] + pub fn ggml_use_gpu(&self) -> bool { + false + } +} + #[tokio::main] async fn main() { let cli = Cli::parse(); @@ -58,6 +105,10 @@ async fn main() { Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) .await .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), + Commands::WorkerCompletion(args) => { + worker::main(worker::WorkerKind::Completion, args).await + } + Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await, } opentelemetry::global::shutdown_tracer_provider(); diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve.rs similarity index 75% rename from crates/tabby/src/serve/mod.rs rename to crates/tabby/src/serve.rs index 88d0686..ff8e9e5 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve.rs @@ -1,5 +1,4 @@ use std::{ - fs, net::{Ipv4Addr, SocketAddr}, sync::Arc, time::Duration, @@ -9,7 +8,6 @@ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; use tabby_common::{config::Config, usage}; -use tabby_download::download_model; use tabby_webserver::attach_webserver; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; @@ -20,7 +18,14 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{ api::{self}, fatal, routes, - services::{chat, completion, event::create_event_logger, health, model}, + services::{ + chat::{self, create_chat_service}, + completion::{self, create_completion_service}, + event::create_logger, + health, + model::download_model_if_needed, + }, + Device, }; #[derive(OpenApi)] @@ -62,41 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi )] struct ApiDoc; -#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] -pub enum Device { - #[strum(serialize = "cpu")] - Cpu, - - #[cfg(feature = "cuda")] - #[strum(serialize = "cuda")] - Cuda, - - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - #[strum(serialize = "metal")] - Metal, - - #[cfg(feature = "experimental-http")] - #[strum(serialize = "experimental_http")] - ExperimentalHttp, -} - -impl Device { - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Metal - } - - #[cfg(feature = "cuda")] - pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Cuda - } - - #[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))] - pub fn ggml_use_gpu(&self) -> bool { - false - } -} - #[derive(Args)] pub struct ServeArgs { /// Model id for `/completions` API endpoint. @@ -152,43 +122,30 @@ pub async fn main(config: &Config, args: &ServeArgs) { } async fn load_model(args: &ServeArgs) { - if fs::metadata(&args.model).is_ok() { - info!("Loading model from local path {}", &args.model); - } else { - download_model(&args.model, true).await; - if let Some(chat_model) = &args.chat_model { - download_model(chat_model, true).await; - } + download_model_if_needed(&args.model).await; + if let Some(chat_model) = &args.chat_model { + download_model_if_needed(chat_model).await } } async fn api_router(args: &ServeArgs, config: &Config) -> Router { - let logger = Arc::new(create_event_logger()); + let logger = Arc::new(create_logger()); let code = Arc::new(crate::services::code::create_code_search()); - let completion_state = { - let ( - engine, - model::PromptInfo { - prompt_template, .. - }, - ) = model::load_text_generation(&args.model, &args.device, args.parallelism).await; - let state = completion::CompletionService::new( - engine.clone(), + let completion = Arc::new( + create_completion_service( code.clone(), logger.clone(), - prompt_template, - ); - Arc::new(state) - }; + &args.model, + &args.device, + args.parallelism, + ) + .await, + ); - let chat_state = if let Some(chat_model) = &args.chat_model { - let (engine, model::PromptInfo { chat_template, .. }) = - model::load_text_generation(chat_model, &args.device, args.parallelism).await; - let Some(chat_template) = chat_template else { - panic!("Chat model requires specifying prompt template"); - }; - let state = chat::ChatService::new(engine, chat_template); - Some(Arc::new(state)) + let chat_state = if let Some(_chat_model) = &args.chat_model { + Some(Arc::new( + create_chat_service(&args.model, &args.device, args.parallelism).await, + )) } else { None }; @@ -220,7 +177,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { Router::new() .route( "/v1/completions", - routing::post(routes::completions).with_state(completion_state), + routing::post(routes::completions).with_state(completion), ) .layer(TimeoutLayer::new(Duration::from_secs( config.server.completion_timeout, diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index 37286ad..d791496 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -11,6 +11,9 @@ use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptio use tracing::debug; use utoipa::ToSchema; +use super::model; +use crate::{fatal, Device}; + #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ "messages": [ @@ -40,7 +43,7 @@ pub struct ChatService { } impl ChatService { - pub fn new(engine: Arc, chat_template: String) -> Self { + fn new(engine: Arc, chat_template: String) -> Self { Self { engine, prompt_builder: ChatPromptBuilder::new(chat_template), @@ -73,3 +76,14 @@ impl ChatService { Box::pin(s) } } + +pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService { + let (engine, model::PromptInfo { chat_template, .. }) = + model::load_text_generation(model, device, parallelism).await; + + let Some(chat_template) = chat_template else { + fatal!("Chat model requires specifying prompt template"); + }; + + ChatService::new(engine, chat_template) +} diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 8c76ed8..a76289b 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -9,10 +9,14 @@ use thiserror::Error; use tracing::debug; use utoipa::ToSchema; -use crate::api::{ - self, - code::CodeSearch, - event::{Event, EventLogger}, +use super::model; +use crate::{ + api::{ + self, + code::CodeSearch, + event::{Event, EventLogger}, + }, + Device, }; #[derive(Error, Debug)] @@ -166,7 +170,7 @@ pub struct CompletionService { } impl CompletionService { - pub fn new( + fn new( engine: Arc, code: Arc, logger: Arc, @@ -260,3 +264,20 @@ impl CompletionService { )) } } + +pub async fn create_completion_service( + code: Arc, + logger: Arc, + model: &str, + device: &Device, + parallelism: u8, +) -> CompletionService { + let ( + engine, + model::PromptInfo { + prompt_template, .. + }, + ) = model::load_text_generation(model, device, parallelism).await; + + CompletionService::new(engine.clone(), code, logger, prompt_template) +} diff --git a/crates/tabby/src/services/event.rs b/crates/tabby/src/services/event.rs index 9c77391..380f2d7 100644 --- a/crates/tabby/src/services/event.rs +++ b/crates/tabby/src/services/event.rs @@ -85,6 +85,16 @@ fn timestamp() -> u128 { .as_millis() } -pub fn create_event_logger() -> impl EventLogger { +pub fn create_logger() -> impl EventLogger { EventService } + +struct NullLogger; + +impl EventLogger for NullLogger { + fn log(&self, _e: &Event) {} +} + +pub fn create_null_logger() -> impl EventLogger { + NullLogger +} diff --git a/crates/tabby/src/services/health.rs b/crates/tabby/src/services/health.rs index 7db4cf4..a55db99 100644 --- a/crates/tabby/src/services/health.rs +++ b/crates/tabby/src/services/health.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use sysinfo::{CpuExt, System, SystemExt}; use utoipa::ToSchema; -use crate::serve::Device; +use crate::Device; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct HealthState { @@ -43,7 +43,7 @@ impl HealthState { } } -fn read_cpu_info() -> (String, usize) { +pub fn read_cpu_info() -> (String, usize) { let mut system = System::new_all(); system.refresh_cpu(); let cpus = system.cpus(); @@ -58,7 +58,7 @@ fn read_cpu_info() -> (String, usize) { (info, count) } -fn read_cuda_devices() -> Result> { +pub fn read_cuda_devices() -> Result> { // In cases of MacOS or docker containers where --gpus are not specified, // the Nvml::init() would return an error. In these scenarios, we // assign cuda_devices to be empty, indicating that the current runtime diff --git a/crates/tabby/src/services/model.rs b/crates/tabby/src/services/model.rs index 21f9545..8b35681 100644 --- a/crates/tabby/src/services/model.rs +++ b/crates/tabby/src/services/model.rs @@ -2,9 +2,11 @@ use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; +use tabby_download::download_model; use tabby_inference::TextGeneration; +use tracing::info; -use crate::{fatal, serve::Device}; +use crate::{fatal, Device}; pub async fn load_text_generation( model_id: &str, @@ -72,3 +74,11 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp llama_cpp_bindings::LlamaTextGeneration::new(options) } + +pub async fn download_model_if_needed(model: &str) { + if fs::metadata(model).is_ok() { + info!("Loading model from local path {}", model); + } else { + download_model(model, true).await; + } +} diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs new file mode 100644 index 0000000..d034434 --- /dev/null +++ b/crates/tabby/src/worker.rs @@ -0,0 +1,146 @@ +use std::{ + env::consts::ARCH, + net::{Ipv4Addr, SocketAddr}, + sync::Arc, +}; + +use axum::{routing, Router}; +use clap::Args; +use graphql_client::{reqwest::post_graphql, GraphQLQuery}; +use hyper::Server; +use tracing::{info, warn}; + +use crate::{ + fatal, routes, + services::{ + chat::create_chat_service, + code, + completion::create_completion_service, + event::{self}, + health::{read_cpu_info, read_cuda_devices}, + model::download_model_if_needed, + }, + Device, +}; + +#[derive(GraphQLQuery)] +#[graphql( + schema_path = "../../ee/tabby-webserver/graphql/schema.graphql", + query_path = "./graphql/worker.query.graphql" +)] +struct RegisterWorker; + +#[derive(Args)] +pub struct WorkerArgs { + /// URL to register this worker. + #[clap(long)] + url: String, + + #[clap(long, default_value_t = 8080)] + port: u16, + + /// Model id + #[clap(long, help_heading=Some("Model Options"))] + model: String, + + /// Device to run model inference. + #[clap(long, default_value_t=Device::Cpu, help_heading=Some("Model Options"))] + device: Device, + + /// Parallelism for model serving - increasing this number will have a significant impact on the + /// memory requirement e.g., GPU vRAM. + #[clap(long, default_value_t = 1, help_heading=Some("Model Options"))] + parallelism: u8, +} + +async fn make_chat_route(args: &WorkerArgs) -> Router { + let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); + + request_register(register_worker::WorkerKind::CHAT, args).await; + + Router::new().route( + "/v1beta/chat/completions", + routing::post(routes::chat_completions).with_state(state), + ) +} + +async fn make_completion_route(args: &WorkerArgs) -> Router { + let code = Arc::new(code::create_code_search()); + let logger = Arc::new(event::create_null_logger()); + let state = Arc::new( + create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await, + ); + + request_register(register_worker::WorkerKind::COMPLETION, args).await; + + Router::new().route( + "/v1/completions", + routing::post(routes::completions).with_state(state), + ) +} + +pub enum WorkerKind { + Chat, + Completion, +} + +pub async fn main(kind: WorkerKind, args: &WorkerArgs) { + download_model_if_needed(&args.model).await; + + info!("Starting worker, this might takes a few minutes..."); + + let app = match kind { + WorkerKind::Completion => make_completion_route(args).await, + WorkerKind::Chat => make_chat_route(args).await, + }; + + let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); + info!("Listening at {}", address); + + Server::bind(&address) + .serve(app.into_make_service()) + .await + .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) +} + +async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) { + request_register_impl( + kind, + args.url.clone(), + args.port as i64, + args.model.to_owned(), + args.device.to_string(), + ) + .await; +} + +async fn request_register_impl( + kind: register_worker::WorkerKind, + url: String, + port: i64, + name: String, + device: String, +) { + let client = reqwest::Client::new(); + let (cpu_info, cpu_count) = read_cpu_info(); + let cuda_devices = read_cuda_devices().unwrap_or_default(); + let variables = register_worker::Variables { + port, + kind, + name, + device, + arch: ARCH.to_string(), + cpu_info, + cpu_count: cpu_count as i64, + cuda_devices, + }; + + let url = format!("{}/graphql", url); + match post_graphql::(&client, &url, variables).await { + Ok(x) => { + let addr = x.data.unwrap().worker.addr; + info!("Worker alive at {}", addr); + } + Err(err) => warn!("Failed to register worker: {}", err), + } +} diff --git a/ee/tabby-webserver/examples/update-schema.rs b/ee/tabby-webserver/examples/update-schema.rs new file mode 100644 index 0000000..5ecce9f --- /dev/null +++ b/ee/tabby-webserver/examples/update-schema.rs @@ -0,0 +1,13 @@ +use std::fs::write; + +use juniper::EmptySubscription; +use tabby_webserver::schema::{Mutation, Query, Schema}; + +fn main() { + let schema = Schema::new(Query, Mutation, EmptySubscription::new()); + write( + "ee/tabby-webserver/graphql/schema.graphql", + schema.as_schema_language(), + ) + .unwrap(); +} diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql new file mode 100644 index 0000000..aa836eb --- /dev/null +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -0,0 +1,28 @@ +enum WorkerKind { + COMPLETION + CHAT +} + +type Mutation { + registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker! +} + +type Query { + workers: [Worker!]! +} + +type Worker { + kind: WorkerKind! + name: String! + addr: String! + device: String! + arch: String! + cpuInfo: String! + cpuCount: Int! + cudaDevices: [String!]! +} + +schema { + query: Query + mutation: Mutation +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 1a77c89..7f5dd63 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,4 +1,4 @@ -mod schema; +pub mod schema; mod ui; mod webserver; mod worker; diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs index d4f9956..5f130df 100644 --- a/ee/tabby-webserver/src/schema.rs +++ b/ee/tabby-webserver/src/schema.rs @@ -30,14 +30,14 @@ pub enum WorkerKind { #[derive(GraphQLObject, Clone, Debug)] pub struct Worker { - kind: WorkerKind, - addr: String, -} - -impl Worker { - pub fn new(kind: WorkerKind, addr: String) -> Self { - Self { kind, addr } - } + pub kind: WorkerKind, + pub name: String, + pub addr: String, + pub device: String, + pub arch: String, + pub cpu_info: String, + pub cpu_count: i32, + pub cuda_devices: Vec, } #[derive(Default)] @@ -56,13 +56,27 @@ pub struct Mutation; impl Mutation { async fn register_worker( request: &Request, - token: String, - kind: WorkerKind, port: i32, + kind: WorkerKind, + name: String, + device: String, + arch: String, + cpu_info: String, + cpu_count: i32, + cuda_devices: Vec, ) -> Result { let ws = &request.ws; - ws.register_worker(token, request.client_addr, kind, port) - .await + let worker = Worker { + name, + kind, + addr: format!("http://{}:{}", request.client_addr.ip(), port), + device, + arch, + cpu_info, + cpu_count, + cuda_devices, + }; + ws.register_worker(worker).await } } diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index 2cb0a56..67576dd 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -28,47 +28,26 @@ pub struct Webserver { chat: worker::WorkerGroup, } -// FIXME: generate token and support refreshing in database. -static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382"; - impl Webserver { - pub async fn register_worker( - &self, - token: String, - client_addr: SocketAddr, - kind: WorkerKind, - port: i32, - ) -> Result { - if token != WORKER_TOKEN { - return Err(WebserverError::InvalidToken(token)); - } - - let addr = SocketAddr::new(client_addr.ip(), port as u16); - let addr = match kind { - WorkerKind::Completion => self.completion.register(addr).await, - WorkerKind::Chat => self.chat.register(addr).await, + pub async fn register_worker(&self, worker: Worker) -> Result { + let worker = match worker.kind { + WorkerKind::Completion => self.completion.register(worker).await, + WorkerKind::Chat => self.chat.register(worker).await, }; - if let Some(addr) = addr { - info!("registering <{:?}> worker running at {}", kind, addr); - Ok(Worker::new(kind, addr)) + if let Some(worker) = worker { + info!( + "registering <{:?}> worker running at {}", + worker.kind, worker.addr + ); + Ok(worker) } else { Err(WebserverError::RequiresEnterpriseLicense) } } pub async fn list_workers(&self) -> Vec { - let make_workers = |x: WorkerKind, lst: Vec| -> Vec { - lst.into_iter() - .map(|addr| Worker::new(x.clone(), addr)) - .collect() - }; - - [ - make_workers(WorkerKind::Completion, self.completion.list().await), - make_workers(WorkerKind::Chat, self.chat.list().await), - ] - .concat() + [self.completion.list().await, self.chat.list().await].concat() } pub async fn dispatch_request( diff --git a/ee/tabby-webserver/src/worker.rs b/ee/tabby-webserver/src/worker.rs index b6c7f82..53eb755 100644 --- a/ee/tabby-webserver/src/worker.rs +++ b/ee/tabby-webserver/src/worker.rs @@ -1,42 +1,41 @@ -use std::{ - net::SocketAddr, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use tracing::error; +use crate::schema::Worker; + #[derive(Default)] pub struct WorkerGroup { - workers: RwLock>, + workers: RwLock>, } impl WorkerGroup { pub async fn select(&self) -> Option { let workers = self.workers.read().await; if workers.len() > 0 { - Some(workers[random_index(workers.len())].clone()) + Some(workers[random_index(workers.len())].addr.clone()) } else { None } } - pub async fn list(&self) -> Vec { + pub async fn list(&self) -> Vec { self.workers.read().await.clone() } - pub async fn register(&self, addr: SocketAddr) -> Option { - let addr = format!("http://{}", addr); + pub async fn register(&self, worker: Worker) -> Option { let mut workers = self.workers.write().await; if workers.len() >= 1 { error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information."); return None; } - if !workers.contains(&addr) { - workers.push(addr.clone()); + if workers.iter().all(|x| x.addr != worker.addr) { + workers.push(worker.clone()); } - Some(addr) + + Some(worker) } } @@ -50,23 +49,38 @@ fn random_index(size: usize) -> usize { #[cfg(test)] mod tests { + use super::*; + use crate::schema::WorkerKind; #[tokio::test] async fn test_worker_group() { let wg = WorkerGroup::default(); - let addr1 = "127.0.0.1:8080".parse().unwrap(); - let addr2 = "127.0.0.2:8080".parse().unwrap(); + let worker1 = make_worker("http://127.0.0.1:8080"); + let worker2 = make_worker("http://127.0.0.2:8080"); // Register success. - assert!(wg.register(addr1).await.is_some()); + assert!(wg.register(worker1.clone()).await.is_some()); // Register failed, as > 1 workers requires enterprise license. - assert!(wg.register(addr2).await.is_none()); + assert!(wg.register(worker2).await.is_none()); let workers = wg.list().await; assert_eq!(workers.len(), 1); - assert_eq!(workers[0], format!("http://{}", addr1)); + assert_eq!(workers[0].addr, worker1.addr); + } + + fn make_worker(addr: &str) -> Worker { + Worker { + name: "Fake worker".to_owned(), + kind: WorkerKind::Chat, + addr: addr.to_owned(), + device: "cuda".to_owned(), + arch: "x86_64".to_owned(), + cpu_info: "Fake CPU".to_owned(), + cpu_count: 32, + cuda_devices: vec![], + } } }