refactor: extract routes/ to share routes between commands
parent
bad87a99a2
commit
98c50f8050
|
|
@ -1,7 +1,7 @@
|
|||
mod api;
|
||||
mod download;
|
||||
mod routes;
|
||||
mod serve;
|
||||
|
||||
mod services;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
|
|||
)
|
||||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn completions(
|
||||
pub async fn chat_completions(
|
||||
State(state): State<Arc<ChatService>>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
use std::{sync::Arc};
|
||||
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
|
||||
|
||||
use sysinfo::{SystemExt};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::services::health;
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/v1/health",
|
||||
tag = "v1",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
||||
)
|
||||
)]
|
||||
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
|
||||
Json(state.as_ref().clone())
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
mod chat;
|
||||
mod completions;
|
||||
mod events;
|
||||
mod health;
|
||||
mod search;
|
||||
|
||||
pub use chat::*;
|
||||
pub use completions::*;
|
||||
pub use events::*;
|
||||
pub use health::*;
|
||||
pub use search::*;
|
||||
|
|
@ -1,9 +1,3 @@
|
|||
mod chat;
|
||||
mod completions;
|
||||
mod events;
|
||||
mod health;
|
||||
mod search;
|
||||
|
||||
use std::{
|
||||
fs,
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
|
|
@ -23,15 +17,9 @@ use tracing::info;
|
|||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use self::health::HealthState;
|
||||
use crate::{
|
||||
api::{Hit, HitDocument, SearchResponse},
|
||||
fatal,
|
||||
services::{
|
||||
chat::ChatService,
|
||||
completions::CompletionService,
|
||||
model::{load_text_generation, PromptInfo},
|
||||
},
|
||||
api, fatal, routes,
|
||||
services::{chat, completions, health, model},
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
|
|
@ -51,24 +39,24 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
servers(
|
||||
(url = "/", description = "Server"),
|
||||
),
|
||||
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
||||
paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
|
||||
components(schemas(
|
||||
events::LogEventRequest,
|
||||
crate::services::completions::CompletionRequest,
|
||||
crate::services::completions::CompletionResponse,
|
||||
crate::services::completions::Segments,
|
||||
crate::services::completions::Choice,
|
||||
crate::services::completions::Snippet,
|
||||
crate::services::completions::DebugOptions,
|
||||
crate::services::completions::DebugData,
|
||||
crate::services::chat::ChatCompletionRequest,
|
||||
crate::services::chat::Message,
|
||||
crate::services::chat::ChatCompletionChunk,
|
||||
routes::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
completions::CompletionResponse,
|
||||
completions::Segments,
|
||||
completions::Choice,
|
||||
completions::Snippet,
|
||||
completions::DebugOptions,
|
||||
completions::DebugData,
|
||||
chat::ChatCompletionRequest,
|
||||
chat::Message,
|
||||
chat::ChatCompletionChunk,
|
||||
health::HealthState,
|
||||
health::Version,
|
||||
SearchResponse,
|
||||
Hit,
|
||||
HitDocument
|
||||
api::SearchResponse,
|
||||
api::Hit,
|
||||
api::HitDocument
|
||||
))
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
|
@ -178,21 +166,22 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
let completion_state = {
|
||||
let (
|
||||
engine,
|
||||
PromptInfo {
|
||||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||
) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||
let state =
|
||||
completions::CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
let (engine, PromptInfo { chat_template, .. }) =
|
||||
load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||
model::load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||
let Some(chat_template) = chat_template else {
|
||||
panic!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let state = ChatService::new(engine, chat_template);
|
||||
let state = chat::ChatService::new(engine, chat_template);
|
||||
Some(Arc::new(state))
|
||||
} else {
|
||||
None
|
||||
|
|
@ -200,17 +189,21 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
|
||||
let mut routers = vec![];
|
||||
|
||||
let health_state = Arc::new(health::HealthState::new(args));
|
||||
let health_state = Arc::new(health::HealthState::new(
|
||||
&args.model,
|
||||
args.chat_model.as_deref(),
|
||||
&args.device,
|
||||
));
|
||||
routers.push({
|
||||
Router::new()
|
||||
.route("/v1/events", routing::post(events::log_event))
|
||||
.route("/v1/events", routing::post(routes::log_event))
|
||||
.route(
|
||||
"/v1/health",
|
||||
routing::post(health::health).with_state(health_state.clone()),
|
||||
routing::post(routes::health).with_state(health_state.clone()),
|
||||
)
|
||||
.route(
|
||||
"/v1/health",
|
||||
routing::get(health::health).with_state(health_state),
|
||||
routing::get(routes::health).with_state(health_state),
|
||||
)
|
||||
});
|
||||
|
||||
|
|
@ -218,7 +211,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
Router::new()
|
||||
.route(
|
||||
"/v1/completions",
|
||||
routing::post(completions::completions).with_state(completion_state),
|
||||
routing::post(routes::completions).with_state(completion_state),
|
||||
)
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||
config.server.completion_timeout,
|
||||
|
|
@ -229,7 +222,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
routers.push({
|
||||
Router::new().route(
|
||||
"/v1beta/chat/completions",
|
||||
routing::post(chat::completions).with_state(chat_state),
|
||||
routing::post(routes::chat_completions).with_state(chat_state),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
|
@ -237,7 +230,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
routers.push({
|
||||
Router::new().route(
|
||||
"/v1beta/search",
|
||||
routing::get(search::search).with_state(code),
|
||||
routing::get(routes::search).with_state(code),
|
||||
)
|
||||
});
|
||||
|
||||
|
|
@ -250,7 +243,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
}
|
||||
|
||||
fn start_heartbeat(args: &ServeArgs) {
|
||||
let state = HealthState::new(args);
|
||||
let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
usage::capture("ServeHealth", &state).await;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
use std::{env::consts::ARCH, sync::Arc};
|
||||
use std::{env::consts::ARCH};
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{extract::State, Json};
|
||||
|
||||
use nvml_wrapper::Nvml;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sysinfo::{CpuExt, System, SystemExt};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::serve::Device;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct HealthState {
|
||||
model: String,
|
||||
|
|
@ -21,7 +23,7 @@ pub struct HealthState {
|
|||
}
|
||||
|
||||
impl HealthState {
|
||||
pub fn new(args: &super::ServeArgs) -> Self {
|
||||
pub fn new(model: &str, chat_model: Option<&str>, device: &Device) -> Self {
|
||||
let (cpu_info, cpu_count) = read_cpu_info();
|
||||
|
||||
let cuda_devices = match read_cuda_devices() {
|
||||
|
|
@ -30,9 +32,9 @@ impl HealthState {
|
|||
};
|
||||
|
||||
Self {
|
||||
model: args.model.clone(),
|
||||
chat_model: args.chat_model.clone(),
|
||||
device: args.device.to_string(),
|
||||
model: model.to_owned(),
|
||||
chat_model: chat_model.map(|x| x.to_owned()),
|
||||
device: device.to_string(),
|
||||
arch: ARCH.to_string(),
|
||||
cpu_info,
|
||||
cpu_count,
|
||||
|
|
@ -90,15 +92,3 @@ impl Version {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/v1/health",
|
||||
tag = "v1",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
||||
)
|
||||
)]
|
||||
pub async fn health(State(state): State<Arc<HealthState>>) -> Json<HealthState> {
|
||||
Json(state.as_ref().clone())
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
pub mod chat;
|
||||
pub mod code;
|
||||
pub mod completions;
|
||||
pub mod health;
|
||||
pub mod model;
|
||||
|
|
|
|||
Loading…
Reference in New Issue