From a159c2358d8468812cf8e9c2e4c69a8ab04b0ff6 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 28 Sep 2023 16:58:17 -0700 Subject: [PATCH] refactor: move generate / generate_stream to /v1beta (#487) --- crates/tabby/src/serve/generate.rs | 20 +++++++++++++------- crates/tabby/src/serve/mod.rs | 13 ++++++------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/crates/tabby/src/serve/generate.rs b/crates/tabby/src/serve/generate.rs index 4dc2f8a..cfebe9a 100644 --- a/crates/tabby/src/serve/generate.rs +++ b/crates/tabby/src/serve/generate.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use async_stream::stream; use axum::{extract::State, response::IntoResponse, Json}; use axum_streams::StreamBodyAs; +use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; use tracing::instrument; @@ -21,7 +22,7 @@ impl GenerateState { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct GenerateRequest { #[schema( - example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\ndef" + example = "# Dijkstra'\''s shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" )] prompt: String, } @@ -33,10 +34,10 @@ pub struct GenerateResponse { #[utoipa::path( post, - path = "/v1/generate", + path = "/v1beta/generate", request_body = GenerateRequest, operation_id = "generate", - tag = "v1", + tag = "v1beta", responses( (status = 200, description = "Success", body = GenerateResponse, content_type = "application/json"), ) @@ -54,10 +55,10 @@ pub async fn generate( #[utoipa::path( post, - path = "/v1/generate_stream", + path = "/v1beta/generate_stream", request_body = GenerateRequest, operation_id = "generate_stream", - tag = "v1", + tag = "v1beta", responses( (status = 200, description = "Success", body = GenerateResponse, content_type = "application/jsonstream"), ) @@ -77,11 +78,16 @@ pub async fn generate_stream( StreamBodyAs::json_nl(s) } +lazy_static! { + static ref STOP_WORDS: Vec<&'static str> = vec!["\n\n",]; +} + fn build_options(_request: &GenerateRequest) -> TextGenerationOptions { TextGenerationOptionsBuilder::default() - .max_input_length(2048) - .max_decoding_length(usize::MAX) + .max_input_length(1024) + .max_decoding_length(1024) .sampling_temperature(0.1) + .stop_words(&STOP_WORDS) .build() .unwrap() } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 51deaf4..50052c4 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -159,9 +159,8 @@ pub async fn main(config: &Config, args: &ServeArgs) { let doc = add_localhost_server(ApiDoc::openapi(), args.port); let doc = add_proxy_server(doc, config.swagger.server_url.clone()); - let app = Router::new() + let app = api_router(args, config) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) - .nest("/v1", api_router(args, config)) .fallback(fallback()); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); @@ -178,24 +177,24 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { let (engine, prompt_template) = create_engine(args); let engine = Arc::new(engine); Router::new() - .route("/events", routing::post(events::log_event)) + .route("/v1/events", routing::post(events::log_event)) .route( - "/health", + "/v1/health", routing::post(health::health).with_state(Arc::new(health::HealthState::new(args))), ) .route( - "/completions", + "/v1/completions", routing::post(completions::completion).with_state(Arc::new( completions::CompletionState::new(engine.clone(), prompt_template, config), )), ) .route( - "/generate", + "/v1beta/generate", routing::post(generate::generate) .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), ) .route( - "/generate_stream", + "/v1beta/generate_stream", routing::post(generate::generate_stream) .with_state(Arc::new(generate::GenerateState::new(engine.clone()))), )