feat: add worker command worker::completion and worker::chat (#778)
parent
510eddca89
commit
e521f0637c
|
|
@ -1711,6 +1711,15 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-introspection-query"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f2a4732cf5140bd6c082434494f785a19cfb566ab07d1382c3671f5812fed6d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-parser"
|
||||
version = "0.3.0"
|
||||
|
|
@ -1721,6 +1730,56 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-parser"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
|
||||
dependencies = [
|
||||
"combine",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_client"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cdf7b487d864c2939b23902291a5041bc4a84418268f25fda1c8d4e15ad8fa"
|
||||
dependencies = [
|
||||
"graphql_query_derive",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_client_codegen"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a40f793251171991c4eb75bd84bc640afa8b68ff6907bc89d3b712a22f700506"
|
||||
dependencies = [
|
||||
"graphql-introspection-query",
|
||||
"graphql-parser 0.4.0",
|
||||
"heck 0.4.1",
|
||||
"lazy_static",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql_query_derive"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00bda454f3d313f909298f626115092d348bc231025699f557b27e248475f48c"
|
||||
dependencies = [
|
||||
"graphql_client_codegen",
|
||||
"proc-macro2",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.19"
|
||||
|
|
@ -2148,7 +2207,7 @@ dependencies = [
|
|||
"fnv",
|
||||
"futures",
|
||||
"futures-enum",
|
||||
"graphql-parser",
|
||||
"graphql-parser 0.3.0",
|
||||
"indexmap 1.9.3",
|
||||
"juniper_codegen",
|
||||
"serde",
|
||||
|
|
@ -4285,6 +4344,7 @@ dependencies = [
|
|||
"chrono",
|
||||
"clap 4.4.7",
|
||||
"futures",
|
||||
"graphql_client",
|
||||
"http-api-bindings",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
|
|
|
|||
3
Makefile
3
Makefile
|
|
@ -35,3 +35,6 @@ update-openapi-doc:
|
|||
["components", "schemas", "DebugOptions"] \
|
||||
])' | jq '.servers[0] |= { url: "https://playground.app.tabbyml.com", description: "Playground server" }' \
|
||||
> website/static/openapi.json
|
||||
|
||||
update-graphql-schema:
|
||||
cargo run --package tabby-webserver --example update-schema
|
||||
|
|
@ -47,6 +47,8 @@ async-trait.workspace = true
|
|||
tabby-webserver = { path = "../../ee/tabby-webserver" }
|
||||
thiserror.workspace = true
|
||||
chrono = "0.4.31"
|
||||
graphql_client = { version = "0.13.0", features = ["reqwest"] }
|
||||
reqwest.workspace = true
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
mutation RegisterWorker(
|
||||
$port: Int!
|
||||
$kind: WorkerKind!
|
||||
$name: String!
|
||||
$device: String!
|
||||
$arch: String!
|
||||
$cpuInfo: String!
|
||||
$cpuCount: Int!
|
||||
$cudaDevices: [String!]!
|
||||
) {
|
||||
worker: registerWorker(
|
||||
port: $port
|
||||
kind: $kind
|
||||
name: $name
|
||||
device: $device
|
||||
arch: $arch
|
||||
cpuInfo: $cpuInfo
|
||||
cpuCount: $cpuCount
|
||||
cudaDevices: $cudaDevices
|
||||
) {
|
||||
addr
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
mod api;
|
||||
mod download;
|
||||
mod routes;
|
||||
mod serve;
|
||||
mod services;
|
||||
|
||||
mod download;
|
||||
mod serve;
|
||||
mod worker;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use opentelemetry::{
|
||||
global,
|
||||
|
|
@ -36,6 +38,16 @@ pub enum Commands {
|
|||
|
||||
/// Run scheduler progress for cron jobs integrating external code repositories.
|
||||
Scheduler(SchedulerArgs),
|
||||
|
||||
/// Run completion model as worker
|
||||
#[clap(name = "worker::completion")]
|
||||
#[command(arg_required_else_help = true)]
|
||||
WorkerCompletion(worker::WorkerArgs),
|
||||
|
||||
/// Run chat model as worker
|
||||
#[clap(name = "worker::chat")]
|
||||
#[command(arg_required_else_help = true)]
|
||||
WorkerChat(worker::WorkerArgs),
|
||||
}
|
||||
|
||||
#[derive(clap::Args)]
|
||||
|
|
@ -45,6 +57,41 @@ pub struct SchedulerArgs {
|
|||
now: bool,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||
pub enum Device {
|
||||
#[strum(serialize = "cpu")]
|
||||
Cpu,
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
#[strum(serialize = "cuda")]
|
||||
Cuda,
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
#[strum(serialize = "metal")]
|
||||
Metal,
|
||||
|
||||
#[cfg(feature = "experimental-http")]
|
||||
#[strum(serialize = "experimental_http")]
|
||||
ExperimentalHttp,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Metal
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Cuda
|
||||
}
|
||||
|
||||
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cli = Cli::parse();
|
||||
|
|
@ -58,6 +105,10 @@ async fn main() {
|
|||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||
.await
|
||||
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
|
||||
Commands::WorkerCompletion(args) => {
|
||||
worker::main(worker::WorkerKind::Completion, args).await
|
||||
}
|
||||
Commands::WorkerChat(args) => worker::main(worker::WorkerKind::Chat, args).await,
|
||||
}
|
||||
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
use std::{
|
||||
fs,
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
|
|
@ -9,7 +8,6 @@ use axum::{routing, Router, Server};
|
|||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use clap::Args;
|
||||
use tabby_common::{config::Config, usage};
|
||||
use tabby_download::download_model;
|
||||
use tabby_webserver::attach_webserver;
|
||||
use tokio::time::sleep;
|
||||
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
|
||||
|
|
@ -20,7 +18,14 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||
use crate::{
|
||||
api::{self},
|
||||
fatal, routes,
|
||||
services::{chat, completion, event::create_event_logger, health, model},
|
||||
services::{
|
||||
chat::{self, create_chat_service},
|
||||
completion::{self, create_completion_service},
|
||||
event::create_logger,
|
||||
health,
|
||||
model::download_model_if_needed,
|
||||
},
|
||||
Device,
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
|
|
@ -62,41 +67,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
|
||||
pub enum Device {
|
||||
#[strum(serialize = "cpu")]
|
||||
Cpu,
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
#[strum(serialize = "cuda")]
|
||||
Cuda,
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
#[strum(serialize = "metal")]
|
||||
Metal,
|
||||
|
||||
#[cfg(feature = "experimental-http")]
|
||||
#[strum(serialize = "experimental_http")]
|
||||
ExperimentalHttp,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Metal
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
*self == Device::Cuda
|
||||
}
|
||||
|
||||
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
|
||||
pub fn ggml_use_gpu(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct ServeArgs {
|
||||
/// Model id for `/completions` API endpoint.
|
||||
|
|
@ -152,43 +122,30 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
}
|
||||
|
||||
async fn load_model(args: &ServeArgs) {
|
||||
if fs::metadata(&args.model).is_ok() {
|
||||
info!("Loading model from local path {}", &args.model);
|
||||
} else {
|
||||
download_model(&args.model, true).await;
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model(chat_model, true).await;
|
||||
}
|
||||
download_model_if_needed(&args.model).await;
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model_if_needed(chat_model).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
let logger = Arc::new(create_event_logger());
|
||||
let logger = Arc::new(create_logger());
|
||||
let code = Arc::new(crate::services::code::create_code_search());
|
||||
let completion_state = {
|
||||
let (
|
||||
engine,
|
||||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||
let state = completion::CompletionService::new(
|
||||
engine.clone(),
|
||||
let completion = Arc::new(
|
||||
create_completion_service(
|
||||
code.clone(),
|
||||
logger.clone(),
|
||||
prompt_template,
|
||||
);
|
||||
Arc::new(state)
|
||||
};
|
||||
&args.model,
|
||||
&args.device,
|
||||
args.parallelism,
|
||||
)
|
||||
.await,
|
||||
);
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||
model::load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let state = chat::ChatService::new(engine, chat_template);
|
||||
Some(Arc::new(state))
|
||||
let chat_state = if let Some(_chat_model) = &args.chat_model {
|
||||
Some(Arc::new(
|
||||
create_chat_service(&args.model, &args.device, args.parallelism).await,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
@ -220,7 +177,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
Router::new()
|
||||
.route(
|
||||
"/v1/completions",
|
||||
routing::post(routes::completions).with_state(completion_state),
|
||||
routing::post(routes::completions).with_state(completion),
|
||||
)
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||
config.server.completion_timeout,
|
||||
|
|
@ -11,6 +11,9 @@ use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptio
|
|||
use tracing::debug;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::model;
|
||||
use crate::{fatal, Device};
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"messages": [
|
||||
|
|
@ -40,7 +43,7 @@ pub struct ChatService {
|
|||
}
|
||||
|
||||
impl ChatService {
|
||||
pub fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
|
||||
fn new(engine: Arc<dyn TextGeneration>, chat_template: String) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: ChatPromptBuilder::new(chat_template),
|
||||
|
|
@ -73,3 +76,14 @@ impl ChatService {
|
|||
Box::pin(s)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
|
||||
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||
model::load_text_generation(model, device, parallelism).await;
|
||||
|
||||
let Some(chat_template) = chat_template else {
|
||||
fatal!("Chat model requires specifying prompt template");
|
||||
};
|
||||
|
||||
ChatService::new(engine, chat_template)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,14 @@ use thiserror::Error;
|
|||
use tracing::debug;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::api::{
|
||||
self,
|
||||
code::CodeSearch,
|
||||
event::{Event, EventLogger},
|
||||
use super::model;
|
||||
use crate::{
|
||||
api::{
|
||||
self,
|
||||
code::CodeSearch,
|
||||
event::{Event, EventLogger},
|
||||
},
|
||||
Device,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
|
@ -166,7 +170,7 @@ pub struct CompletionService {
|
|||
}
|
||||
|
||||
impl CompletionService {
|
||||
pub fn new(
|
||||
fn new(
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
|
|
@ -260,3 +264,20 @@ impl CompletionService {
|
|||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_completion_service(
|
||||
code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
model: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> CompletionService {
|
||||
let (
|
||||
engine,
|
||||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = model::load_text_generation(model, device, parallelism).await;
|
||||
|
||||
CompletionService::new(engine.clone(), code, logger, prompt_template)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,6 +85,16 @@ fn timestamp() -> u128 {
|
|||
.as_millis()
|
||||
}
|
||||
|
||||
pub fn create_event_logger() -> impl EventLogger {
|
||||
pub fn create_logger() -> impl EventLogger {
|
||||
EventService
|
||||
}
|
||||
|
||||
struct NullLogger;
|
||||
|
||||
impl EventLogger for NullLogger {
|
||||
fn log(&self, _e: &Event) {}
|
||||
}
|
||||
|
||||
pub fn create_null_logger() -> impl EventLogger {
|
||||
NullLogger
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
|
|||
use sysinfo::{CpuExt, System, SystemExt};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::serve::Device;
|
||||
use crate::Device;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct HealthState {
|
||||
|
|
@ -43,7 +43,7 @@ impl HealthState {
|
|||
}
|
||||
}
|
||||
|
||||
fn read_cpu_info() -> (String, usize) {
|
||||
pub fn read_cpu_info() -> (String, usize) {
|
||||
let mut system = System::new_all();
|
||||
system.refresh_cpu();
|
||||
let cpus = system.cpus();
|
||||
|
|
@ -58,7 +58,7 @@ fn read_cpu_info() -> (String, usize) {
|
|||
(info, count)
|
||||
}
|
||||
|
||||
fn read_cuda_devices() -> Result<Vec<String>> {
|
||||
pub fn read_cuda_devices() -> Result<Vec<String>> {
|
||||
// In cases of MacOS or docker containers where --gpus are not specified,
|
||||
// the Nvml::init() would return an error. In these scenarios, we
|
||||
// assign cuda_devices to be empty, indicating that the current runtime
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ use std::{fs, path::PathBuf, sync::Arc};
|
|||
|
||||
use serde::Deserialize;
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||
use tabby_download::download_model;
|
||||
use tabby_inference::TextGeneration;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{fatal, serve::Device};
|
||||
use crate::{fatal, Device};
|
||||
|
||||
pub async fn load_text_generation(
|
||||
model_id: &str,
|
||||
|
|
@ -72,3 +74,11 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp
|
|||
|
||||
llama_cpp_bindings::LlamaTextGeneration::new(options)
|
||||
}
|
||||
|
||||
pub async fn download_model_if_needed(model: &str) {
|
||||
if fs::metadata(model).is_ok() {
|
||||
info!("Loading model from local path {}", model);
|
||||
} else {
|
||||
download_model(model, true).await;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,146 @@
|
|||
use std::{
|
||||
env::consts::ARCH,
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::{routing, Router};
|
||||
use clap::Args;
|
||||
use graphql_client::{reqwest::post_graphql, GraphQLQuery};
|
||||
use hyper::Server;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
fatal, routes,
|
||||
services::{
|
||||
chat::create_chat_service,
|
||||
code,
|
||||
completion::create_completion_service,
|
||||
event::{self},
|
||||
health::{read_cpu_info, read_cuda_devices},
|
||||
model::download_model_if_needed,
|
||||
},
|
||||
Device,
|
||||
};
|
||||
|
||||
#[derive(GraphQLQuery)]
|
||||
#[graphql(
|
||||
schema_path = "../../ee/tabby-webserver/graphql/schema.graphql",
|
||||
query_path = "./graphql/worker.query.graphql"
|
||||
)]
|
||||
struct RegisterWorker;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct WorkerArgs {
|
||||
/// URL to register this worker.
|
||||
#[clap(long)]
|
||||
url: String,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
/// Model id
|
||||
#[clap(long, help_heading=Some("Model Options"))]
|
||||
model: String,
|
||||
|
||||
/// Device to run model inference.
|
||||
#[clap(long, default_value_t=Device::Cpu, help_heading=Some("Model Options"))]
|
||||
device: Device,
|
||||
|
||||
/// Parallelism for model serving - increasing this number will have a significant impact on the
|
||||
/// memory requirement e.g., GPU vRAM.
|
||||
#[clap(long, default_value_t = 1, help_heading=Some("Model Options"))]
|
||||
parallelism: u8,
|
||||
}
|
||||
|
||||
async fn make_chat_route(args: &WorkerArgs) -> Router {
|
||||
let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await);
|
||||
|
||||
request_register(register_worker::WorkerKind::CHAT, args).await;
|
||||
|
||||
Router::new().route(
|
||||
"/v1beta/chat/completions",
|
||||
routing::post(routes::chat_completions).with_state(state),
|
||||
)
|
||||
}
|
||||
|
||||
async fn make_completion_route(args: &WorkerArgs) -> Router {
|
||||
let code = Arc::new(code::create_code_search());
|
||||
let logger = Arc::new(event::create_null_logger());
|
||||
let state = Arc::new(
|
||||
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
|
||||
);
|
||||
|
||||
request_register(register_worker::WorkerKind::COMPLETION, args).await;
|
||||
|
||||
Router::new().route(
|
||||
"/v1/completions",
|
||||
routing::post(routes::completions).with_state(state),
|
||||
)
|
||||
}
|
||||
|
||||
pub enum WorkerKind {
|
||||
Chat,
|
||||
Completion,
|
||||
}
|
||||
|
||||
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
|
||||
download_model_if_needed(&args.model).await;
|
||||
|
||||
info!("Starting worker, this might takes a few minutes...");
|
||||
|
||||
let app = match kind {
|
||||
WorkerKind::Completion => make_completion_route(args).await,
|
||||
WorkerKind::Chat => make_chat_route(args).await,
|
||||
};
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
info!("Listening at {}", address);
|
||||
|
||||
Server::bind(&address)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||
}
|
||||
|
||||
async fn request_register(kind: register_worker::WorkerKind, args: &WorkerArgs) {
|
||||
request_register_impl(
|
||||
kind,
|
||||
args.url.clone(),
|
||||
args.port as i64,
|
||||
args.model.to_owned(),
|
||||
args.device.to_string(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn request_register_impl(
|
||||
kind: register_worker::WorkerKind,
|
||||
url: String,
|
||||
port: i64,
|
||||
name: String,
|
||||
device: String,
|
||||
) {
|
||||
let client = reqwest::Client::new();
|
||||
let (cpu_info, cpu_count) = read_cpu_info();
|
||||
let cuda_devices = read_cuda_devices().unwrap_or_default();
|
||||
let variables = register_worker::Variables {
|
||||
port,
|
||||
kind,
|
||||
name,
|
||||
device,
|
||||
arch: ARCH.to_string(),
|
||||
cpu_info,
|
||||
cpu_count: cpu_count as i64,
|
||||
cuda_devices,
|
||||
};
|
||||
|
||||
let url = format!("{}/graphql", url);
|
||||
match post_graphql::<RegisterWorker, _>(&client, &url, variables).await {
|
||||
Ok(x) => {
|
||||
let addr = x.data.unwrap().worker.addr;
|
||||
info!("Worker alive at {}", addr);
|
||||
}
|
||||
Err(err) => warn!("Failed to register worker: {}", err),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
use std::fs::write;
|
||||
|
||||
use juniper::EmptySubscription;
|
||||
use tabby_webserver::schema::{Mutation, Query, Schema};
|
||||
|
||||
fn main() {
|
||||
let schema = Schema::new(Query, Mutation, EmptySubscription::new());
|
||||
write(
|
||||
"ee/tabby-webserver/graphql/schema.graphql",
|
||||
schema.as_schema_language(),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
enum WorkerKind {
|
||||
COMPLETION
|
||||
CHAT
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
registerWorker(port: Int!, kind: WorkerKind!, name: String!, device: String!, arch: String!, cpuInfo: String!, cpuCount: Int!, cudaDevices: [String!]!): Worker!
|
||||
}
|
||||
|
||||
type Query {
|
||||
workers: [Worker!]!
|
||||
}
|
||||
|
||||
type Worker {
|
||||
kind: WorkerKind!
|
||||
name: String!
|
||||
addr: String!
|
||||
device: String!
|
||||
arch: String!
|
||||
cpuInfo: String!
|
||||
cpuCount: Int!
|
||||
cudaDevices: [String!]!
|
||||
}
|
||||
|
||||
schema {
|
||||
query: Query
|
||||
mutation: Mutation
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
mod schema;
|
||||
pub mod schema;
|
||||
mod ui;
|
||||
mod webserver;
|
||||
mod worker;
|
||||
|
|
|
|||
|
|
@ -30,14 +30,14 @@ pub enum WorkerKind {
|
|||
|
||||
#[derive(GraphQLObject, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
kind: WorkerKind,
|
||||
addr: String,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
pub fn new(kind: WorkerKind, addr: String) -> Self {
|
||||
Self { kind, addr }
|
||||
}
|
||||
pub kind: WorkerKind,
|
||||
pub name: String,
|
||||
pub addr: String,
|
||||
pub device: String,
|
||||
pub arch: String,
|
||||
pub cpu_info: String,
|
||||
pub cpu_count: i32,
|
||||
pub cuda_devices: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
@ -56,13 +56,27 @@ pub struct Mutation;
|
|||
impl Mutation {
|
||||
async fn register_worker(
|
||||
request: &Request,
|
||||
token: String,
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
kind: WorkerKind,
|
||||
name: String,
|
||||
device: String,
|
||||
arch: String,
|
||||
cpu_info: String,
|
||||
cpu_count: i32,
|
||||
cuda_devices: Vec<String>,
|
||||
) -> Result<Worker, WebserverError> {
|
||||
let ws = &request.ws;
|
||||
ws.register_worker(token, request.client_addr, kind, port)
|
||||
.await
|
||||
let worker = Worker {
|
||||
name,
|
||||
kind,
|
||||
addr: format!("http://{}:{}", request.client_addr.ip(), port),
|
||||
device,
|
||||
arch,
|
||||
cpu_info,
|
||||
cpu_count,
|
||||
cuda_devices,
|
||||
};
|
||||
ws.register_worker(worker).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -28,47 +28,26 @@ pub struct Webserver {
|
|||
chat: worker::WorkerGroup,
|
||||
}
|
||||
|
||||
// FIXME: generate token and support refreshing in database.
|
||||
static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382";
|
||||
|
||||
impl Webserver {
|
||||
pub async fn register_worker(
|
||||
&self,
|
||||
token: String,
|
||||
client_addr: SocketAddr,
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
) -> Result<Worker, WebserverError> {
|
||||
if token != WORKER_TOKEN {
|
||||
return Err(WebserverError::InvalidToken(token));
|
||||
}
|
||||
|
||||
let addr = SocketAddr::new(client_addr.ip(), port as u16);
|
||||
let addr = match kind {
|
||||
WorkerKind::Completion => self.completion.register(addr).await,
|
||||
WorkerKind::Chat => self.chat.register(addr).await,
|
||||
pub async fn register_worker(&self, worker: Worker) -> Result<Worker, WebserverError> {
|
||||
let worker = match worker.kind {
|
||||
WorkerKind::Completion => self.completion.register(worker).await,
|
||||
WorkerKind::Chat => self.chat.register(worker).await,
|
||||
};
|
||||
|
||||
if let Some(addr) = addr {
|
||||
info!("registering <{:?}> worker running at {}", kind, addr);
|
||||
Ok(Worker::new(kind, addr))
|
||||
if let Some(worker) = worker {
|
||||
info!(
|
||||
"registering <{:?}> worker running at {}",
|
||||
worker.kind, worker.addr
|
||||
);
|
||||
Ok(worker)
|
||||
} else {
|
||||
Err(WebserverError::RequiresEnterpriseLicense)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
||||
let make_workers = |x: WorkerKind, lst: Vec<String>| -> Vec<Worker> {
|
||||
lst.into_iter()
|
||||
.map(|addr| Worker::new(x.clone(), addr))
|
||||
.collect()
|
||||
};
|
||||
|
||||
[
|
||||
make_workers(WorkerKind::Completion, self.completion.list().await),
|
||||
make_workers(WorkerKind::Chat, self.chat.list().await),
|
||||
]
|
||||
.concat()
|
||||
[self.completion.list().await, self.chat.list().await].concat()
|
||||
}
|
||||
|
||||
pub async fn dispatch_request(
|
||||
|
|
|
|||
|
|
@ -1,42 +1,41 @@
|
|||
use std::{
|
||||
net::SocketAddr,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::error;
|
||||
|
||||
use crate::schema::Worker;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WorkerGroup {
|
||||
workers: RwLock<Vec<String>>,
|
||||
workers: RwLock<Vec<Worker>>,
|
||||
}
|
||||
|
||||
impl WorkerGroup {
|
||||
pub async fn select(&self) -> Option<String> {
|
||||
let workers = self.workers.read().await;
|
||||
if workers.len() > 0 {
|
||||
Some(workers[random_index(workers.len())].clone())
|
||||
Some(workers[random_index(workers.len())].addr.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list(&self) -> Vec<String> {
|
||||
pub async fn list(&self) -> Vec<Worker> {
|
||||
self.workers.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn register(&self, addr: SocketAddr) -> Option<String> {
|
||||
let addr = format!("http://{}", addr);
|
||||
pub async fn register(&self, worker: Worker) -> Option<Worker> {
|
||||
let mut workers = self.workers.write().await;
|
||||
if workers.len() >= 1 {
|
||||
error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information.");
|
||||
return None;
|
||||
}
|
||||
|
||||
if !workers.contains(&addr) {
|
||||
workers.push(addr.clone());
|
||||
if workers.iter().all(|x| x.addr != worker.addr) {
|
||||
workers.push(worker.clone());
|
||||
}
|
||||
Some(addr)
|
||||
|
||||
Some(worker)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -50,23 +49,38 @@ fn random_index(size: usize) -> usize {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::schema::WorkerKind;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_group() {
|
||||
let wg = WorkerGroup::default();
|
||||
|
||||
let addr1 = "127.0.0.1:8080".parse().unwrap();
|
||||
let addr2 = "127.0.0.2:8080".parse().unwrap();
|
||||
let worker1 = make_worker("http://127.0.0.1:8080");
|
||||
let worker2 = make_worker("http://127.0.0.2:8080");
|
||||
|
||||
// Register success.
|
||||
assert!(wg.register(addr1).await.is_some());
|
||||
assert!(wg.register(worker1.clone()).await.is_some());
|
||||
|
||||
// Register failed, as > 1 workers requires enterprise license.
|
||||
assert!(wg.register(addr2).await.is_none());
|
||||
assert!(wg.register(worker2).await.is_none());
|
||||
|
||||
let workers = wg.list().await;
|
||||
assert_eq!(workers.len(), 1);
|
||||
assert_eq!(workers[0], format!("http://{}", addr1));
|
||||
assert_eq!(workers[0].addr, worker1.addr);
|
||||
}
|
||||
|
||||
fn make_worker(addr: &str) -> Worker {
|
||||
Worker {
|
||||
name: "Fake worker".to_owned(),
|
||||
kind: WorkerKind::Chat,
|
||||
addr: addr.to_owned(),
|
||||
device: "cuda".to_owned(),
|
||||
arch: "x86_64".to_owned(),
|
||||
cpu_info: "Fake CPU".to_owned(),
|
||||
cpu_count: 32,
|
||||
cuda_devices: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue