refactor: move generate / generate_stream to /v1beta (#487)
parent
56b7b850af
commit
a159c2358d
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||
use async_stream::stream;
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use axum_streams::StreamBodyAs;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use tracing::instrument;
|
||||
|
|
@ -21,7 +22,7 @@ impl GenerateState {
|
|||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct GenerateRequest {
|
||||
#[schema(
|
||||
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef"
|
||||
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
|
||||
)]
|
||||
prompt: String,
|
||||
}
|
||||
|
|
@ -33,10 +34,10 @@ pub struct GenerateResponse {
|
|||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1/generate",
|
||||
path = "/v1beta/generate",
|
||||
request_body = GenerateRequest,
|
||||
operation_id = "generate",
|
||||
tag = "v1",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
|
||||
)
|
||||
|
|
@ -54,10 +55,10 @@ pub async fn generate(
|
|||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1/generate_stream",
|
||||
path = "/v1beta/generate_stream",
|
||||
request_body = GenerateRequest,
|
||||
operation_id = "generate_stream",
|
||||
tag = "v1",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
|
||||
)
|
||||
|
|
@ -77,11 +78,16 @@ pub async fn generate_stream(
|
|||
StreamBodyAs::json_nl(s)
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
|
||||
}
|
||||
|
||||
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
|
||||
TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(usize::MAX)
|
||||
.max_input_length(1024)
|
||||
.max_decoding_length(1024)
|
||||
.sampling_temperature(0.1)
|
||||
.stop_words(&STOP_WORDS)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,9 +159,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
|
||||
let doc = add_localhost_server(ApiDoc::openapi(), args.port);
|
||||
let doc = add_proxy_server(doc, config.swagger.server_url.clone());
|
||||
let app = Router::new()
|
||||
let app = api_router(args, config)
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
||||
.nest("/v1", api_router(args, config))
|
||||
.fallback(fallback());
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
|
||||
|
|
@ -178,24 +177,24 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
let (engine, prompt_template) = create_engine(args);
|
||||
let engine = Arc::new(engine);
|
||||
Router::new()
|
||||
.route("/events", routing::post(events::log_event))
|
||||
.route("/v1/events", routing::post(events::log_event))
|
||||
.route(
|
||||
"/health",
|
||||
"/v1/health",
|
||||
routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))),
|
||||
)
|
||||
.route(
|
||||
"/completions",
|
||||
"/v1/completions",
|
||||
routing::post(completions::completion).with_state(Arc::new(
|
||||
completions::CompletionState::new(engine.clone(), prompt_template, config),
|
||||
)),
|
||||
)
|
||||
.route(
|
||||
"/generate",
|
||||
"/v1beta/generate",
|
||||
routing::post(generate::generate)
|
||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||
)
|
||||
.route(
|
||||
"/generate_stream",
|
||||
"/v1beta/generate_stream",
|
||||
routing::post(generate::generate_stream)
|
||||
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue