refactor: move generate / generate_stream to /v1beta (#487)

release-0.2
Meng Zhang 2023-09-28 16:58:17 -07:00 committed by GitHub
parent 56b7b850af
commit a159c2358d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 14 deletions

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use async_stream::stream;
use axum::{extract::State, response::IntoResponse, Json};
use axum_streams::StreamBodyAs;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use tracing::instrument;
@ -21,7 +22,7 @@ impl GenerateState {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct GenerateRequest {
#[schema(
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef"
example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n"
)]
prompt: String,
}
@ -33,10 +34,10 @@ pub struct GenerateResponse {
#[utoipa::path(
post,
path = "/v1/generate",
path = "/v1beta/generate",
request_body = GenerateRequest,
operation_id = "generate",
tag = "v1",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"),
)
@ -54,10 +55,10 @@ pub async fn generate(
#[utoipa::path(
post,
path = "/v1/generate_stream",
path = "/v1beta/generate_stream",
request_body = GenerateRequest,
operation_id = "generate_stream",
tag = "v1",
tag = "v1beta",
responses(
(status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"),
)
@ -77,11 +78,16 @@ pub async fn generate_stream(
StreamBodyAs::json_nl(s)
}
lazy_static! {
static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",];
}
fn build_options(_request: &GenerateRequest) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
.max_input_length(2048)
.max_decoding_length(usize::MAX)
.max_input_length(1024)
.max_decoding_length(1024)
.sampling_temperature(0.1)
.stop_words(&STOP_WORDS)
.build()
.unwrap()
}

View File

@ -159,9 +159,8 @@ pub async fn main(config: &Config, args: &ServeArgs) {
let doc = add_localhost_server(ApiDoc::openapi(), args.port);
let doc = add_proxy_server(doc, config.swagger.server_url.clone());
let app = Router::new()
let app = api_router(args, config)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.nest("/v1", api_router(args, config))
.fallback(fallback());
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
@ -178,24 +177,24 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
let (engine, prompt_template) = create_engine(args);
let engine = Arc::new(engine);
Router::new()
.route("/events", routing::post(events::log_event))
.route("/v1/events", routing::post(events::log_event))
.route(
"/health",
"/v1/health",
routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))),
)
.route(
"/completions",
"/v1/completions",
routing::post(completions::completion).with_state(Arc::new(
completions::CompletionState::new(engine.clone(), prompt_template, config),
)),
)
.route(
"/generate",
"/v1beta/generate",
routing::post(generate::generate)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
)
.route(
"/generate_stream",
"/v1beta/generate_stream",
routing::post(generate::generate_stream)
.with_state(Arc::new(generate::GenerateState::new(engine.clone()))),
)