diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 794b374..c5743f6 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -1,108 +1,10 @@ -mod prompt; - use std::sync::Arc; use axum::{extract::State, Json}; use hyper::StatusCode; -use serde::{Deserialize, Serialize}; -use tabby_common::{events, languages::get_language}; -use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; -use tracing::{debug, instrument}; -use utoipa::ToSchema; +use tracing::{instrument, warn}; -use crate::api::CodeSearch; - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -#[schema(example=json!({ - "language": "python", - "segments": { - "prefix": "def fib(n):\n ", - "suffix": "\n return fib(n - 1) + fib(n - 2)" - } -}))] -pub struct CompletionRequest { - /// Language identifier, full list is maintained at - /// https://code.visualstudio.com/docs/languages/identifiers - #[schema(example = "python")] - language: Option, - - /// When segments are set, the `prompt` is ignored during the inference. - segments: Option, - - /// A unique identifier representing your end-user, which can help Tabby to monitor & generating - /// reports. - user: Option, - - debug_options: Option, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct DebugOptions { - /// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored. - /// - /// This is useful for certain requests that aim to test the tabby's e2e quality. - raw_prompt: Option, - - /// When true, returns `snippets` in `debug_data`. - #[serde(default = "default_false")] - return_snippets: bool, - - /// When true, returns `prompt` in `debug_data`. - #[serde(default = "default_false")] - return_prompt: bool, - - /// When true, disable retrieval augmented code completion. - #[serde(default = "default_false")] - disable_retrieval_augmented_code_completion: bool, -} - -fn default_false() -> bool { - false -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Segments { - /// Content that appears before the cursor in the editor window. - prefix: String, - - /// Content that appears after the cursor in the editor window. - suffix: Option, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Choice { - index: u32, - text: String, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct Snippet { - filepath: String, - body: String, - score: f32, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -#[schema(example=json!({ - "id": "string", - "choices": [ { "index": 0, "text": "string" } ] -}))] -pub struct CompletionResponse { - id: String, - choices: Vec, - - #[serde(skip_serializing_if = "Option::is_none")] - debug_data: Option, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct DebugData { - #[serde(skip_serializing_if = "Option::is_none")] - snippets: Option>, - - #[serde(skip_serializing_if = "Option::is_none")] - prompt: Option, -} +use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService}; #[utoipa::path( post, @@ -117,106 +19,14 @@ pub struct DebugData { )] #[instrument(skip(state, request))] pub async fn completions( - State(state): State>, + State(state): State>, Json(request): Json, ) -> Result, StatusCode> { - let language = request.language.unwrap_or("unknown".to_string()); - let options = TextGenerationOptionsBuilder::default() - .max_input_length(1024 + 512) - .max_decoding_length(128) - .sampling_temperature(0.1) - .language(get_language(&language)) - .build() - .unwrap(); - - let (prompt, segments, snippets) = if let Some(prompt) = request - .debug_options - .as_ref() - .and_then(|x| x.raw_prompt.clone()) - { - (prompt, None, vec![]) - } else if let Some(segments) = request.segments { - debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); - let (prompt, snippets) = - build_prompt(&state, &request.debug_options, &language, &segments).await; - (prompt, Some(segments), snippets) - } else { - return Err(StatusCode::BAD_REQUEST); - }; - debug!("PROMPT: {}", prompt); - - let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); - let text = state.engine.generate(&prompt, options).await; - - let segments = segments.map(|x| tabby_common::events::Segments { - prefix: x.prefix, - suffix: x.suffix, - }); - - events::Event::Completion { - completion_id: &completion_id, - language: &language, - prompt: &prompt, - segments: &segments, - choices: vec![events::Choice { - index: 0, - text: &text, - }], - user: request.user.as_deref(), - } - .log(); - - let debug_data = request - .debug_options - .as_ref() - .map(|debug_options| DebugData { - snippets: debug_options.return_snippets.then_some(snippets), - prompt: debug_options.return_prompt.then_some(prompt), - }); - - Ok(Json(CompletionResponse { - id: completion_id, - choices: vec![Choice { index: 0, text }], - debug_data, - })) -} - -async fn build_prompt( - state: &Arc, - debug_options: &Option, - language: &str, - segments: &Segments, -) -> (String, Vec) { - let snippets = if !debug_options - .as_ref() - .is_some_and(|x| x.disable_retrieval_augmented_code_completion) - { - state.prompt_builder.collect(language, segments).await - } else { - vec![] - }; - ( - state - .prompt_builder - .build(language, segments.clone(), &snippets), - snippets, - ) -} - -pub struct CompletionState { - engine: Arc, - prompt_builder: prompt::PromptBuilder, -} - -impl CompletionState { - pub fn new( - engine: Arc, - code: Arc, - prompt_template: Option, - ) -> Self { - Self { - engine, - prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)), + match state.generate(&request).await { + Ok(resp) => Ok(Json(resp)), + Err(err) => { + warn!("{}", err); + Err(StatusCode::BAD_REQUEST) } } } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 6e8c56b..6c4c184 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -31,7 +31,7 @@ use self::{ use crate::{ api::{Hit, HitDocument, SearchResponse}, fatal, - services::chat::ChatService, + services::{chat::ChatService, completions::CompletionService}, }; #[derive(OpenApi)] @@ -54,13 +54,13 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi paths(events::log_event, completions::completions, chat::completions, health::health, search::search), components(schemas( events::LogEventRequest, - completions::CompletionRequest, - completions::CompletionResponse, - completions::Segments, - completions::Choice, - completions::Snippet, - completions::DebugOptions, - completions::DebugData, + crate::services::completions::CompletionRequest, + crate::services::completions::CompletionResponse, + crate::services::completions::Segments, + crate::services::completions::Choice, + crate::services::completions::Snippet, + crate::services::completions::DebugOptions, + crate::services::completions::DebugData, crate::services::chat::ChatCompletionRequest, crate::services::chat::Message, crate::services::chat::ChatCompletionChunk, @@ -182,8 +182,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { prompt_template, .. }, ) = create_engine(&args.model, args).await; - let state = - completions::CompletionState::new(engine.clone(), code.clone(), prompt_template); + let state = CompletionService::new(engine.clone(), code.clone(), prompt_template); Arc::new(state) }; diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index 1f2c91a..37286ad 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -1,10 +1,10 @@ -mod prompt; +mod chat_prompt; use std::sync::Arc; use async_stream::stream; +use chat_prompt::ChatPromptBuilder; use futures::stream::BoxStream; -use prompt::ChatPromptBuilder; use serde::{Deserialize, Serialize}; use tabby_common::languages::EMPTY_LANGUAGE; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; @@ -47,26 +47,22 @@ impl ChatService { } } - fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) { - let mut builder = TextGenerationOptionsBuilder::default(); - - builder + fn text_generation_options() -> TextGenerationOptions { + TextGenerationOptionsBuilder::default() .max_input_length(2048) .max_decoding_length(1920) .language(&EMPTY_LANGUAGE) - .sampling_temperature(0.1); - - ( - self.prompt_builder.build(&request.messages), - builder.build().unwrap(), - ) + .sampling_temperature(0.1) + .build() + .unwrap() } pub async fn generate( &self, request: &ChatCompletionRequest, ) -> BoxStream { - let (prompt, options) = self.parse_request(request); + let prompt = self.prompt_builder.build(&request.messages); + let options = Self::text_generation_options(); debug!("PROMPT: {}", prompt); let s = stream! { for await content in self.engine.generate_stream(&prompt, options).await { diff --git a/crates/tabby/src/services/chat/prompt.rs b/crates/tabby/src/services/chat/chat_prompt.rs similarity index 100% rename from crates/tabby/src/services/chat/prompt.rs rename to crates/tabby/src/services/chat/chat_prompt.rs diff --git a/crates/tabby/src/services/completions.rs b/crates/tabby/src/services/completions.rs new file mode 100644 index 0000000..80bf2b8 --- /dev/null +++ b/crates/tabby/src/services/completions.rs @@ -0,0 +1,256 @@ +mod completions_prompt; + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use tabby_common::{events, languages::get_language}; +use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; +use thiserror::Error; +use tracing::debug; +use utoipa::ToSchema; + +use crate::api::CodeSearch; + +#[derive(Error, Debug)] +pub enum CompletionError { + #[error("empty prompt from completion request")] + EmptyPrompt, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +#[schema(example=json!({ + "language": "python", + "segments": { + "prefix": "def fib(n):\n ", + "suffix": "\n return fib(n - 1) + fib(n - 2)" + } +}))] +pub struct CompletionRequest { + /// Language identifier, full list is maintained at + /// https://code.visualstudio.com/docs/languages/identifiers + #[schema(example = "python")] + language: Option, + + /// When segments are set, the `prompt` is ignored during the inference. + segments: Option, + + /// A unique identifier representing your end-user, which can help Tabby to monitor & generating + /// reports. + user: Option, + + debug_options: Option, +} + +impl CompletionRequest { + /// Returns the language info or "unknown" if not specified. + fn language_or_unknown(&self) -> String { + self.language.clone().unwrap_or("unknown".to_string()) + } + + /// Returns the raw prompt if specified. + fn raw_prompt(&self) -> Option { + self.debug_options + .as_ref() + .and_then(|x| x.raw_prompt.clone()) + } + + /// Returns true if retrieval augmented code completion is disabled. + fn disable_retrieval_augmented_code_completion(&self) -> bool { + self.debug_options + .as_ref() + .is_some_and(|x| x.disable_retrieval_augmented_code_completion) + } +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct DebugOptions { + /// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored. + /// + /// This is useful for certain requests that aim to test the tabby's e2e quality. + raw_prompt: Option, + + /// When true, returns `snippets` in `debug_data`. + #[serde(default = "default_false")] + return_snippets: bool, + + /// When true, returns `prompt` in `debug_data`. + #[serde(default = "default_false")] + return_prompt: bool, + + /// When true, disable retrieval augmented code completion. + #[serde(default = "default_false")] + disable_retrieval_augmented_code_completion: bool, +} + +fn default_false() -> bool { + false +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Segments { + /// Content that appears before the cursor in the editor window. + prefix: String, + + /// Content that appears after the cursor in the editor window. + suffix: Option, +} + +impl From for events::Segments { + fn from(val: Segments) -> Self { + events::Segments { + prefix: val.prefix, + suffix: val.suffix, + } + } +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Choice { + index: u32, + text: String, +} + +impl Choice { + pub fn new(text: String) -> Self { + Self { index: 0, text } + } +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct Snippet { + filepath: String, + body: String, + score: f32, +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +#[schema(example=json!({ + "id": "string", + "choices": [ { "index": 0, "text": "string" } ] +}))] +pub struct CompletionResponse { + id: String, + choices: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + debug_data: Option, +} + +impl CompletionResponse { + pub fn new(id: String, choices: Vec, debug_data: Option) -> Self { + Self { + id, + choices, + debug_data, + } + } +} + +#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] +pub struct DebugData { + #[serde(skip_serializing_if = "Option::is_none")] + snippets: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + prompt: Option, +} + +pub struct CompletionService { + engine: Arc, + prompt_builder: completions_prompt::PromptBuilder, +} + +impl CompletionService { + pub fn new( + engine: Arc, + code: Arc, + prompt_template: Option, + ) -> Self { + Self { + engine, + prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)), + } + } + + async fn build_snippets( + &self, + language: &str, + segments: &Segments, + disable_retrieval_augmented_code_completion: bool, + ) -> Vec { + if !disable_retrieval_augmented_code_completion { + self.prompt_builder.collect(language, segments).await + } else { + vec![] + } + } + + fn text_generation_options(language: &str) -> TextGenerationOptions { + TextGenerationOptionsBuilder::default() + .max_input_length(1024 + 512) + .max_decoding_length(128) + .sampling_temperature(0.1) + .language(get_language(language)) + .build() + .unwrap() + } + + pub async fn generate( + &self, + request: &CompletionRequest, + ) -> Result { + let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); + let language = request.language_or_unknown(); + let options = Self::text_generation_options(language.as_str()); + + let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() { + (prompt, None, vec![]) + } else if let Some(segments) = request.segments.clone() { + debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); + let snippets = self + .build_snippets( + &language, + &segments, + request.disable_retrieval_augmented_code_completion(), + ) + .await; + let prompt = self + .prompt_builder + .build(&language, segments.clone(), &snippets); + (prompt, Some(segments), snippets) + } else { + return Err(CompletionError::EmptyPrompt); + }; + debug!("PROMPT: {}", prompt); + + let text = self.engine.generate(&prompt, options).await; + let segments = segments.map(|s| s.into()); + + events::Event::Completion { + completion_id: &completion_id, + language: &language, + prompt: &prompt, + segments: &segments, + choices: vec![events::Choice { + index: 0, + text: &text, + }], + user: request.user.as_deref(), + } + .log(); + + let debug_data = request + .debug_options + .as_ref() + .map(|debug_options| DebugData { + snippets: debug_options.return_snippets.then_some(snippets), + prompt: debug_options.return_prompt.then_some(prompt), + }); + + Ok(CompletionResponse::new( + completion_id, + vec![Choice::new(text)], + debug_data, + )) + } +} diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/services/completions/completions_prompt.rs similarity index 100% rename from crates/tabby/src/serve/completions/prompt.rs rename to crates/tabby/src/services/completions/completions_prompt.rs diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index 74476ac..a0bba28 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -1,2 +1,3 @@ pub mod chat; pub mod code; +pub mod completions;