refactor: move generate / generate_stream to /v1beta (#487)

release-0.2
Meng Zhang 2023-09-28 16:58:17 -07:00 committed by GitHub
parent 56b7b850af
commit a159c2358d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 14 deletions

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use async_stream::stream; use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs; use axum_streams::StreamBodyAs;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument; use tracing::instrument;
@ -21,7 +22,7 @@ impl GenerateState {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateRequest { pub struct GenerateRequest {
#[schema( #[schema(
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef" example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
)] )]
prompt: String, prompt: String,
} }
@ -33,10 +34,10 @@ pub struct GenerateResponse {
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/generate", path = "/v1beta/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
operation_id = "generate", operation_id = "generate",
tag = "v1", tag = "v1beta",
responses( responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"), (status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
) )
@ -54,10 +55,10 @@ pub async fn generate(
#[utoipa::path( #[utoipa::path(
post, post,
path = "/v1/generate_stream", path = "/v1beta/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
operation_id = "generate_stream", operation_id = "generate_stream",
tag = "v1", tag = "v1beta",
responses( responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"), (status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
) )
@ -77,11 +78,16 @@ pub async fn generate_stream(
StreamBodyAs::json_nl(s) StreamBodyAs::json_nl(s)
} }
lazy_static! {
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
}
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions { fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default() TextGenerationOptionsBuilder::default()
.max_input_length(2048) .max_input_length(1024)
.max_decoding_length(usize::MAX) .max_decoding_length(1024)
.sampling_temperature(0.1) .sampling_temperature(0.1)
.stop_words(&STOP_WORDS)
.build() .build()
.unwrap() .unwrap()
} }

View File

@ -159,9 +159,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
let doc = add_localhost_server(ApiDoc::openapi(), args.port); let doc = add_localhost_server(ApiDoc::openapi(), args.port);
let doc = add_proxy_server(doc, config.swagger.server_url.clone()); let doc = add_proxy_server(doc, config.swagger.server_url.clone());
let app = Router::new() let app = api_router(args, config)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.nest("/v1", api_router(args, config))
.fallback(fallback()); .fallback(fallback());
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
@ -178,24 +177,24 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
let (engine, prompt_template) = create_engine(args); let (engine, prompt_template) = create_engine(args);
let engine = Arc::new(engine); let engine = Arc::new(engine);
Router::new() Router::new()
.route("/events", routing::post(events::log_event)) .route("/v1/events", routing::post(events::log_event))
.route( .route(
"/health", "/v1/health",
routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))), routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))),
) )
.route( .route(
"/completions", "/v1/completions",
routing::post(completions::completion).with_state(Arc::new( routing::post(completions::completion).with_state(Arc::new(
completions::CompletionState::new(engine.clone(), prompt_template, config), completions::CompletionState::new(engine.clone(), prompt_template, config),
)), )),
) )
.route( .route(
"/generate", "/v1beta/generate",
routing::post(generate::generate) routing::post(generate::generate)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))), .with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
) )
.route( .route(
"/generate_stream", "/v1beta/generate_stream",
routing::post(generate::generate_stream) routing::post(generate::generate_stream)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))), .with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
) )