diff --git a/Cargo.lock b/Cargo.lock index 490c128..48808e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4475,8 +4475,10 @@ dependencies = [ "reqwest", "serde", "serde-jsonlines", + "serde_json", "serdeconv", "tantivy", + "utoipa", "uuid 1.4.1", ] diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index 9ef3e33..c39af3a 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -14,6 +14,14 @@ uuid = { version = "1.4.1", features = ["v4"] } tantivy.workspace = true anyhow.workspace = true glob = "0.3.1" +utoipa.workspace = true +serde_json.workspace = true [features] testutils = [] + +[package.metadata.cargo-machete] +ignored = [ + # required in utoipa ToSchema. + "serde_json" +] \ No newline at end of file diff --git a/crates/tabby/src/api/event.rs b/crates/tabby-common/src/api/event.rs similarity index 66% rename from crates/tabby/src/api/event.rs rename to crates/tabby-common/src/api/event.rs index 0eceba3..70fd5e6 100644 --- a/crates/tabby/src/api/event.rs +++ b/crates/tabby-common/src/api/event.rs @@ -56,3 +56,34 @@ pub struct Segments { pub trait EventLogger: Send + Sync { fn log(&self, e: &Event); } + +#[derive(Serialize)] +struct Log<'a> { + ts: u128, + event: &'a Event<'a>, +} + +pub trait RawEventLogger: Send + Sync { + fn log(&self, content: String); +} + +impl EventLogger for T { + fn log(&self, e: &Event) { + let content = serdeconv::to_json_string(&Log { + ts: timestamp(), + event: e, + }) + .unwrap(); + + self.log(content); + } +} + +fn timestamp() -> u128 { + use std::time::{SystemTime, UNIX_EPOCH}; + let start = SystemTime::now(); + start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() +} diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs new file mode 100644 index 0000000..53f1126 --- /dev/null +++ b/crates/tabby-common/src/api/mod.rs @@ -0,0 +1 @@ +pub mod event; diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index cce151d..b157cbc 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,3 +1,4 @@ +pub mod api; pub mod config; pub mod index; pub mod languages; diff --git a/crates/tabby/src/api/mod.rs b/crates/tabby/src/api/mod.rs index cebf170..9de50d4 100644 --- a/crates/tabby/src/api/mod.rs +++ b/crates/tabby/src/api/mod.rs @@ -1,2 +1 @@ pub mod code; -pub mod event; diff --git a/crates/tabby/src/routes/events.rs b/crates/tabby/src/routes/events.rs index 61e8037..c8f747d 100644 --- a/crates/tabby/src/routes/events.rs +++ b/crates/tabby/src/routes/events.rs @@ -5,8 +5,7 @@ use axum::{ Json, }; use hyper::StatusCode; - -use crate::api::event::{Event, EventLogger, LogEventRequest, SelectKind}; +use tabby_common::api::event::{Event, EventLogger, LogEventRequest, SelectKind}; #[utoipa::path( post, diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index b2a7348..ab15b68 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -7,7 +7,7 @@ use std::{ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; -use tabby_common::{config::Config, usage}; +use tabby_common::{api::event::EventLogger, config::Config, usage}; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tracing::info; @@ -46,7 +46,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi ), paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search), components(schemas( - api::event::LogEventRequest, + tabby_common::api::event::LogEventRequest, completion::CompletionRequest, completion::CompletionResponse, completion::Segments, @@ -101,12 +101,14 @@ pub async fn main(config: &Config, args: &ServeArgs) { info!("Starting server, this might takes a few minutes..."); + let logger = Arc::new(create_logger()); + let app = Router::new() - .merge(api_router(args, config).await) + .merge(api_router(args, config, logger.clone()).await) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())); #[cfg(feature = "ee")] - let app = tabby_webserver::attach_webserver(app).await; + let app = tabby_webserver::attach_webserver(app, logger).await; #[cfg(not(feature = "ee"))] let app = app.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") }); @@ -131,8 +133,7 @@ async fn load_model(args: &ServeArgs) { } } -async fn api_router(args: &ServeArgs, config: &Config) -> Router { - let logger = Arc::new(create_logger()); +async fn api_router(args: &ServeArgs, config: &Config, logger: Arc) -> Router { let code = Arc::new(crate::services::code::create_code_search()); let completion_state = if let Some(model) = &args.model { diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index a76289b..11aca1c 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -3,21 +3,18 @@ mod completion_prompt; use std::sync::Arc; use serde::{Deserialize, Serialize}; -use tabby_common::languages::get_language; +use tabby_common::{ + api, + api::event::{Event, EventLogger}, + languages::get_language, +}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use thiserror::Error; use tracing::debug; use utoipa::ToSchema; use super::model; -use crate::{ - api::{ - self, - code::CodeSearch, - event::{Event, EventLogger}, - }, - Device, -}; +use crate::{api::code::CodeSearch, Device}; #[derive(Error, Debug)] pub enum CompletionError { diff --git a/crates/tabby/src/services/event.rs b/crates/tabby/src/services/event.rs index 380f2d7..da83e10 100644 --- a/crates/tabby/src/services/event.rs +++ b/crates/tabby/src/services/event.rs @@ -6,15 +6,12 @@ use std::{ use chrono::Utc; use lazy_static::lazy_static; -use serde::Serialize; -use tabby_common::path; +use tabby_common::{api::event::RawEventLogger, path}; use tokio::{ sync::mpsc::{unbounded_channel, UnboundedSender}, time::{self}, }; -use crate::api::event::{Event, EventLogger}; - lazy_static! { static ref WRITER: UnboundedSender = { let (tx, mut rx) = unbounded_channel::(); @@ -58,43 +55,12 @@ lazy_static! { struct EventService; -#[derive(Serialize)] -struct Log<'a> { - ts: u128, - event: &'a Event<'a>, -} - -impl EventLogger for EventService { - fn log(&self, e: &Event) { - let content = serdeconv::to_json_string(&Log { - ts: timestamp(), - event: e, - }) - .unwrap(); - +impl RawEventLogger for EventService { + fn log(&self, content: String) { WRITER.send(content).unwrap(); } } -fn timestamp() -> u128 { - use std::time::{SystemTime, UNIX_EPOCH}; - let start = SystemTime::now(); - start - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() -} - -pub fn create_logger() -> impl EventLogger { +pub fn create_logger() -> impl RawEventLogger { EventService } - -struct NullLogger; - -impl EventLogger for NullLogger { - fn log(&self, _e: &Event) {} -} - -pub fn create_null_logger() -> impl EventLogger { - NullLogger -} diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index 0fb390e..4991f78 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -8,7 +8,8 @@ use anyhow::Result; use axum::{routing, Router}; use clap::Args; use hyper::Server; -use tabby_webserver::api::WorkerKind; +use tabby_common::api::event::EventLogger; +use tabby_webserver::api::{tracing_context, HubClient, WorkerKind}; use tracing::{info, warn}; use crate::{ @@ -17,7 +18,6 @@ use crate::{ chat::create_chat_service, code, completion::create_completion_service, - event::{self}, health::{read_cpu_info, read_cuda_devices}, model::download_model_if_needed, }, @@ -51,29 +51,30 @@ pub struct WorkerArgs { parallelism: u8, } -async fn make_chat_route(args: &WorkerArgs) -> Router { - let state = Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); +async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router { + context.register(WorkerKind::Chat, args).await; - request_register(WorkerKind::Chat, args).await; + let chat_state = + Arc::new(create_chat_service(&args.model, &args.device, args.parallelism).await); Router::new().route( "/v1beta/chat/completions", - routing::post(routes::chat_completions).with_state(state), + routing::post(routes::chat_completions).with_state(chat_state), ) } -async fn make_completion_route(args: &WorkerArgs) -> Router { +async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router { + context.register(WorkerKind::Completion, args).await; + let code = Arc::new(code::create_code_search()); - let logger = Arc::new(event::create_null_logger()); - let state = Arc::new( + let logger: Arc = Arc::new(context.client); + let completion_state = Arc::new( create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await, ); - request_register(WorkerKind::Completion, args).await; - Router::new().route( "/v1/completions", - routing::post(routes::completions).with_state(state), + routing::post(routes::completions).with_state(completion_state), ) } @@ -82,9 +83,10 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) { info!("Starting worker, this might takes a few minutes..."); + let context = WorkerContext::new(&args.url).await; let app = match kind { - WorkerKind::Completion => make_completion_route(args).await, - WorkerKind::Chat => make_chat_route(args).await, + WorkerKind::Completion => make_completion_route(context, args).await, + WorkerKind::Chat => make_chat_route(context, args).await, }; let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); @@ -96,48 +98,44 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -async fn request_register(kind: WorkerKind, args: &WorkerArgs) { - if let Err(err) = request_register_impl( - kind, - args.url.clone(), - args.port, - args.model.to_owned(), - args.device.to_string(), - args.token.clone(), - ) - .await - { - warn!("Failed to register worker: {}", err) +struct WorkerContext { + client: HubClient, +} + +impl WorkerContext { + async fn new(url: &str) -> Self { + Self { + client: tabby_webserver::api::create_client(url).await, + } + } + + async fn register(&self, kind: WorkerKind, args: &WorkerArgs) { + if let Err(err) = self.register_impl(kind, args).await { + warn!("Failed to register worker: {}", err) + } + } + + async fn register_impl(&self, kind: WorkerKind, args: &WorkerArgs) -> Result<()> { + let (cpu_info, cpu_count) = read_cpu_info(); + let cuda_devices = read_cuda_devices().unwrap_or_default(); + let worker = self + .client + .register_worker( + tracing_context(), + kind, + args.port as i32, + args.model.to_owned(), + args.device.to_string(), + ARCH.to_string(), + cpu_info, + cpu_count as i32, + cuda_devices, + args.token.clone(), + ) + .await??; + + info!("Worker alive at {}", worker.addr); + + Ok(()) } } - -async fn request_register_impl( - kind: WorkerKind, - url: String, - port: u16, - name: String, - device: String, - token: String, -) -> Result<()> { - let client = tabby_webserver::api::create_client(url).await; - let (cpu_info, cpu_count) = read_cpu_info(); - let cuda_devices = read_cuda_devices().unwrap_or_default(); - let worker = client - .register_worker( - tabby_webserver::api::tracing_context(), - kind, - port as i32, - name, - device, - ARCH.to_string(), - cpu_info, - cpu_count as i32, - cuda_devices, - token, - ) - .await??; - - info!("Worker alive at {}", worker.addr); - - Ok(()) -} diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index 4929cb8..e05bce2 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -1,5 +1,6 @@ use juniper::{GraphQLEnum, GraphQLObject}; use serde::{Deserialize, Serialize}; +use tabby_common::api::event::RawEventLogger; use thiserror::Error; use tokio_tungstenite::connect_async; @@ -45,14 +46,25 @@ pub trait Hub { cuda_devices: Vec, token: String, ) -> Result; + + async fn log_event(content: String); } pub fn tracing_context() -> tarpc::context::Context { tarpc::context::current() } -pub async fn create_client(addr: String) -> HubClient { +pub async fn create_client(addr: &str) -> HubClient { let addr = format!("ws://{}/hub", addr); let (socket, _) = connect_async(&addr).await.unwrap(); HubClient::new(Default::default(), WebSocketTransport::from(socket)).spawn() } + +impl RawEventLogger for HubClient { + fn log(&self, content: String) { + let context = tarpc::context::current(); + let client = self.clone(); + + tokio::spawn(async move { client.log_event(context, content).await }); + } +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 8c85cd4..54553b6 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -2,6 +2,7 @@ pub mod api; mod schema; pub use schema::create_schema; +use tabby_common::api::event::RawEventLogger; use tracing::error; use websocket::WebSocketTransport; @@ -26,9 +27,9 @@ use schema::Schema; use server::ServerContext; use tarpc::server::{BaseChannel, Channel}; -pub async fn attach_webserver(router: Router) -> Router { +pub async fn attach_webserver(router: Router, logger: Arc) -> Router { let conn = db::DbConn::new().await.unwrap(); - let ctx = Arc::new(ServerContext::new(conn)); + let ctx = Arc::new(ServerContext::new(conn, logger)); let schema = Arc::new(create_schema()); let app = Router::new() @@ -124,4 +125,8 @@ impl Hub for Arc { }; self.ctx.register_worker(worker).await } + + async fn log_event(self, _context: tarpc::context::Context, content: String) { + self.ctx.logger.log(content) + } } diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/server.rs index eca0c38..b108fa0 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/server.rs @@ -1,11 +1,12 @@ mod proxy; mod worker; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; use axum::{http::Request, middleware::Next, response::IntoResponse}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; +use tabby_common::api::event::RawEventLogger; use tracing::{info, warn}; use crate::{ @@ -18,15 +19,18 @@ pub struct ServerContext { completion: worker::WorkerGroup, chat: worker::WorkerGroup, db_conn: DbConn, + + pub logger: Arc, } impl ServerContext { - pub fn new(db_conn: DbConn) -> Self { + pub fn new(db_conn: DbConn, logger: Arc) -> Self { Self { client: Client::default(), completion: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(), db_conn, + logger, } }