refactor: extract routes/ to share routes between commands (#774)
* refactor: extract routes/ to share routes between commands * refactor: extract events api * extract EventLogger service * lift api into sub packages * services completions -> completion * remove useless code * fix testrelease-fix-intellij-update-support-version-range
parent
63c7da4f96
commit
e0017cadec
|
|
@ -4282,6 +4282,7 @@ dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"axum-streams",
|
"axum-streams",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
|
"chrono",
|
||||||
"clap 4.4.7",
|
"clap 4.4.7",
|
||||||
"futures",
|
"futures",
|
||||||
"http-api-bindings",
|
"http-api-bindings",
|
||||||
|
|
@ -4324,7 +4325,6 @@ name = "tabby-common"
|
||||||
version = "0.6.0-dev"
|
version = "0.6.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"chrono",
|
|
||||||
"filenamify",
|
"filenamify",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
|
@ -4332,7 +4332,6 @@ dependencies = [
|
||||||
"serde-jsonlines",
|
"serde-jsonlines",
|
||||||
"serdeconv",
|
"serdeconv",
|
||||||
"tantivy",
|
"tantivy",
|
||||||
"tokio",
|
|
||||||
"uuid 1.4.1",
|
"uuid 1.4.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,4 +39,4 @@ thiserror = "1.0.49"
|
||||||
utoipa = "3.3"
|
utoipa = "3.3"
|
||||||
axum = "0.6"
|
axum = "0.6"
|
||||||
hyper = "0.14"
|
hyper = "0.14"
|
||||||
juniper = "0.15"
|
juniper = "0.15"
|
||||||
|
|
|
||||||
|
|
@ -16,55 +16,6 @@ use juniper::{
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use std::sync::Arc;
|
|
||||||
///
|
|
||||||
/// use axum::{routing::post, Extension, Json, Router};
|
|
||||||
/// use juniper::{
|
|
||||||
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
|
|
||||||
/// };
|
|
||||||
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
|
|
||||||
///
|
|
||||||
/// #[derive(Clone, Copy, Debug)]
|
|
||||||
/// pub struct Context;
|
|
||||||
///
|
|
||||||
/// impl juniper::Context for Context {}
|
|
||||||
///
|
|
||||||
/// #[derive(Clone, Copy, Debug)]
|
|
||||||
/// pub struct Query;
|
|
||||||
///
|
|
||||||
/// #[graphql_object(context = Context)]
|
|
||||||
/// impl Query {
|
|
||||||
/// fn add(a: i32, b: i32) -> i32 {
|
|
||||||
/// a + b
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
|
|
||||||
///
|
|
||||||
/// let schema = Schema::new(
|
|
||||||
/// Query,
|
|
||||||
/// EmptyMutation::<Context>::new(),
|
|
||||||
/// EmptySubscription::<Context>::new()
|
|
||||||
/// );
|
|
||||||
///
|
|
||||||
/// let app: Router = Router::new()
|
|
||||||
/// .route("/graphql", post(graphql))
|
|
||||||
/// .layer(Extension(Arc::new(schema)))
|
|
||||||
/// .layer(Extension(Context));
|
|
||||||
///
|
|
||||||
/// # #[axum::debug_handler]
|
|
||||||
/// async fn graphql(
|
|
||||||
/// Extension(schema): Extension<Arc<Schema>>,
|
|
||||||
/// Extension(context): Extension<Context>,
|
|
||||||
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
|
|
||||||
/// ) -> JuniperResponse {
|
|
||||||
/// JuniperResponse(req.execute(&*schema, &context).await)
|
|
||||||
/// }
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
|
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
|
||||||
where
|
where
|
||||||
|
|
|
||||||
|
|
@ -15,55 +15,6 @@ pub trait FromStateAndClientAddr<C, S> {
|
||||||
fn build(state: S, client_addr: SocketAddr) -> C;
|
fn build(state: S, client_addr: SocketAddr) -> C;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing
|
|
||||||
/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one.
|
|
||||||
///
|
|
||||||
/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need
|
|
||||||
/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s
|
|
||||||
/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a
|
|
||||||
/// > [`JuniperRequest`] (see its documentation for examples).
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use std::sync::Arc;
|
|
||||||
///
|
|
||||||
/// use axum::{routing::post, Extension, Json, Router};
|
|
||||||
/// use juniper::{
|
|
||||||
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
|
|
||||||
/// };
|
|
||||||
/// use juniper_axum::graphql;
|
|
||||||
///
|
|
||||||
/// #[derive(Clone, Copy, Debug, Default)]
|
|
||||||
/// pub struct Context;
|
|
||||||
///
|
|
||||||
/// impl juniper::Context for Context {}
|
|
||||||
///
|
|
||||||
/// #[derive(Clone, Copy, Debug)]
|
|
||||||
/// pub struct Query;
|
|
||||||
///
|
|
||||||
/// #[graphql_object(context = Context)]
|
|
||||||
/// impl Query {
|
|
||||||
/// fn add(a: i32, b: i32) -> i32 {
|
|
||||||
/// a + b
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
|
|
||||||
///
|
|
||||||
/// let schema = Schema::new(
|
|
||||||
/// Query,
|
|
||||||
/// EmptyMutation::<Context>::new(),
|
|
||||||
/// EmptySubscription::<Context>::new()
|
|
||||||
/// );
|
|
||||||
///
|
|
||||||
/// let app: Router = Router::new()
|
|
||||||
/// .route("/graphql", post(graphql::<Arc<Schema>>))
|
|
||||||
/// .layer(Extension(Arc::new(schema)));
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// [`extract`]: axum::extract
|
|
||||||
/// [`Handler`]: axum::handler::Handler
|
|
||||||
#[cfg_attr(text, axum::debug_handler)]
|
#[cfg_attr(text, axum::debug_handler)]
|
||||||
pub async fn graphql<S, C>(
|
pub async fn graphql<S, C>(
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,12 @@ version = "0.6.0-dev"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
chrono = "0.4.26"
|
|
||||||
filenamify = "0.1.0"
|
filenamify = "0.1.0"
|
||||||
lazy_static = { workspace = true }
|
lazy_static = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serdeconv = { workspace = true }
|
serdeconv = { workspace = true }
|
||||||
serde-jsonlines = { workspace = true }
|
serde-jsonlines = { workspace = true }
|
||||||
reqwest = { workspace = true, features = [ "json" ] }
|
reqwest = { workspace = true, features = [ "json" ] }
|
||||||
tokio = { workspace = true, features = ["rt", "macros"] }
|
|
||||||
uuid = { version = "1.4.1", features = ["v4"] }
|
uuid = { version = "1.4.1", features = ["v4"] }
|
||||||
tantivy.workspace = true
|
tantivy.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod events;
|
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod languages;
|
pub mod languages;
|
||||||
pub mod path;
|
pub mod path;
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ futures.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
tabby-webserver = { path = "../../ee/tabby-webserver" }
|
tabby-webserver = { path = "../../ee/tabby-webserver" }
|
||||||
thiserror.workspace = true
|
thiserror.workspace = true
|
||||||
|
chrono = "0.4.31"
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct LogEventRequest {
|
||||||
|
/// Event type, should be `view` or `select`.
|
||||||
|
#[schema(example = "view")]
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub event_type: String,
|
||||||
|
|
||||||
|
pub completion_id: String,
|
||||||
|
|
||||||
|
pub choice_index: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct Choice<'a> {
|
||||||
|
pub index: u32,
|
||||||
|
pub text: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum SelectKind {
|
||||||
|
Line,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum Event<'a> {
|
||||||
|
View {
|
||||||
|
completion_id: &'a str,
|
||||||
|
choice_index: u32,
|
||||||
|
},
|
||||||
|
Select {
|
||||||
|
completion_id: &'a str,
|
||||||
|
choice_index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
kind: Option<SelectKind>,
|
||||||
|
},
|
||||||
|
Completion {
|
||||||
|
completion_id: &'a str,
|
||||||
|
language: &'a str,
|
||||||
|
prompt: &'a str,
|
||||||
|
segments: &'a Option<Segments>,
|
||||||
|
choices: Vec<Choice<'a>>,
|
||||||
|
user: Option<&'a str>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct Segments {
|
||||||
|
pub prefix: String,
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait EventLogger: Send + Sync {
|
||||||
|
fn log(&self, e: &Event);
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,2 @@
|
||||||
mod code;
|
pub mod code;
|
||||||
|
pub mod event;
|
||||||
pub use code::*;
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
mod api;
|
mod api;
|
||||||
mod download;
|
mod download;
|
||||||
|
mod routes;
|
||||||
mod serve;
|
mod serve;
|
||||||
|
|
||||||
mod services;
|
mod services;
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ use crate::services::chat::{ChatCompletionRequest, ChatService};
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, request))]
|
#[instrument(skip(state, request))]
|
||||||
pub async fn completions(
|
pub async fn chat_completions(
|
||||||
State(state): State<Arc<ChatService>>,
|
State(state): State<Arc<ChatService>>,
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
|
|
@ -4,7 +4,7 @@ use axum::{extract::State, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use tracing::{instrument, warn};
|
use tracing::{instrument, warn};
|
||||||
|
|
||||||
use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService};
|
use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService};
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
|
@ -1,22 +1,12 @@
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use axum::{extract::Query, Json};
|
use axum::{
|
||||||
|
extract::{Query, State},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tabby_common::events::{self, SelectKind};
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
use crate::api::event::{Event, EventLogger, LogEventRequest, SelectKind};
|
||||||
pub struct LogEventRequest {
|
|
||||||
/// Event type, should be `view` or `select`.
|
|
||||||
#[schema(example = "view")]
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
event_type: String,
|
|
||||||
|
|
||||||
completion_id: String,
|
|
||||||
|
|
||||||
choice_index: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
|
@ -30,22 +20,22 @@ pub struct LogEventRequest {
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn log_event(
|
pub async fn log_event(
|
||||||
|
State(logger): State<Arc<dyn EventLogger>>,
|
||||||
Query(params): Query<HashMap<String, String>>,
|
Query(params): Query<HashMap<String, String>>,
|
||||||
Json(request): Json<LogEventRequest>,
|
Json(request): Json<LogEventRequest>,
|
||||||
) -> StatusCode {
|
) -> StatusCode {
|
||||||
if request.event_type == "view" {
|
if request.event_type == "view" {
|
||||||
events::Event::View {
|
logger.log(&Event::View {
|
||||||
completion_id: &request.completion_id,
|
completion_id: &request.completion_id,
|
||||||
choice_index: request.choice_index,
|
choice_index: request.choice_index,
|
||||||
}
|
});
|
||||||
.log();
|
|
||||||
StatusCode::OK
|
StatusCode::OK
|
||||||
} else if request.event_type == "select" {
|
} else if request.event_type == "select" {
|
||||||
let is_line = params
|
let is_line = params
|
||||||
.get("select_kind")
|
.get("select_kind")
|
||||||
.map(|x| x == "line")
|
.map(|x| x == "line")
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
events::Event::Select {
|
logger.log(&Event::Select {
|
||||||
completion_id: &request.completion_id,
|
completion_id: &request.completion_id,
|
||||||
choice_index: request.choice_index,
|
choice_index: request.choice_index,
|
||||||
kind: if is_line {
|
kind: if is_line {
|
||||||
|
|
@ -53,8 +43,7 @@ pub async fn log_event(
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
},
|
},
|
||||||
}
|
});
|
||||||
.log();
|
|
||||||
StatusCode::OK
|
StatusCode::OK
|
||||||
} else {
|
} else {
|
||||||
StatusCode::BAD_REQUEST
|
StatusCode::BAD_REQUEST
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{extract::State, Json};
|
||||||
|
|
||||||
|
use crate::services::health;
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/v1/health",
|
||||||
|
tag = "v1",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn health(State(state): State<Arc<health::HealthState>>) -> Json<health::HealthState> {
|
||||||
|
Json(state.as_ref().clone())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
mod chat;
|
||||||
|
mod completions;
|
||||||
|
mod events;
|
||||||
|
mod health;
|
||||||
|
mod search;
|
||||||
|
|
||||||
|
pub use chat::*;
|
||||||
|
pub use completions::*;
|
||||||
|
pub use events::*;
|
||||||
|
pub use health::*;
|
||||||
|
pub use search::*;
|
||||||
|
|
@ -10,7 +10,7 @@ use serde::Deserialize;
|
||||||
use tracing::{instrument, warn};
|
use tracing::{instrument, warn};
|
||||||
use utoipa::IntoParams;
|
use utoipa::IntoParams;
|
||||||
|
|
||||||
use crate::api::{CodeSearch, CodeSearchError, SearchResponse};
|
use crate::api::code::{CodeSearch, CodeSearchError, SearchResponse};
|
||||||
|
|
||||||
#[derive(Deserialize, IntoParams)]
|
#[derive(Deserialize, IntoParams)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
|
|
@ -1,9 +1,3 @@
|
||||||
mod chat;
|
|
||||||
mod completions;
|
|
||||||
mod events;
|
|
||||||
mod health;
|
|
||||||
mod search;
|
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
fs,
|
fs,
|
||||||
net::{Ipv4Addr, SocketAddr},
|
net::{Ipv4Addr, SocketAddr},
|
||||||
|
|
@ -23,15 +17,10 @@ use tracing::info;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
use self::health::HealthState;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{Hit, HitDocument, SearchResponse},
|
api::{self},
|
||||||
fatal,
|
fatal, routes,
|
||||||
services::{
|
services::{chat, completion, event::create_event_logger, health, model},
|
||||||
chat::ChatService,
|
|
||||||
completions::CompletionService,
|
|
||||||
model::{load_text_generation, PromptInfo},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
|
@ -51,24 +40,24 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
servers(
|
servers(
|
||||||
(url = "/", description = "Server"),
|
(url = "/", description = "Server"),
|
||||||
),
|
),
|
||||||
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
events::LogEventRequest,
|
api::event::LogEventRequest,
|
||||||
crate::services::completions::CompletionRequest,
|
completion::CompletionRequest,
|
||||||
crate::services::completions::CompletionResponse,
|
completion::CompletionResponse,
|
||||||
crate::services::completions::Segments,
|
completion::Segments,
|
||||||
crate::services::completions::Choice,
|
completion::Choice,
|
||||||
crate::services::completions::Snippet,
|
completion::Snippet,
|
||||||
crate::services::completions::DebugOptions,
|
completion::DebugOptions,
|
||||||
crate::services::completions::DebugData,
|
completion::DebugData,
|
||||||
crate::services::chat::ChatCompletionRequest,
|
chat::ChatCompletionRequest,
|
||||||
crate::services::chat::Message,
|
chat::Message,
|
||||||
crate::services::chat::ChatCompletionChunk,
|
chat::ChatCompletionChunk,
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
health::Version,
|
health::Version,
|
||||||
SearchResponse,
|
api::code::SearchResponse,
|
||||||
Hit,
|
api::code::Hit,
|
||||||
HitDocument
|
api::code::HitDocument
|
||||||
))
|
))
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
@ -174,25 +163,31 @@ async fn load_model(args: &ServeArgs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
|
let logger = Arc::new(create_event_logger());
|
||||||
let code = Arc::new(crate::services::code::create_code_search());
|
let code = Arc::new(crate::services::code::create_code_search());
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
let (
|
let (
|
||||||
engine,
|
engine,
|
||||||
PromptInfo {
|
model::PromptInfo {
|
||||||
prompt_template, ..
|
prompt_template, ..
|
||||||
},
|
},
|
||||||
) = load_text_generation(&args.model, &args.device, args.parallelism).await;
|
) = model::load_text_generation(&args.model, &args.device, args.parallelism).await;
|
||||||
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
let state = completion::CompletionService::new(
|
||||||
|
engine.clone(),
|
||||||
|
code.clone(),
|
||||||
|
logger.clone(),
|
||||||
|
prompt_template,
|
||||||
|
);
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||||
let (engine, PromptInfo { chat_template, .. }) =
|
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||||
load_text_generation(chat_model, &args.device, args.parallelism).await;
|
model::load_text_generation(chat_model, &args.device, args.parallelism).await;
|
||||||
let Some(chat_template) = chat_template else {
|
let Some(chat_template) = chat_template else {
|
||||||
panic!("Chat model requires specifying prompt template");
|
panic!("Chat model requires specifying prompt template");
|
||||||
};
|
};
|
||||||
let state = ChatService::new(engine, chat_template);
|
let state = chat::ChatService::new(engine, chat_template);
|
||||||
Some(Arc::new(state))
|
Some(Arc::new(state))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
@ -200,17 +195,24 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
|
|
||||||
let mut routers = vec![];
|
let mut routers = vec![];
|
||||||
|
|
||||||
let health_state = Arc::new(health::HealthState::new(args));
|
let health_state = Arc::new(health::HealthState::new(
|
||||||
|
&args.model,
|
||||||
|
args.chat_model.as_deref(),
|
||||||
|
&args.device,
|
||||||
|
));
|
||||||
routers.push({
|
routers.push({
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/events", routing::post(events::log_event))
|
|
||||||
.route(
|
.route(
|
||||||
"/v1/health",
|
"/v1/events",
|
||||||
routing::post(health::health).with_state(health_state.clone()),
|
routing::post(routes::log_event).with_state(logger),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
"/v1/health",
|
"/v1/health",
|
||||||
routing::get(health::health).with_state(health_state),
|
routing::post(routes::health).with_state(health_state.clone()),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/v1/health",
|
||||||
|
routing::get(routes::health).with_state(health_state),
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -218,7 +220,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
routing::post(completions::completions).with_state(completion_state),
|
routing::post(routes::completions).with_state(completion_state),
|
||||||
)
|
)
|
||||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||||
config.server.completion_timeout,
|
config.server.completion_timeout,
|
||||||
|
|
@ -229,7 +231,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
routers.push({
|
routers.push({
|
||||||
Router::new().route(
|
Router::new().route(
|
||||||
"/v1beta/chat/completions",
|
"/v1beta/chat/completions",
|
||||||
routing::post(chat::completions).with_state(chat_state),
|
routing::post(routes::chat_completions).with_state(chat_state),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -237,7 +239,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
routers.push({
|
routers.push({
|
||||||
Router::new().route(
|
Router::new().route(
|
||||||
"/v1beta/search",
|
"/v1beta/search",
|
||||||
routing::get(search::search).with_state(code),
|
routing::get(routes::search).with_state(code),
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -250,7 +252,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_heartbeat(args: &ServeArgs) {
|
fn start_heartbeat(args: &ServeArgs) {
|
||||||
let state = HealthState::new(args);
|
let state = health::HealthState::new(&args.model, args.chat_model.as_deref(), &args.device);
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
usage::capture("ServeHealth", &state).await;
|
usage::capture("ServeHealth", &state).await;
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ use tantivy::{
|
||||||
use tokio::{sync::Mutex, time::sleep};
|
use tokio::{sync::Mutex, time::sleep};
|
||||||
use tracing::{debug, log::info};
|
use tracing::{debug, log::info};
|
||||||
|
|
||||||
use crate::api::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse};
|
use crate::api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse};
|
||||||
|
|
||||||
struct CodeSearchImpl {
|
struct CodeSearchImpl {
|
||||||
reader: IndexReader,
|
reader: IndexReader,
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,19 @@
|
||||||
mod completions_prompt;
|
mod completion_prompt;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::{events, languages::get_language};
|
use tabby_common::languages::get_language;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use crate::api::CodeSearch;
|
use crate::api::{
|
||||||
|
self,
|
||||||
|
code::CodeSearch,
|
||||||
|
event::{Event, EventLogger},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum CompletionError {
|
pub enum CompletionError {
|
||||||
|
|
@ -95,9 +99,9 @@ pub struct Segments {
|
||||||
suffix: Option<String>,
|
suffix: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Segments> for events::Segments {
|
impl From<Segments> for api::event::Segments {
|
||||||
fn from(val: Segments) -> Self {
|
fn from(val: Segments) -> Self {
|
||||||
events::Segments {
|
Self {
|
||||||
prefix: val.prefix,
|
prefix: val.prefix,
|
||||||
suffix: val.suffix,
|
suffix: val.suffix,
|
||||||
}
|
}
|
||||||
|
|
@ -157,18 +161,21 @@ pub struct DebugData {
|
||||||
|
|
||||||
pub struct CompletionService {
|
pub struct CompletionService {
|
||||||
engine: Arc<dyn TextGeneration>,
|
engine: Arc<dyn TextGeneration>,
|
||||||
prompt_builder: completions_prompt::PromptBuilder,
|
logger: Arc<dyn EventLogger>,
|
||||||
|
prompt_builder: completion_prompt::PromptBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CompletionService {
|
impl CompletionService {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
engine: Arc<dyn TextGeneration>,
|
engine: Arc<dyn TextGeneration>,
|
||||||
code: Arc<dyn CodeSearch>,
|
code: Arc<dyn CodeSearch>,
|
||||||
|
logger: Arc<dyn EventLogger>,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)),
|
prompt_builder: completion_prompt::PromptBuilder::new(prompt_template, Some(code)),
|
||||||
|
logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -226,18 +233,17 @@ impl CompletionService {
|
||||||
let text = self.engine.generate(&prompt, options).await;
|
let text = self.engine.generate(&prompt, options).await;
|
||||||
let segments = segments.map(|s| s.into());
|
let segments = segments.map(|s| s.into());
|
||||||
|
|
||||||
events::Event::Completion {
|
self.logger.log(&Event::Completion {
|
||||||
completion_id: &completion_id,
|
completion_id: &completion_id,
|
||||||
language: &language,
|
language: &language,
|
||||||
prompt: &prompt,
|
prompt: &prompt,
|
||||||
segments: &segments,
|
segments: &segments,
|
||||||
choices: vec![events::Choice {
|
choices: vec![api::event::Choice {
|
||||||
index: 0,
|
index: 0,
|
||||||
text: &text,
|
text: &text,
|
||||||
}],
|
}],
|
||||||
user: request.user.as_deref(),
|
user: request.user.as_deref(),
|
||||||
}
|
});
|
||||||
.log();
|
|
||||||
|
|
||||||
let debug_data = request
|
let debug_data = request
|
||||||
.debug_options
|
.debug_options
|
||||||
|
|
@ -8,7 +8,7 @@ use textdistance::Algorithm;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
use super::{Segments, Snippet};
|
use super::{Segments, Snippet};
|
||||||
use crate::api::{CodeSearch, CodeSearchError};
|
use crate::api::code::{CodeSearch, CodeSearchError};
|
||||||
|
|
||||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
||||||
|
|
@ -7,17 +7,20 @@ use std::{
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use tabby_common::path;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
sync::mpsc::{unbounded_channel, UnboundedSender},
|
sync::mpsc::{unbounded_channel, UnboundedSender},
|
||||||
time::{self},
|
time::{self},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::api::event::{Event, EventLogger};
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref WRITER: UnboundedSender<String> = {
|
static ref WRITER: UnboundedSender<String> = {
|
||||||
let (tx, mut rx) = unbounded_channel::<String>();
|
let (tx, mut rx) = unbounded_channel::<String>();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let events_dir = crate::path::events_dir();
|
let events_dir = path::events_dir();
|
||||||
std::fs::create_dir_all(events_dir.as_path()).ok();
|
std::fs::create_dir_all(events_dir.as_path()).ok();
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
|
|
@ -53,45 +56,7 @@ lazy_static! {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
struct EventService;
|
||||||
pub struct Choice<'a> {
|
|
||||||
pub index: u32,
|
|
||||||
pub text: &'a str,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum SelectKind {
|
|
||||||
Line,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum Event<'a> {
|
|
||||||
View {
|
|
||||||
completion_id: &'a str,
|
|
||||||
choice_index: u32,
|
|
||||||
},
|
|
||||||
Select {
|
|
||||||
completion_id: &'a str,
|
|
||||||
choice_index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
kind: Option<SelectKind>,
|
|
||||||
},
|
|
||||||
Completion {
|
|
||||||
completion_id: &'a str,
|
|
||||||
language: &'a str,
|
|
||||||
prompt: &'a str,
|
|
||||||
segments: &'a Option<Segments>,
|
|
||||||
choices: Vec<Choice<'a>>,
|
|
||||||
user: Option<&'a str>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct Segments {
|
|
||||||
pub prefix: String,
|
|
||||||
pub suffix: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct Log<'a> {
|
struct Log<'a> {
|
||||||
|
|
@ -99,11 +64,11 @@ struct Log<'a> {
|
||||||
event: &'a Event<'a>,
|
event: &'a Event<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Event<'_> {
|
impl EventLogger for EventService {
|
||||||
pub fn log(&self) {
|
fn log(&self, e: &Event) {
|
||||||
let content = serdeconv::to_json_string(&Log {
|
let content = serdeconv::to_json_string(&Log {
|
||||||
ts: timestamp(),
|
ts: timestamp(),
|
||||||
event: self,
|
event: e,
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -119,3 +84,7 @@ fn timestamp() -> u128 {
|
||||||
.expect("Time went backwards")
|
.expect("Time went backwards")
|
||||||
.as_millis()
|
.as_millis()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn create_event_logger() -> impl EventLogger {
|
||||||
|
EventService
|
||||||
|
}
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
use std::{env::consts::ARCH, sync::Arc};
|
use std::env::consts::ARCH;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{extract::State, Json};
|
|
||||||
use nvml_wrapper::Nvml;
|
use nvml_wrapper::Nvml;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sysinfo::{CpuExt, System, SystemExt};
|
use sysinfo::{CpuExt, System, SystemExt};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
use crate::serve::Device;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
pub struct HealthState {
|
pub struct HealthState {
|
||||||
model: String,
|
model: String,
|
||||||
|
|
@ -21,7 +22,7 @@ pub struct HealthState {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HealthState {
|
impl HealthState {
|
||||||
pub fn new(args: &super::ServeArgs) -> Self {
|
pub fn new(model: &str, chat_model: Option<&str>, device: &Device) -> Self {
|
||||||
let (cpu_info, cpu_count) = read_cpu_info();
|
let (cpu_info, cpu_count) = read_cpu_info();
|
||||||
|
|
||||||
let cuda_devices = match read_cuda_devices() {
|
let cuda_devices = match read_cuda_devices() {
|
||||||
|
|
@ -30,9 +31,9 @@ impl HealthState {
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model: args.model.clone(),
|
model: model.to_owned(),
|
||||||
chat_model: args.chat_model.clone(),
|
chat_model: chat_model.map(|x| x.to_owned()),
|
||||||
device: args.device.to_string(),
|
device: device.to_string(),
|
||||||
arch: ARCH.to_string(),
|
arch: ARCH.to_string(),
|
||||||
cpu_info,
|
cpu_info,
|
||||||
cpu_count,
|
cpu_count,
|
||||||
|
|
@ -90,15 +91,3 @@ impl Version {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
|
||||||
get,
|
|
||||||
path = "/v1/health",
|
|
||||||
tag = "v1",
|
|
||||||
responses(
|
|
||||||
(status = 200, description = "Success", body = HealthState, content_type = "application/json"),
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
pub async fn health(State(state): State<Arc<HealthState>>) -> Json<HealthState> {
|
|
||||||
Json(state.as_ref().clone())
|
|
||||||
}
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod code;
|
pub mod code;
|
||||||
pub mod completions;
|
pub mod completion;
|
||||||
|
pub mod event;
|
||||||
|
pub mod health;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue