refactor: extract routes/ to share routes between commands (#774)

* refactor: extract routes/ to share routes between commands

* refactor: extract events api

* extract EventLogger service

* lift api into sub packages

* services completions -> completion

* remove useless code

* fix test
release-fix-intellij-update-support-version-range
Meng Zhang 2023-11-12 22:24:20 -08:00 committed by GitHub
parent 63c7da4f96
commit e0017cadec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 193 additions and 252 deletions

3
Cargo.lock generated
View File

@ -4282,6 +4282,7 @@ dependencies = [
"axum", "axum",
"axum-streams", "axum-streams",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"chrono",
"clap 4.4.7", "clap 4.4.7",
"futures", "futures",
"http-api-bindings", "http-api-bindings",
@ -4324,7 +4325,6 @@ name = "tabby-common"
version = "0.6.0-dev" version = "0.6.0-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"chrono",
"filenamify", "filenamify",
"lazy_static", "lazy_static",
"reqwest", "reqwest",
@ -4332,7 +4332,6 @@ dependencies = [
"serde-jsonlines", "serde-jsonlines",
"serdeconv", "serdeconv",
"tantivy", "tantivy",
"tokio",
"uuid 1.4.1", "uuid 1.4.1",
] ]

View File

@ -39,4 +39,4 @@ thiserror = "1.0.49"
utoipa = "3.3" utoipa = "3.3"
axum = "0.6" axum = "0.6"
hyper = "0.14" hyper = "0.14"
juniper = "0.15" juniper = "0.15"

View File

@ -16,55 +16,6 @@ use juniper::{
}; };
use serde::Deserialize; use serde::Deserialize;
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
/// fn add(a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
/// Query,
/// EmptyMutation::<Context>::new(),
/// EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
/// .route("/graphql", post(graphql))
/// .layer(Extension(Arc::new(schema)))
/// .layer(Extension(Context));
///
/// # #[axum::debug_handler]
/// async fn graphql(
/// Extension(schema): Extension<Arc<Schema>>,
/// Extension(context): Extension<Context>,
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
/// ) -> JuniperResponse {
/// JuniperResponse(req.execute(&*schema, &context).await)
/// }
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>) pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where where

View File

@ -15,55 +15,6 @@ pub trait FromStateAndClientAddr<C, S> {
fn build(state: S, client_addr: SocketAddr) -> C; fn build(state: S, client_addr: SocketAddr) -> C;
} }
/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing
/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one.
///
/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need
/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s
/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a
/// > [`JuniperRequest`] (see its documentation for examples).
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::graphql;
///
/// #[derive(Clone, Copy, Debug, Default)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
/// fn add(a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
/// Query,
/// EmptyMutation::<Context>::new(),
/// EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
/// .route("/graphql", post(graphql::<Arc<Schema>>))
/// .layer(Extension(Arc::new(schema)));
/// ```
///
/// [`extract`]: axum::extract
/// [`Handler`]: axum::handler::Handler
#[cfg_attr(text, axum::debug_handler)] #[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S, C>( pub async fn graphql<S, C>(
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,

View File

@ -4,14 +4,12 @@ version = "0.6.0-dev"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
chrono = "0.4.26"
filenamify = "0.1.0" filenamify = "0.1.0"
lazy_static = { workspace = true } lazy_static = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serdeconv = { workspace = true } serdeconv = { workspace = true }
serde-jsonlines = { workspace = true } serde-jsonlines = { workspace = true }
reqwest = { workspace = true, features = [ "json" ] } reqwest = { workspace = true, features = [ "json" ] }
tokio = { workspace = true, features = ["rt", "macros"] }
uuid = { version = "1.4.1", features = ["v4"] } uuid = { version = "1.4.1", features = ["v4"] }
tantivy.workspace = true tantivy.workspace = true
anyhow.workspace = true anyhow.workspace = true

View File

@ -1,5 +1,4 @@
pub mod config; pub mod config;
pub mod events;
pub mod index; pub mod index;
pub mod languages; pub mod languages;
pub mod path; pub mod path;

View File

@ -46,6 +46,7 @@ futures.workspace = true
async-trait.workspace = true async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" } tabby-webserver = { path = "../../ee/tabby-webserver" }
thiserror.workspace = true thiserror.workspace = true
chrono = "0.4.31"
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"

