From e0017cadec0153eb440c5957e563012e94a5d086 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 12 Nov 2023 22:24:20 -0800 Subject: [PATCH] refactor: extract routes/ to share routes between commands (#774) * refactor: extract routes/ to share routes between commands * refactor: extract events api * extract EventLogger service * lift api into sub packages * services completions -> completion * remove useless code * fix test --- Cargo.lock | 3 +- Cargo.toml | 2 +- crates/juniper-axum/src/extract.rs | 49 ---------- crates/juniper-axum/src/lib.rs | 49 ---------- crates/tabby-common/Cargo.toml | 2 - crates/tabby-common/src/lib.rs | 1 - crates/tabby/Cargo.toml | 1 + crates/tabby/src/api/event.rs | 58 ++++++++++++ crates/tabby/src/api/mod.rs | 5 +- crates/tabby/src/main.rs | 2 +- crates/tabby/src/{serve => routes}/chat.rs | 2 +- .../src/{serve => routes}/completions.rs | 2 +- crates/tabby/src/{serve => routes}/events.rs | 33 +++---- crates/tabby/src/routes/health.rs | 17 ++++ crates/tabby/src/routes/mod.rs | 11 +++ crates/tabby/src/{serve => routes}/search.rs | 2 +- crates/tabby/src/serve/mod.rs | 90 ++++++++++--------- crates/tabby/src/services/code.rs | 2 +- .../{completions.rs => completion.rs} | 28 +++--- .../completion_prompt.rs} | 2 +- .../events.rs => tabby/src/services/event.rs} | 55 +++--------- .../tabby/src/{serve => services}/health.rs | 25 ++---- crates/tabby/src/services/mod.rs | 4 +- 23 files changed, 193 insertions(+), 252 deletions(-) create mode 100644 crates/tabby/src/api/event.rs rename crates/tabby/src/{serve => routes}/chat.rs (97%) rename crates/tabby/src/{serve => routes}/completions.rs (89%) rename crates/tabby/src/{serve => routes}/events.rs (65%) create mode 100644 crates/tabby/src/routes/health.rs create mode 100644 crates/tabby/src/routes/mod.rs rename crates/tabby/src/{serve => routes}/search.rs (95%) rename crates/tabby/src/services/{completions.rs => completion.rs} (92%) rename crates/tabby/src/services/{completions/completions_prompt.rs => completion/completion_prompt.rs} (99%) rename crates/{tabby-common/src/events.rs => tabby/src/services/event.rs} (66%) rename crates/tabby/src/{serve => services}/health.rs (79%) diff --git a/Cargo.lock b/Cargo.lock index 81653fb..dabb7e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4282,6 +4282,7 @@ dependencies = [ "axum", "axum-streams", "axum-tracing-opentelemetry", + "chrono", "clap 4.4.7", "futures", "http-api-bindings", @@ -4324,7 +4325,6 @@ name = "tabby-common" version = "0.6.0-dev" dependencies = [ "anyhow", - "chrono", "filenamify", "lazy_static", "reqwest", @@ -4332,7 +4332,6 @@ dependencies = [ "serde-jsonlines", "serdeconv", "tantivy", - "tokio", "uuid 1.4.1", ] diff --git a/Cargo.toml b/Cargo.toml index 1a11930..236815e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,4 +39,4 @@ thiserror = "1.0.49" utoipa = "3.3" axum = "0.6" hyper = "0.14" -juniper = "0.15" \ No newline at end of file +juniper = "0.15" diff --git a/crates/juniper-axum/src/extract.rs b/crates/juniper-axum/src/extract.rs index d8fbab6..82207ca 100644 --- a/crates/juniper-axum/src/extract.rs +++ b/crates/juniper-axum/src/extract.rs @@ -16,55 +16,6 @@ use juniper::{ }; use serde::Deserialize; -/// Extractor for [`axum`] to extract a [`JuniperRequest`]. -/// -/// # Example -/// -/// ```rust -/// use std::sync::Arc; -/// -/// use axum::{routing::post, Extension, Json, Router}; -/// use juniper::{ -/// RootNode, EmptySubscription, EmptyMutation, graphql_object, -/// }; -/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse}; -/// -/// #[derive(Clone, Copy, Debug)] -/// pub struct Context; -/// -/// impl juniper::Context for Context {} -/// -/// #[derive(Clone, Copy, Debug)] -/// pub struct Query; -/// -/// #[graphql_object(context = Context)] -/// impl Query { -/// fn add(a: i32, b: i32) -> i32 { -/// a + b -/// } -/// } -/// -/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; -/// -/// let schema = Schema::new( -/// Query, -/// EmptyMutation::::new(), -/// EmptySubscription::::new() -/// ); -/// -/// let app: Router = Router::new() -/// .route("/graphql", post(graphql)) -/// .layer(Extension(Arc::new(schema))) -/// .layer(Extension(Context)); -/// -/// # #[axum::debug_handler] -/// async fn graphql( -/// Extension(schema): Extension>, -/// Extension(context): Extension, -/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request` -/// ) -> JuniperResponse { -/// JuniperResponse(req.execute(&*schema, &context).await) -/// } #[derive(Debug, PartialEq)] pub struct JuniperRequest(pub GraphQLBatchRequest) where diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs index 06d81e5..ae0b52a 100644 --- a/crates/juniper-axum/src/lib.rs +++ b/crates/juniper-axum/src/lib.rs @@ -15,55 +15,6 @@ pub trait FromStateAndClientAddr { fn build(state: S, client_addr: SocketAddr) -> C; } -/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing -/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one. -/// -/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need -/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s -/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a -/// > [`JuniperRequest`] (see its documentation for examples). -/// -/// # Example -/// -/// ```rust -/// use std::sync::Arc; -/// -/// use axum::{routing::post, Extension, Json, Router}; -/// use juniper::{ -/// RootNode, EmptySubscription, EmptyMutation, graphql_object, -/// }; -/// use juniper_axum::graphql; -/// -/// #[derive(Clone, Copy, Debug, Default)] -/// pub struct Context; -/// -/// impl juniper::Context for Context {} -/// -/// #[derive(Clone, Copy, Debug)] -/// pub struct Query; -/// -/// #[graphql_object(context = Context)] -/// impl Query { -/// fn add(a: i32, b: i32) -> i32 { -/// a + b -/// } -/// } -/// -/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; -/// -/// let schema = Schema::new( -/// Query, -/// EmptyMutation::::new(), -/// EmptySubscription::::new() -/// ); -/// -/// let app: Router = Router::new() -/// .route("/graphql", post(graphql::>)) -/// .layer(Extension(Arc::new(schema))); -/// ``` -/// -/// [`extract`]: axum::extract -/// [`Handler`]: axum::handler::Handler #[cfg_attr(text, axum::debug_handler)] pub async fn graphql( ConnectInfo(addr): ConnectInfo, diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index ccaac95..c222370 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -4,14 +4,12 @@ version = "0.6.0-dev" edition = "2021" [dependencies] -chrono = "0.4.26" filenamify = "0.1.0" lazy_static = { workspace = true } serde = { workspace = true } serdeconv = { workspace = true } serde-jsonlines = { workspace = true } reqwest = { workspace = true, features = [ "json" ] } -tokio = { workspace = true, features = ["rt", "macros"] } uuid = { version = "1.4.1", features = ["v4"] } tantivy.workspace = true anyhow.workspace = true diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 2458fef..ec17a42 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,5 +1,4 @@ pub mod config; -pub mod events; pub mod index; pub mod languages; pub mod path; diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 1d6450a..66d9e04 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -46,6 +46,7 @@ futures.workspace = true async-trait.workspace = true tabby-webserver = { path = "../../ee/tabby-webserver" } thiserror.workspace = true +chrono = "0.4.31" [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/api/event.rs b/crates/tabby/src/api/event.rs new file mode 100644 index 0000000..0eceba3 --- /dev/null +++ b/crates/tabby/src/api/event.rs @@ -0,0 +1,58 @@ +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct LogEventRequest { + /// Event type, should be `view` or `select`. + #[schema(example = "view")] + #[serde(rename = "type")] + pub event_type: String, + + pub completion_id: String, + + pub choice_index: u32, +} + +#[derive(Serialize)] +pub struct Choice<'a> { + pub index: u32, + pub text: &'a str, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +pub enum SelectKind { + Line, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Event<'a> { + View { + completion_id: &'a str, + choice_index: u32, + }, + Select { + completion_id: &'a str, + choice_index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + kind: Option, + }, + Completion { + completion_id: &'a str, + language: &'a str, + prompt: &'a str, + segments: &'a Option, + choices: Vec>, + user: Option<&'a str>, + }, +} +#[derive(Serialize)] +pub struct Segments { + pub prefix: String, + pub suffix: Option, +} + +pub trait EventLogger: Send + Sync { + fn log(&self, e: &Event); +} diff --git a/crates/tabby/src/api/mod.rs b/crates/tabby/src/api/mod.rs index 4c48d8b..cebf170 100644 --- a/crates/tabby/src/api/mod.rs +++ b/crates/tabby/src/api/mod.rs @@ -1,3 +1,2 @@ -mod code; - -pub use code::*; +pub mod code; +pub mod event; diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 11f10b2..4cfc093 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,7 +1,7 @@ mod api; mod download; +mod routes; mod serve; - mod services; use clap::{Parser, Subcommand}; diff --git a/crates/tabby/src/serve/chat.rs b/crates/tabby/src/routes/chat.rs similarity index 97% rename from crates/tabby/src/serve/chat.rs rename to crates/tabby/src/routes/chat.rs index 4b51683..1b54333 100644 --- a/crates/tabby/src/serve/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -23,7 +23,7 @@ use crate::services::chat::{ChatCompletionRequest, ChatService}; ) )] #[instrument(skip(state, request))] -pub async fn completions( +pub async fn chat_completions( State(state): State>, Json(request): Json, ) -> Response { diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/routes/completions.rs similarity index 89% rename from crates/tabby/src/serve/completions.rs rename to crates/tabby/src/routes/completions.rs index c5743f6..d394d18 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -4,7 +4,7 @@ use axum::{extract::State, Json}; use hyper::StatusCode; use tracing::{instrument, warn}; -use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService}; +use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService}; #[utoipa::path( post, diff --git a/crates/tabby/src/serve/events.rs b/crates/tabby/src/routes/events.rs similarity index 65% rename from crates/tabby/src/serve/events.rs rename to crates/tabby/src/routes/events.rs index 7c1b4dd..61e8037 100644 --- a/crates/tabby/src/serve/events.rs +++ b/crates/tabby/src/routes/events.rs @@ -1,22 +1,12 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; -use axum::{extract::Query, Json}; +use axum::{ + extract::{Query, State}, + Json, +}; use hyper::StatusCode; -use serde::{Deserialize, Serialize}; -use tabby_common::events::{self, SelectKind}; -use utoipa::ToSchema; -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct LogEventRequest { - /// Event type, should be `view` or `select`. - #[schema(example = "view")] - #[serde(rename = "type")] - event_type: String, - - completion_id: String, - - choice_index: u32, -} +use crate::api::event::{Event, EventLogger, LogEventRequest, SelectKind}; #[utoipa::path( post, @@ -30,22 +20,22 @@ pub struct LogEventRequest { ) )] pub async fn log_event( + State(logger): State>, Query(params): Query>, Json(request): Json, ) -> StatusCode { if request.event_type == "view" { - events::Event::View { + logger.log(&Event::View { completion_id: &request.completion_id, choice_index: request.choice_index, - } - .log(); + }); StatusCode::OK } else if request.event_type == "select" { let is_line = params .get("select_kind") .map(|x| x == "line") .unwrap_or(false); - events::Event::Select { + logger.log(&Event::Select { completion_id: &request.completion_id, choice_index: request.choice_index, kind: if is_line { @@ -53,8 +43,7 @@ pub async fn log_event( } else { None }, - } - .log(); + }); StatusCode::OK } else { StatusCode::BAD_REQUEST diff --git a/crates/tabby/src/routes/health.rs b/crates/tabby/src/routes/health.rs new file mode 100644 index 0000000..9483ac2 --- /dev/null +++ b/crates/tabby/src/routes/health.rs @@ -0,0 +1,17 @@ +use std::sync::Arc; + +use axum::{extract::State, Json}; + +use crate::services::health; + +#[utoipa::path( + get, + path = "/v1/health", + tag = "v1", + responses( + (status = 200, description = "Success", body = HealthState, content_type = "application/json"), + ) +)] +pub async fn health(State(state): State>) -> Json { + Json(state.as_ref().clone()) +} diff --git a/crates/tabby/src/routes/mod.rs b/crates/tabby/src/routes/mod.rs new file mode 100644 index 0000000..2b52910 --- /dev/null +++ b/crates/tabby/src/routes/mod.rs @@ -0,0 +1,11 @@ +mod chat; +mod completions; +mod events; +mod health; +mod search; + +pub use chat::*; +pub use completions::*; +pub use events::*; +pub use health::*; +pub use search::*; diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/routes/search.rs similarity index 95% rename from crates/tabby/src/serve/search.rs rename to crates/tabby/src/routes/search.rs index 5bcea1a..89932b1 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/routes/search.rs @@ -10,7 +10,7 @@ use serde::Deserialize; use tracing::{instrument, warn}; use utoipa::IntoParams; -use crate::api::{CodeSearch, CodeSearchError, SearchResponse}; +use crate::api::code::{CodeSearch, CodeSearchError, SearchResponse}; #[derive(Deserialize, IntoParams)] pub struct SearchQuery { diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 2a5d3ae..88d0686 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -1,9 +1,3 @@ -mod chat; -mod completions; -mod events; -mod health; -mod search; - use std::{ fs, net::{Ipv4Addr, SocketAddr}, @@ -23,15 +17,10 @@ use tracing::info; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use self::health::HealthState; use crate::{ - api::{Hit, HitDocument, SearchResponse}, - fatal, - services::{ - chat::ChatService, - completions::CompletionService, - model::{load_text_generation, PromptInfo}, - }, + api::{self}, + fatal, routes, + services::{chat, completion, event::create_event_logger, health, model}, }; #[derive(OpenApi)] @@ -51,24 +40,24 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "/", description = "Server"), ), - paths(events::log_event, completions::completions, chat::completions, health::health, search::search), + paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search), components(schemas( - events::LogEventRequest, - crate::services::completions::CompletionRequest, - crate::services::completions::CompletionResponse, - crate::services::completions::Segments, - crate::services::completions::Choice, - crate::services::completions::Snippet, - crate::services::completions::DebugOptions, - crate::services::completions::DebugData, - crate::services::chat::ChatCompletionRequest, - crate::services::chat::Message, - crate::services::chat::ChatCompletionChunk, + api::event::LogEventRequest, + completion::CompletionRequest, + completion::CompletionResponse, + completion::Segments, + completion::Choice, + completion::Snippet, + completion::DebugOptions, + completion::DebugData, + chat::ChatCompletionRequest, + chat::Message, + chat::ChatCompletionChunk, health::HealthState, health::Version, - SearchResponse, - Hit, - HitDocument + api::code::SearchResponse, + api::code::Hit, + api::code::HitDocument )) )] struct ApiDoc; @@ -174,25 +163,31 @@ async fn load_model(args: &ServeArgs) { } async fn api_router(args: &ServeArgs, config: &Config) -> Router { + let logger = Arc::new(create_event_logger()); let code = Arc::new(crate::services::code::create_code_search()); let completion_state = { let ( engine, - PromptInfo { + model::PromptInfo { prompt_template, .. }, - ) = load_text_generation(&args.model, &args.device, args.parallelism).await; - let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); + ) = model::load_text_generation(&args.model, &args.device, args.parallelism).await; + let state = completion::CompletionService::new( + engine.clone(), + code.clone(), + logger.clone(), + prompt_template, + ); Arc::new(state) }; let chat_state = if let Some(chat_model) = &args.chat_model { - let (engine, PromptInfo { chat_template, .. }) = - load_text_generation(chat_model, &args.device, args.parallelism).await; + let (engine, model::PromptInfo { chat_template, .. }) = + model::load_text_generation(chat_model, &args.device, args.parallelism).await; let Some(chat_template) = chat_template else { panic!("Chat model requires specifying prompt template"); }; - let state = ChatService::new(engine, chat_template); + let state = chat::ChatService::new(engine, chat_template); Some(Arc::new(state)) } else { None @@ -200,17 +195,24 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { let mut routers = vec![]; - let health_state = Arc::new(health::HealthState::new(args)); + let health_state = Arc::new(health::HealthState::new( + &args.model, + args.chat_model.as_deref(), + &args.device, + )); routers.push({ Router::new() - .route("/v1/events", routing::post(events::log_event)) .route( - "/v1/health", - routing::post(health::health).with_state(health_state.clone()), + "/v1/events", + routing::post(routes::log_event).with_state(logger), ) .route( "/v1/health", - routing::get(health::health).with_state(health_state), + routing::post(routes::health).with_state(health_state.clone()), + ) + .route( + "/v1/health", + routing::get(routes::health).with_state(health_state), ) }); @@ -218,7 +220,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { Router::new() .route( "/v1/completions", - routing::post(completions::completions).with_state(completion_state), + routing::post(routes::completions).with_state(completion_state), ) .layer(TimeoutLayer::new(Duration::from_secs( config.server.completion_timeout, @@ -229,7 +231,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { routers.push({ Router::new().route( "/v1beta/chat/completions", - routing::post(chat::completions).with_state(chat_state), + routing::post(routes::chat_completions).with_state(chat_state), ) }) } @@ -237,7 +239,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { routers.push({ Router::new().route( "/v1beta/search", - routing::get(search::search).with_state(code), + routing::get(routes::search).with_state(code), ) }); @@ -250,7 +252,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { } fn start_heartbeat(args: &ServeArgs) { - let state = HealthState::new(args); + let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device); tokio::spawn(async move { loop { usage::capture("ServeHealth", &state).await; diff --git a/crates/tabby/src/services/code.rs b/crates/tabby/src/services/code.rs index 71c1ea6..1abc586 100644 --- a/crates/tabby/src/services/code.rs +++ b/crates/tabby/src/services/code.rs @@ -16,7 +16,7 @@ use tantivy::{ use tokio::{sync::Mutex, time::sleep}; use tracing::{debug, log::info}; -use crate::api::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}; +use crate::api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}; struct CodeSearchImpl { reader: IndexReader, diff --git a/crates/tabby/src/services/completions.rs b/crates/tabby/src/services/completion.rs similarity index 92% rename from crates/tabby/src/services/completions.rs rename to crates/tabby/src/services/completion.rs index 80bf2b8..8c76ed8 100644 --- a/crates/tabby/src/services/completions.rs +++ b/crates/tabby/src/services/completion.rs @@ -1,15 +1,19 @@ -mod completions_prompt; +mod completion_prompt; use std::sync::Arc; use serde::{Deserialize, Serialize}; -use tabby_common::{events, languages::get_language}; +use tabby_common::languages::get_language; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use thiserror::Error; use tracing::debug; use utoipa::ToSchema; -use crate::api::CodeSearch; +use crate::api::{ + self, + code::CodeSearch, + event::{Event, EventLogger}, +}; #[derive(Error, Debug)] pub enum CompletionError { @@ -95,9 +99,9 @@ pub struct Segments { suffix: Option, } -impl From for events::Segments { +impl From for api::event::Segments { fn from(val: Segments) -> Self { - events::Segments { + Self { prefix: val.prefix, suffix: val.suffix, } @@ -157,18 +161,21 @@ pub struct DebugData { pub struct CompletionService { engine: Arc, - prompt_builder: completions_prompt::PromptBuilder, + logger: Arc, + prompt_builder: completion_prompt::PromptBuilder, } impl CompletionService { pub fn new( engine: Arc, code: Arc, + logger: Arc, prompt_template: Option, ) -> Self { Self { engine, - prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)), + prompt_builder: completion_prompt::PromptBuilder::new(prompt_template, Some(code)), + logger, } } @@ -226,18 +233,17 @@ impl CompletionService { let text = self.engine.generate(&prompt, options).await; let segments = segments.map(|s| s.into()); - events::Event::Completion { + self.logger.log(&Event::Completion { completion_id: &completion_id, language: &language, prompt: &prompt, segments: &segments, - choices: vec![events::Choice { + choices: vec![api::event::Choice { index: 0, text: &text, }], user: request.user.as_deref(), - } - .log(); + }); let debug_data = request .debug_options diff --git a/crates/tabby/src/services/completions/completions_prompt.rs b/crates/tabby/src/services/completion/completion_prompt.rs similarity index 99% rename from crates/tabby/src/services/completions/completions_prompt.rs rename to crates/tabby/src/services/completion/completion_prompt.rs index 4a0e79b..627916f 100644 --- a/crates/tabby/src/services/completions/completions_prompt.rs +++ b/crates/tabby/src/services/completion/completion_prompt.rs @@ -8,7 +8,7 @@ use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::api::{CodeSearch, CodeSearchError}; +use crate::api::code::{CodeSearch, CodeSearchError}; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; diff --git a/crates/tabby-common/src/events.rs b/crates/tabby/src/services/event.rs similarity index 66% rename from crates/tabby-common/src/events.rs rename to crates/tabby/src/services/event.rs index dacec9b..9c77391 100644 --- a/crates/tabby-common/src/events.rs +++ b/crates/tabby/src/services/event.rs @@ -7,17 +7,20 @@ use std::{ use chrono::Utc; use lazy_static::lazy_static; use serde::Serialize; +use tabby_common::path; use tokio::{ sync::mpsc::{unbounded_channel, UnboundedSender}, time::{self}, }; +use crate::api::event::{Event, EventLogger}; + lazy_static! { static ref WRITER: UnboundedSender = { let (tx, mut rx) = unbounded_channel::(); tokio::spawn(async move { - let events_dir = crate::path::events_dir(); + let events_dir = path::events_dir(); std::fs::create_dir_all(events_dir.as_path()).ok(); let now = Utc::now(); @@ -53,45 +56,7 @@ lazy_static! { }; } -#[derive(Serialize)] -pub struct Choice<'a> { - pub index: u32, - pub text: &'a str, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -pub enum SelectKind { - Line, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -pub enum Event<'a> { - View { - completion_id: &'a str, - choice_index: u32, - }, - Select { - completion_id: &'a str, - choice_index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - kind: Option, - }, - Completion { - completion_id: &'a str, - language: &'a str, - prompt: &'a str, - segments: &'a Option, - choices: Vec>, - user: Option<&'a str>, - }, -} -#[derive(Serialize)] -pub struct Segments { - pub prefix: String, - pub suffix: Option, -} +struct EventService; #[derive(Serialize)] struct Log<'a> { @@ -99,11 +64,11 @@ struct Log<'a> { event: &'a Event<'a>, } -impl Event<'_> { - pub fn log(&self) { +impl EventLogger for EventService { + fn log(&self, e: &Event) { let content = serdeconv::to_json_string(&Log { ts: timestamp(), - event: self, + event: e, }) .unwrap(); @@ -119,3 +84,7 @@ fn timestamp() -> u128 { .expect("Time went backwards") .as_millis() } + +pub fn create_event_logger() -> impl EventLogger { + EventService +} diff --git a/crates/tabby/src/serve/health.rs b/crates/tabby/src/services/health.rs similarity index 79% rename from crates/tabby/src/serve/health.rs rename to crates/tabby/src/services/health.rs index 0a5e0bf..7db4cf4 100644 --- a/crates/tabby/src/serve/health.rs +++ b/crates/tabby/src/services/health.rs @@ -1,12 +1,13 @@ -use std::{env::consts::ARCH, sync::Arc}; +use std::env::consts::ARCH; use anyhow::Result; -use axum::{extract::State, Json}; use nvml_wrapper::Nvml; use serde::{Deserialize, Serialize}; use sysinfo::{CpuExt, System, SystemExt}; use utoipa::ToSchema; +use crate::serve::Device; + #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct HealthState { model: String, @@ -21,7 +22,7 @@ pub struct HealthState { } impl HealthState { - pub fn new(args: &super::ServeArgs) -> Self { + pub fn new(model: &str, chat_model: Option<&str>, device: &Device) -> Self { let (cpu_info, cpu_count) = read_cpu_info(); let cuda_devices = match read_cuda_devices() { @@ -30,9 +31,9 @@ impl HealthState { }; Self { - model: args.model.clone(), - chat_model: args.chat_model.clone(), - device: args.device.to_string(), + model: model.to_owned(), + chat_model: chat_model.map(|x| x.to_owned()), + device: device.to_string(), arch: ARCH.to_string(), cpu_info, cpu_count, @@ -90,15 +91,3 @@ impl Version { } } } - -#[utoipa::path( - get, - path = "/v1/health", - tag = "v1", - responses( - (status = 200, description = "Success", body = HealthState, content_type = "application/json"), - ) -)] -pub async fn health(State(state): State>) -> Json { - Json(state.as_ref().clone()) -} diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index 9ef0b4d..e83cb06 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -1,4 +1,6 @@ pub mod chat; pub mod code; -pub mod completions; +pub mod completion; +pub mod event; +pub mod health; pub mod model;