217 lines
6.6 KiB
Rust
217 lines
6.6 KiB
Rust
use std::{
|
||
net::{Ipv4Addr, SocketAddr},
|
||
sync::Arc,
|
||
time::Duration,
|
||
};
|
||
|
||
use axum::{routing, Router, Server};
|
||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||
use clap::Args;
|
||
use tabby_common::{config::Config, usage};
|
||
use tabby_webserver::attach_webserver;
|
||
use tokio::time::sleep;
|
||
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
|
||
use tracing::info;
|
||
use utoipa::OpenApi;
|
||
use utoipa_swagger_ui::SwaggerUi;
|
||
|
||
use crate::{
|
||
api::{self},
|
||
fatal, routes,
|
||
services::{
|
||
chat::{self, create_chat_service},
|
||
completion::{self, create_completion_service},
|
||
event::create_logger,
|
||
health,
|
||
model::download_model_if_needed,
|
||
},
|
||
Device,
|
||
};
|
||
|
||
#[derive(OpenApi)]
|
||
#[openapi(
|
||
info(title="Tabby Server",
|
||
description = "
|
||
[](https://github.com/TabbyML/tabby)
|
||
[](https://join.slack.com/t/tabbycommunity/shared_invite/zt-1xeiddizp-bciR2RtFTaJ37RBxr8VxpA)
|
||
|
||
Install following IDE / Editor extensions to get started with [Tabby](https://github.com/TabbyML/tabby).
|
||
* [VSCode Extension](https://github.com/TabbyML/tabby/tree/main/clients/vscode) – Install from the [marketplace](https://marketplace.visualstudio.com/items?itemName=TabbyML.vscode-tabby), or [open-vsx.org](https://open-vsx.org/extension/TabbyML/vscode-tabby)
|
||
* [VIM Extension](https://github.com/TabbyML/tabby/tree/main/clients/vim)
|
||
* [IntelliJ Platform Plugin](https://github.com/TabbyML/tabby/tree/main/clients/intellij) – Install from the [marketplace](https://plugins.jetbrains.com/plugin/22379-tabby)
|
||
",
|
||
license(name = "Apache 2.0", url="https://github.com/TabbyML/tabby/blob/main/LICENSE")
|
||
),
|
||
servers(
|
||
(url = "/", description = "Server"),
|
||
),
|
||
paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
|
||
components(schemas(
|
||
api::event::LogEventRequest,
|
||
completion::CompletionRequest,
|
||
completion::CompletionResponse,
|
||
completion::Segments,
|
||
completion::Choice,
|
||
completion::Snippet,
|
||
completion::DebugOptions,
|
||
completion::DebugData,
|
||
chat::ChatCompletionRequest,
|
||
chat::Message,
|
||
chat::ChatCompletionChunk,
|
||
health::HealthState,
|
||
health::Version,
|
||
api::code::SearchResponse,
|
||
api::code::Hit,
|
||
api::code::HitDocument
|
||
))
|
||
)]
|
||
struct ApiDoc;
|
||
|
||
#[derive(Args)]
|
||
pub struct ServeArgs {
|
||
/// Model id for `/completions` API endpoint.
|
||
#[clap(long)]
|
||
model: String,
|
||
|
||
/// Model id for `/chat/completions` API endpoints.
|
||
#[clap(long)]
|
||
chat_model: Option<String>,
|
||
|
||
#[clap(long, default_value_t = 8080)]
|
||
port: u16,
|
||
|
||
/// Device to run model inference.
|
||
#[clap(long, default_value_t=Device::Cpu)]
|
||
device: Device,
|
||
|
||
/// Parallelism for model serving - increasing this number will have a significant impact on the
|
||
/// memory requirement e.g., GPU vRAM.
|
||
#[clap(long, default_value_t = 1)]
|
||
parallelism: u8,
|
||
}
|
||
|
||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||
#[cfg(feature = "experimental-http")]
|
||
if args.device == Device::ExperimentalHttp {
|
||
tracing::warn!("HTTP device is unstable and does not comply with semver expectations.");
|
||
} else {
|
||
load_model(args).await;
|
||
}
|
||
#[cfg(not(feature = "experimental-http"))]
|
||
load_model(args).await;
|
||
|
||
info!("Starting server, this might takes a few minutes...");
|
||
|
||
let app = Router::new()
|
||
.merge(api_router(args, config).await)
|
||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()));
|
||
|
||
let app = attach_webserver(app).await;
|
||
|
||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||
info!("Listening at {}", address);
|
||
|
||
start_heartbeat(args);
|
||
Server::bind(&address)
|
||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||
.await
|
||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||
}
|
||
|
||
async fn load_model(args: &ServeArgs) {
|
||
download_model_if_needed(&args.model).await;
|
||
if let Some(chat_model) = &args.chat_model {
|
||
download_model_if_needed(chat_model).await
|
||
}
|
||
}
|
||
|
||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||
let logger = Arc::new(create_logger());
|
||
let code = Arc::new(crate::services::code::create_code_search());
|
||
let completion = Arc::new(
|
||
create_completion_service(
|
||
code.clone(),
|
||
logger.clone(),
|
||
&args.model,
|
||
&args.device,
|
||
args.parallelism,
|
||
)
|
||
.await,
|
||
);
|
||
|
||
let chat_state = if let Some(_chat_model) = &args.chat_model {
|
||
Some(Arc::new(
|
||
create_chat_service(&args.model, &args.device, args.parallelism).await,
|
||
))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let mut routers = vec![];
|
||
|
||
let health_state = Arc::new(health::HealthState::new(
|
||
&args.model,
|
||
args.chat_model.as_deref(),
|
||
&args.device,
|
||
));
|
||
routers.push({
|
||
Router::new()
|
||
.route(
|
||
"/v1/events",
|
||
routing::post(routes::log_event).with_state(logger),
|
||
)
|
||
.route(
|
||
"/v1/health",
|
||
routing::post(routes::health).with_state(health_state.clone()),
|
||
)
|
||
.route(
|
||
"/v1/health",
|
||
routing::get(routes::health).with_state(health_state),
|
||
)
|
||
});
|
||
|
||
routers.push({
|
||
Router::new()
|
||
.route(
|
||
"/v1/completions",
|
||
routing::post(routes::completions).with_state(completion),
|
||
)
|
||
.layer(TimeoutLayer::new(Duration::from_secs(
|
||
config.server.completion_timeout,
|
||
)))
|
||
});
|
||
|
||
if let Some(chat_state) = chat_state {
|
||
routers.push({
|
||
Router::new().route(
|
||
"/v1beta/chat/completions",
|
||
routing::post(routes::chat_completions).with_state(chat_state),
|
||
)
|
||
})
|
||
}
|
||
|
||
routers.push({
|
||
Router::new().route(
|
||
"/v1beta/search",
|
||
routing::get(routes::search).with_state(code),
|
||
)
|
||
});
|
||
|
||
let mut root = Router::new();
|
||
for router in routers {
|
||
root = root.merge(router);
|
||
}
|
||
root.layer(CorsLayer::permissive())
|
||
.layer(opentelemetry_tracing_layer())
|
||
}
|
||
|
||
fn start_heartbeat(args: &ServeArgs) {
|
||
let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device);
|
||
tokio::spawn(async move {
|
||
loop {
|
||
usage::capture("ServeHealth", &state).await;
|
||
sleep(Duration::from_secs(3000)).await;
|
||
}
|
||
});
|
||
}
|