View File

@ -0,0 +1,58 @@
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct LogEventRequest {
/// Event type, should be `view` or `select`.
#[schema(example = "view")]
#[serde(rename = "type")]
pub event_type: String,
pub completion_id: String,
pub choice_index: u32,
}
#[derive(Serialize)]
pub struct Choice<'a> {
pub index: u32,
pub text: &'a str,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
pub enum SelectKind {
Line,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Event<'a> {
View {
completion_id: &'a str,
choice_index: u32,
},
Select {
completion_id: &'a str,
choice_index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
kind: Option<SelectKind>,
},
Completion {
completion_id: &'a str,
language: &'a str,
prompt: &'a str,
segments: &'a Option<Segments>,
choices: Vec<Choice<'a>>,
user: Option<&'a str>,
},
}
#[derive(Serialize)]
pub struct Segments {
pub prefix: String,
pub suffix: Option<String>,
}
pub trait EventLogger: Send + Sync {
fn log(&self, e: &Event);
}

View File

@ -1,3 +1,2 @@
mod code; pub mod code;
pub mod event;
pub use code::*;

View File

@ -1,7 +1,7 @@
mod api; mod api;
mod download; mod download;
mod routes;
mod serve; mod serve;
mod services; mod services;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};

View File

@ -23,7 +23,7 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
) )
)] )]
#[instrument(skip(state, request))] #[instrument(skip(state, request))]
pub async fn completions( pub async fn chat_completions(
State(state): State<Arc<ChatService>>, State(state): State<Arc<ChatService>>,
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Response { ) -> Response {

View File

@ -4,7 +4,7 @@ use axum::{extract::State, Json};
use hyper::StatusCode; use hyper::StatusCode;
use tracing::{instrument, warn}; use tracing::{instrument, warn};
use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService}; use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService};
#[utoipa::path( #[utoipa::path(
post, post,

View File

@ -1,22 +1,12 @@
use std::collections::HashMap; use std::{collections::HashMap, sync::Arc};
use axum::{extract::Query, Json}; use axum::{
extract::{Query, State},
Json,
};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::events::{self, SelectKind};
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] use crate::api::event::{Event, EventLogger, LogEventRequest, SelectKind};
pub struct LogEventRequest {
/// Event type, should be `view` or `select`.
#[schema(example = "view")]
#[serde(rename = "type")]
event_type: String,
completion_id: String,
choice_index: u32,
}
#[utoipa::path( #[utoipa::path(
post, post,
@ -30,22 +20,22 @@ pub struct LogEventRequest {
) )
)] )]
pub async fn log_event( pub async fn log_event(
State(logger): State<Arc<dyn EventLogger>>,
Query(params): Query<HashMap<String, String>>, Query(params): Query<HashMap<String, String>>,
Json(request): Json<LogEventRequest>, Json(request): Json<LogEventRequest>,
) -> StatusCode { ) -> StatusCode {
if request.event_type == "view" { if request.event_type == "view" {
events::Event::View { logger.log(&Event::View {
completion_id: &request.completion_id, completion_id: &request.completion_id,
choice_index: request.choice_index, choice_index: request.choice_index,
} });
.log();
StatusCode::OK StatusCode::OK
} else if request.event_type == "select" { } else if request.event_type == "select" {
let is_line = params let is_line = params
.get("select_kind") .get("select_kind")
.map(|x| x == "line") .map(|x| x == "line")
.unwrap_or(false); .unwrap_or(false);
events::Event::Select { logger.log(&Event::Select {
completion_id: &request.completion_id, completion_id: &request.completion_id,
choice_index: request.choice_index, choice_index: request.choice_index,
kind: if is_line { kind: if is_line {
@ -53,8 +43,7 @@ pub async fn log_event(
} else { } else {
None None
}, },
} });
.log();
StatusCode::OK StatusCode::OK
} else { } else {
StatusCode::BAD_REQUEST StatusCode::BAD_REQUEST

View File

@ -0,0 +1,17 @@
use std::sync::Arc;
use axum::{extract::State, Json};
use crate::services::health;
#[utoipa::path(
get,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
)
)]
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
Json(state.as_ref().clone())
}

View File

@ -0,0 +1,11 @@
mod chat;
mod completions;
mod events;
mod health;
mod search;
pub use chat::*;
pub use completions::*;
pub use events::*;
pub use health::*;
pub use search::*;

View File

@ -10,7 +10,7 @@ use serde::Deserialize;
use tracing::{instrument, warn}; use tracing::{instrument, warn};
use utoipa::IntoParams; use utoipa::IntoParams;
use crate::api::{CodeSearch, CodeSearchError, SearchResponse}; use crate::api::code::{CodeSearch, CodeSearchError, SearchResponse};
#[derive(Deserialize, IntoParams)] #[derive(Deserialize, IntoParams)]
pub struct SearchQuery { pub struct SearchQuery {

View File

@ -1,9 +1,3 @@
mod chat;
mod completions;
mod events;
mod health;
mod search;
use std::{ use std::{
fs, fs,
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
@ -23,15 +17,10 @@ use tracing::info;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use self::health::HealthState;
use crate::{ use crate::{
api::{Hit, HitDocument, SearchResponse}, api::{self},
fatal, fatal, routes,
services::{ services::{chat, completion, event::create_event_logger, health, model},
chat::ChatService,
completions::CompletionService,
model::{load_text_generation, PromptInfo},
},
}; };
#[derive(OpenApi)] #[derive(OpenApi)]
@ -51,24 +40,24 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers( servers(
(url = "/", description = "Server"), (url = "/", description = "Server"),
), ),
paths(events::log_event, completions::completions, chat::completions, health::health, search::search), paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
components(schemas( components(schemas(
events::LogEventRequest, api::event::LogEventRequest,
crate::services::completions::CompletionRequest, completion::CompletionRequest,
crate::services::completions::CompletionResponse, completion::CompletionResponse,
crate::services::completions::Segments, completion::Segments,
crate::services::completions::Choice, completion::Choice,
crate::services::completions::Snippet, completion::Snippet,
crate::services::completions::DebugOptions, completion::DebugOptions,
crate::services::completions::DebugData, completion::DebugData,
crate::services::chat::ChatCompletionRequest, chat::ChatCompletionRequest,
crate::services::chat::Message, chat::Message,
crate::services::chat::ChatCompletionChunk, chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
SearchResponse, api::code::SearchResponse,
Hit, api::code::Hit,
HitDocument api::code::HitDocument
)) ))
)] )]
struct ApiDoc; struct ApiDoc;
@ -174,25 +163,31 @@ async fn load_model(args: &ServeArgs) {
} }
async fn api_router(args: &ServeArgs, config: &Config) -> Router { async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let logger = Arc::new(create_event_logger());
let code = Arc::new(crate::services::code::create_code_search()); let code = Arc::new(crate::services::code::create_code_search());
let completion_state = { let completion_state = {
let ( let (
engine, engine,
PromptInfo { model::PromptInfo {
prompt_template, .. prompt_template, ..
}, },
) = load_text_generation(&args.model, &args.device, args.parallelism).await; ) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); let state = completion::CompletionService::new(
engine.clone(),
code.clone(),
logger.clone(),
prompt_template,
);
Arc::new(state) Arc::new(state)
}; };
let chat_state = if let Some(chat_model) = &args.chat_model { let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, PromptInfo { chat_template, .. }) = let (engine, model::PromptInfo { chat_template, .. }) =
load_text_generation(chat_model, &args.device, args.parallelism).await; model::load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
let state = ChatService::new(engine, chat_template); let state = chat::ChatService::new(engine, chat_template);
Some(Arc::new(state)) Some(Arc::new(state))
} else { } else {
None None
@ -200,17 +195,24 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let mut routers = vec![]; let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(args)); let health_state = Arc::new(health::HealthState::new(
&args.model,
args.chat_model.as_deref(),
&args.device,
));
routers.push({ routers.push({
Router::new() Router::new()
.route("/v1/events", routing::post(events::log_event))
.route( .route(
"/v1/health", "/v1/events",
routing::post(health::health).with_state(health_state.clone()), routing::post(routes::log_event).with_state(logger),
) )
.route( .route(
"/v1/health", "/v1/health",
routing::get(health::health).with_state(health_state), routing::post(routes::health).with_state(health_state.clone()),
)
.route(
"/v1/health",
routing::get(routes::health).with_state(health_state),
) )
}); });
@ -218,7 +220,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
Router::new() Router::new()
.route( .route(
"/v1/completions", "/v1/completions",
routing::post(completions::completions).with_state(completion_state), routing::post(routes::completions).with_state(completion_state),
) )
.layer(TimeoutLayer::new(Duration::from_secs( .layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout, config.server.completion_timeout,
@ -229,7 +231,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
routers.push({ routers.push({
Router::new().route( Router::new().route(
"/v1beta/chat/completions", "/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state), routing::post(routes::chat_completions).with_state(chat_state),
) )
}) })
} }
@ -237,7 +239,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
routers.push({ routers.push({
Router::new().route( Router::new().route(
"/v1beta/search", "/v1beta/search",
routing::get(search::search).with_state(code), routing::get(routes::search).with_state(code),
) )
}); });
@ -250,7 +252,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
} }
fn start_heartbeat(args: &ServeArgs) { fn start_heartbeat(args: &ServeArgs) {
let state = HealthState::new(args); let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device);
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
usage::capture("ServeHealth", &state).await; usage::capture("ServeHealth", &state).await;

View File

@ -16,7 +16,7 @@ use tantivy::{
use tokio::{sync::Mutex, time::sleep}; use tokio::{sync::Mutex, time::sleep};
use tracing::{debug, log::info}; use tracing::{debug, log::info};
use crate::api::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}; use crate::api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse};
struct CodeSearchImpl { struct CodeSearchImpl {
reader: IndexReader, reader: IndexReader,

View File

@ -1,15 +1,19 @@
mod completions_prompt; mod completion_prompt;
use std::sync::Arc; use std::sync::Arc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::{events, languages::get_language}; use tabby_common::languages::get_language;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use thiserror::Error; use thiserror::Error;
use tracing::debug; use tracing::debug;
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::api::CodeSearch; use crate::api::{
self,
code::CodeSearch,
event::{Event, EventLogger},
};
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CompletionError { pub enum CompletionError {
@ -95,9 +99,9 @@ pub struct Segments {
suffix: Option<String>, suffix: Option<String>,
} }
impl From<Segments> for events::Segments { impl From<Segments> for api::event::Segments {
fn from(val: Segments) -> Self { fn from(val: Segments) -> Self {
events::Segments { Self {
prefix: val.prefix, prefix: val.prefix,
suffix: val.suffix, suffix: val.suffix,
} }
@ -157,18 +161,21 @@ pub struct DebugData {
pub struct CompletionService { pub struct CompletionService {
engine: Arc<dyn TextGeneration>, engine: Arc<dyn TextGeneration>,
prompt_builder: completions_prompt::PromptBuilder, logger: Arc<dyn EventLogger>,
prompt_builder: completion_prompt::PromptBuilder,
} }
impl CompletionService { impl CompletionService {
pub fn new( pub fn new(
engine: Arc<dyn TextGeneration>, engine: Arc<dyn TextGeneration>,
code: Arc<dyn CodeSearch>, code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
prompt_template: Option<String>, prompt_template: Option<String>,
) -> Self { ) -> Self {
Self { Self {
engine, engine,
prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)), prompt_builder: completion_prompt::PromptBuilder::new(prompt_template, Some(code)),
logger,
} }
} }
@ -226,18 +233,17 @@ impl CompletionService {
let text = self.engine.generate(&prompt, options).await; let text = self.engine.generate(&prompt, options).await;
let segments = segments.map(|s| s.into()); let segments = segments.map(|s| s.into());
events::Event::Completion { self.logger.log(&Event::Completion {
completion_id: &completion_id, completion_id: &completion_id,
language: &language, language: &language,
prompt: &prompt, prompt: &prompt,
segments: &segments, segments: &segments,
choices: vec![events::Choice { choices: vec![api::event::Choice {
index: 0, index: 0,
text: &text, text: &text,
}], }],
user: request.user.as_deref(), user: request.user.as_deref(),
} });
.log();
let debug_data = request let debug_data = request
.debug_options .debug_options

View File

@ -8,7 +8,7 @@ use textdistance::Algorithm;
use tracing::warn; use tracing::warn;
use super::{Segments, Snippet}; use super::{Segments, Snippet};
use crate::api::{CodeSearch, CodeSearchError}; use crate::api::code::{CodeSearch, CodeSearchError};
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;

View File

@ -7,17 +7,20 @@ use std::{
use chrono::Utc; use chrono::Utc;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::Serialize; use serde::Serialize;
use tabby_common::path;
use tokio::{ use tokio::{
sync::mpsc::{unbounded_channel, UnboundedSender}, sync::mpsc::{unbounded_channel, UnboundedSender},
time::{self}, time::{self},
}; };
use crate::api::event::{Event, EventLogger};
lazy_static! { lazy_static! {
static ref WRITER: UnboundedSender<String> = { static ref WRITER: UnboundedSender<String> = {
let (tx, mut rx) = unbounded_channel::<String>(); let (tx, mut rx) = unbounded_channel::<String>();
tokio::spawn(async move { tokio::spawn(async move {
let events_dir = crate::path::events_dir(); let events_dir = path::events_dir();
std::fs::create_dir_all(events_dir.as_path()).ok(); std::fs::create_dir_all(events_dir.as_path()).ok();
let now = Utc::now(); let now = Utc::now();
@ -53,45 +56,7 @@ lazy_static! {
}; };
} }
#[derive(Serialize)] struct EventService;
pub struct Choice<'a> {
pub index: u32,
pub text: &'a str,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
pub enum SelectKind {
Line,
}
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Event<'a> {
View {
completion_id: &'a str,
choice_index: u32,
},
Select {
completion_id: &'a str,
choice_index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
kind: Option<SelectKind>,
},
Completion {
completion_id: &'a str,
language: &'a str,
prompt: &'a str,
segments: &'a Option<Segments>,
choices: Vec<Choice<'a>>,
user: Option<&'a str>,
},
}
#[derive(Serialize)]
pub struct Segments {
pub prefix: String,
pub suffix: Option<String>,
}
#[derive(Serialize)] #[derive(Serialize)]
struct Log<'a> { struct Log<'a> {
@ -99,11 +64,11 @@ struct Log<'a> {
event: &'a Event<'a>, event: &'a Event<'a>,
} }
impl Event<'_> { impl EventLogger for EventService {
pub fn log(&self) { fn log(&self, e: &Event) {
let content = serdeconv::to_json_string(&Log { let content = serdeconv::to_json_string(&Log {
ts: timestamp(), ts: timestamp(),
event: self, event: e,
}) })
.unwrap(); .unwrap();
@ -119,3 +84,7 @@ fn timestamp() -> u128 {
.expect("Time went backwards") .expect("Time went backwards")
.as_millis() .as_millis()
} }
pub fn create_event_logger() -> impl EventLogger {
EventService
}

View File

@ -1,12 +1,13 @@
use std::{env::consts::ARCH, sync::Arc}; use std::env::consts::ARCH;
use anyhow::Result; use anyhow::Result;
use axum::{extract::State, Json};
use nvml_wrapper::Nvml; use nvml_wrapper::Nvml;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sysinfo::{CpuExt, System, SystemExt}; use sysinfo::{CpuExt, System, SystemExt};
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::serve::Device;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct HealthState { pub struct HealthState {
model: String, model: String,
@ -21,7 +22,7 @@ pub struct HealthState {
} }
impl HealthState { impl HealthState {
pub fn new(args: &super::ServeArgs) -> Self { pub fn new(model: &str, chat_model: Option<&str>, device: &Device) -> Self {
let (cpu_info, cpu_count) = read_cpu_info(); let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = match read_cuda_devices() { let cuda_devices = match read_cuda_devices() {
@ -30,9 +31,9 @@ impl HealthState {
}; };
Self { Self {
model: args.model.clone(), model: model.to_owned(),
chat_model: args.chat_model.clone(), chat_model: chat_model.map(|x| x.to_owned()),
device: args.device.to_string(), device: device.to_string(),
arch: ARCH.to_string(), arch: ARCH.to_string(),
cpu_info, cpu_info,
cpu_count, cpu_count,
@ -90,15 +91,3 @@ impl Version {
} }
} }
} }
#[utoipa::path(
get,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
)
)]
pub async fn health(State(state): State<Arc<HealthState>>) -> Json<HealthState> {
Json(state.as_ref().clone())
}

View File

@ -1,4 +1,6 @@
pub mod chat; pub mod chat;
pub mod code; pub mod code;
pub mod completions; pub mod completion;
pub mod event;
pub mod health;
pub mod model; pub mod model;