refactor: extract CompletionState -> CompletionService (#773)

* refactor: extract CompletionState -> CompletionService

* fix comment

* Update README.md

cmake is preinstalled in ubuntu / debian

* fix compile error

* format files

* format files

---------

Co-authored-by: darknight <illuminating.me@gmail.com>
extract-routes
Meng Zhang 2023-11-12 16:14:58 -08:00 committed by GitHub
parent 6f1a3039b0
commit 4359b0cc4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 283 additions and 221 deletions

View File

@ -1,108 +1,10 @@
mod prompt;
use std::sync::Arc;
use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use tracing::{instrument, warn};
use crate::api::CodeSearch;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"language": "python",
"segments": {
"prefix": "def fib(n):\n ",
"suffix": "\n return fib(n - 1) + fib(n - 2)"
}
}))]
pub struct CompletionRequest {
/// Language identifier, full list is maintained at
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
language: Option<String>,
/// When segments are set, the `prompt` is ignored during the inference.
segments: Option<Segments>,
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
/// reports.
user: Option<String>,
debug_options: Option<DebugOptions>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugOptions {
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
///
/// This is useful for certain requests that aim to test the tabby's e2e quality.
raw_prompt: Option<String>,
/// When true, returns `snippets` in `debug_data`.
#[serde(default = "default_false")]
return_snippets: bool,
/// When true, returns `prompt` in `debug_data`.
#[serde(default = "default_false")]
return_prompt: bool,
/// When true, disable retrieval augmented code completion.
#[serde(default = "default_false")]
disable_retrieval_augmented_code_completion: bool,
}
fn default_false() -> bool {
false
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Segments {
/// Content that appears before the cursor in the editor window.
prefix: String,
/// Content that appears after the cursor in the editor window.
suffix: Option<String>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Choice {
index: u32,
text: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Snippet {
filepath: String,
body: String,
score: f32,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"id": "string",
"choices": [ { "index": 0, "text": "string" } ]
}))]
pub struct CompletionResponse {
id: String,
choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
debug_data: Option<DebugData>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugData {
#[serde(skip_serializing_if = "Option::is_none")]
snippets: Option<Vec<Snippet>>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
}
use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService};
#[utoipa::path(
post,
@ -117,106 +19,14 @@ pub struct DebugData {
)]
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<CompletionState>>,
State(state): State<Arc<CompletionService>>,
Json(request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> {
let language = request.language.unwrap_or("unknown".to_string());
let options = TextGenerationOptionsBuilder::default()
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.language(get_language(&language))
.build()
.unwrap();
let (prompt, segments, snippets) = if let Some(prompt) = request
.debug_options
.as_ref()
.and_then(|x| x.raw_prompt.clone())
{
(prompt, None, vec![])
} else if let Some(segments) = request.segments {
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let (prompt, snippets) =
build_prompt(&state, &request.debug_options, &language, &segments).await;
(prompt, Some(segments), snippets)
} else {
return Err(StatusCode::BAD_REQUEST);
};
debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.generate(&prompt, options).await;
let segments = segments.map(|x| tabby_common::events::Segments {
prefix: x.prefix,
suffix: x.suffix,
});
events::Event::Completion {
completion_id: &completion_id,
language: &language,
prompt: &prompt,
segments: &segments,
choices: vec![events::Choice {
index: 0,
text: &text,
}],
user: request.user.as_deref(),
}
.log();
let debug_data = request
.debug_options
.as_ref()
.map(|debug_options| DebugData {
snippets: debug_options.return_snippets.then_some(snippets),
prompt: debug_options.return_prompt.then_some(prompt),
});
Ok(Json(CompletionResponse {
id: completion_id,
choices: vec![Choice { index: 0, text }],
debug_data,
}))
}
async fn build_prompt(
state: &Arc<CompletionState>,
debug_options: &Option<DebugOptions>,
language: &str,
segments: &Segments,
) -> (String, Vec<Snippet>) {
let snippets = if !debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
{
state.prompt_builder.collect(language, segments).await
} else {
vec![]
};
(
state
.prompt_builder
.build(language, segments.clone(), &snippets),
snippets,
)
}
pub struct CompletionState {
engine: Arc<dyn TextGeneration>,
prompt_builder: prompt::PromptBuilder,
}
impl CompletionState {
pub fn new(
engine: Arc<dyn TextGeneration>,
code: Arc<dyn CodeSearch>,
prompt_template: Option<String>,
) -> Self {
Self {
engine,
prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)),
match state.generate(&request).await {
Ok(resp) => Ok(Json(resp)),
Err(err) => {
warn!("{}", err);
Err(StatusCode::BAD_REQUEST)
}
}
}

View File

@ -31,7 +31,7 @@ use self::{
use crate::{
api::{Hit, HitDocument, SearchResponse},
fatal,
services::chat::ChatService,
services::{chat::ChatService, completions::CompletionService},
};
#[derive(OpenApi)]
@ -54,13 +54,13 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
components(schemas(
events::LogEventRequest,
completions::CompletionRequest,
completions::CompletionResponse,
completions::Segments,
completions::Choice,
completions::Snippet,
completions::DebugOptions,
completions::DebugData,
crate::services::completions::CompletionRequest,
crate::services::completions::CompletionResponse,
crate::services::completions::Segments,
crate::services::completions::Choice,
crate::services::completions::Snippet,
crate::services::completions::DebugOptions,
crate::services::completions::DebugData,
crate::services::chat::ChatCompletionRequest,
crate::services::chat::Message,
crate::services::chat::ChatCompletionChunk,
@ -182,8 +182,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
prompt_template, ..
},
) = create_engine(&args.model, args).await;
let state =
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
Arc::new(state)
};

View File

@ -1,10 +1,10 @@
mod prompt;
mod chat_prompt;
use std::sync::Arc;
use async_stream::stream;
use chat_prompt::ChatPromptBuilder;
use futures::stream::BoxStream;
use prompt::ChatPromptBuilder;
use serde::{Deserialize, Serialize};
use tabby_common::languages::EMPTY_LANGUAGE;
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
@ -47,26 +47,22 @@ impl ChatService {
}
}
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
let mut builder = TextGenerationOptionsBuilder::default();
builder
fn text_generation_options() -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
.max_input_length(2048)
.max_decoding_length(1920)
.language(&EMPTY_LANGUAGE)
.sampling_temperature(0.1);
(
self.prompt_builder.build(&request.messages),
builder.build().unwrap(),
)
.sampling_temperature(0.1)
.build()
.unwrap()
}
pub async fn generate(
&self,
request: &ChatCompletionRequest,
) -> BoxStream<ChatCompletionChunk> {
let (prompt, options) = self.parse_request(request);
let prompt = self.prompt_builder.build(&request.messages);
let options = Self::text_generation_options();
debug!("PROMPT: {}", prompt);
let s = stream! {
for await content in self.engine.generate_stream(&prompt, options).await {

View File

@ -0,0 +1,256 @@
mod completions_prompt;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
use thiserror::Error;
use tracing::debug;
use utoipa::ToSchema;
use crate::api::CodeSearch;
#[derive(Error, Debug)]
pub enum CompletionError {
#[error("empty prompt from completion request")]
EmptyPrompt,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"language": "python",
"segments": {
"prefix": "def fib(n):\n ",
"suffix": "\n return fib(n - 1) + fib(n - 2)"
}
}))]
pub struct CompletionRequest {
/// Language identifier, full list is maintained at
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
language: Option<String>,
/// When segments are set, the `prompt` is ignored during the inference.
segments: Option<Segments>,
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
/// reports.
user: Option<String>,
debug_options: Option<DebugOptions>,
}
impl CompletionRequest {
/// Returns the language info or "unknown" if not specified.
fn language_or_unknown(&self) -> String {
self.language.clone().unwrap_or("unknown".to_string())
}
/// Returns the raw prompt if specified.
fn raw_prompt(&self) -> Option<String> {
self.debug_options
.as_ref()
.and_then(|x| x.raw_prompt.clone())
}
/// Returns true if retrieval augmented code completion is disabled.
fn disable_retrieval_augmented_code_completion(&self) -> bool {
self.debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugOptions {
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
///
/// This is useful for certain requests that aim to test the tabby's e2e quality.
raw_prompt: Option<String>,
/// When true, returns `snippets` in `debug_data`.
#[serde(default = "default_false")]
return_snippets: bool,
/// When true, returns `prompt` in `debug_data`.
#[serde(default = "default_false")]
return_prompt: bool,
/// When true, disable retrieval augmented code completion.
#[serde(default = "default_false")]
disable_retrieval_augmented_code_completion: bool,
}
fn default_false() -> bool {
false
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Segments {
/// Content that appears before the cursor in the editor window.
prefix: String,
/// Content that appears after the cursor in the editor window.
suffix: Option<String>,
}
impl From<Segments> for events::Segments {
fn from(val: Segments) -> Self {
events::Segments {
prefix: val.prefix,
suffix: val.suffix,
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Choice {
index: u32,
text: String,
}
impl Choice {
pub fn new(text: String) -> Self {
Self { index: 0, text }
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Snippet {
filepath: String,
body: String,
score: f32,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"id": "string",
"choices": [ { "index": 0, "text": "string" } ]
}))]
pub struct CompletionResponse {
id: String,
choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
debug_data: Option<DebugData>,
}
impl CompletionResponse {
pub fn new(id: String, choices: Vec<Choice>, debug_data: Option<DebugData>) -> Self {
Self {
id,
choices,
debug_data,
}
}
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugData {
#[serde(skip_serializing_if = "Option::is_none")]
snippets: Option<Vec<Snippet>>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
}
pub struct CompletionService {
engine: Arc<dyn TextGeneration>,
prompt_builder: completions_prompt::PromptBuilder,
}
impl CompletionService {
pub fn new(
engine: Arc<dyn TextGeneration>,
code: Arc<dyn CodeSearch>,
prompt_template: Option<String>,
) -> Self {
Self {
engine,
prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)),
}
}
async fn build_snippets(
&self,
language: &str,
segments: &Segments,
disable_retrieval_augmented_code_completion: bool,
) -> Vec<Snippet> {
if !disable_retrieval_augmented_code_completion {
self.prompt_builder.collect(language, segments).await
} else {
vec![]
}
}
fn text_generation_options(language: &str) -> TextGenerationOptions {
TextGenerationOptionsBuilder::default()
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.language(get_language(language))
.build()
.unwrap()
}
pub async fn generate(
&self,
request: &CompletionRequest,
) -> Result<CompletionResponse, CompletionError> {
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let language = request.language_or_unknown();
let options = Self::text_generation_options(language.as_str());
let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() {
(prompt, None, vec![])
} else if let Some(segments) = request.segments.clone() {
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let snippets = self
.build_snippets(
&language,
&segments,
request.disable_retrieval_augmented_code_completion(),
)
.await;
let prompt = self
.prompt_builder
.build(&language, segments.clone(), &snippets);
(prompt, Some(segments), snippets)
} else {
return Err(CompletionError::EmptyPrompt);
};
debug!("PROMPT: {}", prompt);
let text = self.engine.generate(&prompt, options).await;
let segments = segments.map(|s| s.into());
events::Event::Completion {
completion_id: &completion_id,
language: &language,
prompt: &prompt,
segments: &segments,
choices: vec![events::Choice {
index: 0,
text: &text,
}],
user: request.user.as_deref(),
}
.log();
let debug_data = request
.debug_options
.as_ref()
.map(|debug_options| DebugData {
snippets: debug_options.return_snippets.then_some(snippets),
prompt: debug_options.return_prompt.then_some(prompt),
});
Ok(CompletionResponse::new(
completion_id,
vec![Choice::new(text)],
debug_data,
))
}
}

View File

@ -1,2 +1,3 @@
pub mod chat;
pub mod code;
pub mod completions;