refactor: extract routes/ to share routes between commands

extract-routes
Meng Zhang 2023-11-12 20:59:40 -08:00
parent bad87a99a2
commit 98c50f8050
10 changed files with 80 additions and 63 deletions

View File

@ -1,7 +1,7 @@
mod api; mod api;
mod download; mod download;
mod routes;
mod serve; mod serve;
mod services; mod services;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};

View File

@ -23,7 +23,7 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
) )
)] )]
#[instrument(skip(state, request))] #[instrument(skip(state, request))]
pub async fn completions( pub async fn chat_completions(
State(state): State<Arc<ChatService>>, State(state): State<Arc<ChatService>>,
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Response { ) -> Response {

View File

@ -0,0 +1,22 @@
use std::{sync::Arc};
use axum::{extract::State, Json};
use sysinfo::{SystemExt};
use utoipa::ToSchema;
use crate::services::health;
#[utoipa::path(
get,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
)
)]
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
Json(state.as_ref().clone())
}

View File

@ -0,0 +1,11 @@
mod chat;
mod completions;
mod events;
mod health;
mod search;
pub use chat::*;
pub use completions::*;
pub use events::*;
pub use health::*;
pub use search::*;

View File

@ -1,9 +1,3 @@
mod chat;
mod completions;
mod events;
mod health;
mod search;
use std::{ use std::{
fs, fs,
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
@ -23,15 +17,9 @@ use tracing::info;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
use self::health::HealthState;
use crate::{ use crate::{
api::{Hit, HitDocument, SearchResponse}, api, fatal, routes,
fatal, services::{chat, completions, health, model},
services::{
chat::ChatService,
completions::CompletionService,
model::{load_text_generation, PromptInfo},
},
}; };
#[derive(OpenApi)] #[derive(OpenApi)]
@ -51,24 +39,24 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers( servers(
(url = "/", description = "Server"), (url = "/", description = "Server"),
), ),
paths(events::log_event, completions::completions, chat::completions, health::health, search::search), paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
components(schemas( components(schemas(
events::LogEventRequest, routes::LogEventRequest,
crate::services::completions::CompletionRequest, completions::CompletionRequest,
crate::services::completions::CompletionResponse, completions::CompletionResponse,
crate::services::completions::Segments, completions::Segments,
crate::services::completions::Choice, completions::Choice,
crate::services::completions::Snippet, completions::Snippet,
crate::services::completions::DebugOptions, completions::DebugOptions,
crate::services::completions::DebugData, completions::DebugData,
crate::services::chat::ChatCompletionRequest, chat::ChatCompletionRequest,
crate::services::chat::Message, chat::Message,
crate::services::chat::ChatCompletionChunk, chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
SearchResponse, api::SearchResponse,
Hit, api::Hit,
HitDocument api::HitDocument
)) ))
)] )]
struct ApiDoc; struct ApiDoc;
@ -178,21 +166,22 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let completion_state = { let completion_state = {
let ( let (
engine, engine,
PromptInfo { model::PromptInfo {
prompt_template, .. prompt_template, ..
}, },
) = load_text_generation(&args.model, &args.device, args.parallelism).await; ) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); let state =
completions::CompletionService::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state) Arc::new(state)
}; };
let chat_state = if let Some(chat_model) = &args.chat_model { let chat_state = if let Some(chat_model) = &args.chat_model {
let (engine, PromptInfo { chat_template, .. }) = let (engine, model::PromptInfo { chat_template, .. }) =
load_text_generation(chat_model, &args.device, args.parallelism).await; model::load_text_generation(chat_model, &args.device, args.parallelism).await;
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
panic!("Chat model requires specifying prompt template"); panic!("Chat model requires specifying prompt template");
}; };
let state = ChatService::new(engine, chat_template); let state = chat::ChatService::new(engine, chat_template);
Some(Arc::new(state)) Some(Arc::new(state))
} else { } else {
None None
@ -200,17 +189,21 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let mut routers = vec![]; let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(args)); let health_state = Arc::new(health::HealthState::new(
&args.model,
args.chat_model.as_deref(),
&args.device,
));
routers.push({ routers.push({
Router::new() Router::new()
.route("/v1/events", routing::post(events::log_event)) .route("/v1/events", routing::post(routes::log_event))
.route( .route(
"/v1/health", "/v1/health",
routing::post(health::health).with_state(health_state.clone()), routing::post(routes::health).with_state(health_state.clone()),
) )
.route( .route(
"/v1/health", "/v1/health",
routing::get(health::health).with_state(health_state), routing::get(routes::health).with_state(health_state),
) )
}); });
@ -218,7 +211,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
Router::new() Router::new()
.route( .route(
"/v1/completions", "/v1/completions",
routing::post(completions::completions).with_state(completion_state), routing::post(routes::completions).with_state(completion_state),
) )
.layer(TimeoutLayer::new(Duration::from_secs( .layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout, config.server.completion_timeout,
@ -229,7 +222,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
routers.push({ routers.push({
Router::new().route( Router::new().route(
"/v1beta/chat/completions", "/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state), routing::post(routes::chat_completions).with_state(chat_state),
) )
}) })
} }
@ -237,7 +230,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
routers.push({ routers.push({
Router::new().route( Router::new().route(
"/v1beta/search", "/v1beta/search",
routing::get(search::search).with_state(code), routing::get(routes::search).with_state(code),
) )
}); });
@ -250,7 +243,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
} }
fn start_heartbeat(args: &ServeArgs) { fn start_heartbeat(args: &ServeArgs) {
let state = HealthState::new(args); let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device);
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
usage::capture("ServeHealth", &state).await; usage::capture("ServeHealth", &state).await;

View File

@ -1,12 +1,14 @@
use std::{env::consts::ARCH, sync::Arc}; use std::{env::consts::ARCH};
use anyhow::Result; use anyhow::Result;
use axum::{extract::State, Json};
use nvml_wrapper::Nvml; use nvml_wrapper::Nvml;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sysinfo::{CpuExt, System, SystemExt}; use sysinfo::{CpuExt, System, SystemExt};
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::serve::Device;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct HealthState { pub struct HealthState {
model: String, model: String,
@ -21,7 +23,7 @@ pub struct HealthState {
} }
impl HealthState { impl HealthState {
pub fn new(args: &super::ServeArgs) -> Self { pub fn new(model: &str, chat_model: Option<&str>, device: &Device) -> Self {
let (cpu_info, cpu_count) = read_cpu_info(); let (cpu_info, cpu_count) = read_cpu_info();
let cuda_devices = match read_cuda_devices() { let cuda_devices = match read_cuda_devices() {
@ -30,9 +32,9 @@ impl HealthState {
}; };
Self { Self {
model: args.model.clone(), model: model.to_owned(),
chat_model: args.chat_model.clone(), chat_model: chat_model.map(|x| x.to_owned()),
device: args.device.to_string(), device: device.to_string(),
arch: ARCH.to_string(), arch: ARCH.to_string(),
cpu_info, cpu_info,
cpu_count, cpu_count,
@ -90,15 +92,3 @@ impl Version {
} }
} }
} }
#[utoipa::path(
get,
path = "/v1/health",
tag = "v1",
responses(
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
)
)]
pub async fn health(State(state): State<Arc<HealthState>>) -> Json<HealthState> {
Json(state.as_ref().clone())
}

View File

@ -1,4 +1,5 @@
pub mod chat; pub mod chat;
pub mod code; pub mod code;
pub mod completions; pub mod completions;
pub mod health;
pub mod model; pub mod model;