refactor: extract CompletionState -> CompletionService (#773)
* refactor: extract CompletionState -> CompletionService * fix comment * Update README.md cmake is preinstalled in ubuntu / debian * fix compile error * format files * format files --------- Co-authored-by: darknight <illuminating.me@gmail.com>extract-routes
parent
6f1a3039b0
commit
4359b0cc4b
|
|
@ -1,108 +1,10 @@
|
|||
mod prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::State, Json};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{events, languages::get_language};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
use tracing::{instrument, warn};
|
||||
|
||||
use crate::api::CodeSearch;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"language": "python",
|
||||
"segments": {
|
||||
"prefix": "def fib(n):\n ",
|
||||
"suffix": "\n return fib(n - 1) + fib(n - 2)"
|
||||
}
|
||||
}))]
|
||||
pub struct CompletionRequest {
|
||||
/// Language identifier, full list is maintained at
|
||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||
#[schema(example = "python")]
|
||||
language: Option<String>,
|
||||
|
||||
/// When segments are set, the `prompt` is ignored during the inference.
|
||||
segments: Option<Segments>,
|
||||
|
||||
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
|
||||
/// reports.
|
||||
user: Option<String>,
|
||||
|
||||
debug_options: Option<DebugOptions>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct DebugOptions {
|
||||
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
|
||||
///
|
||||
/// This is useful for certain requests that aim to test the tabby's e2e quality.
|
||||
raw_prompt: Option<String>,
|
||||
|
||||
/// When true, returns `snippets` in `debug_data`.
|
||||
#[serde(default = "default_false")]
|
||||
return_snippets: bool,
|
||||
|
||||
/// When true, returns `prompt` in `debug_data`.
|
||||
#[serde(default = "default_false")]
|
||||
return_prompt: bool,
|
||||
|
||||
/// When true, disable retrieval augmented code completion.
|
||||
#[serde(default = "default_false")]
|
||||
disable_retrieval_augmented_code_completion: bool,
|
||||
}
|
||||
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Segments {
|
||||
/// Content that appears before the cursor in the editor window.
|
||||
prefix: String,
|
||||
|
||||
/// Content that appears after the cursor in the editor window.
|
||||
suffix: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Snippet {
|
||||
filepath: String,
|
||||
body: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"id": "string",
|
||||
"choices": [ { "index": 0, "text": "string" } ]
|
||||
}))]
|
||||
pub struct CompletionResponse {
|
||||
id: String,
|
||||
choices: Vec<Choice>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
debug_data: Option<DebugData>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct DebugData {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
snippets: Option<Vec<Snippet>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prompt: Option<String>,
|
||||
}
|
||||
use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService};
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
|
|
@ -117,106 +19,14 @@ pub struct DebugData {
|
|||
)]
|
||||
#[instrument(skip(state, request))]
|
||||
pub async fn completions(
|
||||
State(state): State<Arc<CompletionState>>,
|
||||
State(state): State<Arc<CompletionService>>,
|
||||
Json(request): Json<CompletionRequest>,
|
||||
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||
let language = request.language.unwrap_or("unknown".to_string());
|
||||
let options = TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.language(get_language(&language))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let (prompt, segments, snippets) = if let Some(prompt) = request
|
||||
.debug_options
|
||||
.as_ref()
|
||||
.and_then(|x| x.raw_prompt.clone())
|
||||
{
|
||||
(prompt, None, vec![])
|
||||
} else if let Some(segments) = request.segments {
|
||||
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||
let (prompt, snippets) =
|
||||
build_prompt(&state, &request.debug_options, &language, &segments).await;
|
||||
(prompt, Some(segments), snippets)
|
||||
} else {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
};
|
||||
debug!("PROMPT: {}", prompt);
|
||||
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let text = state.engine.generate(&prompt, options).await;
|
||||
|
||||
let segments = segments.map(|x| tabby_common::events::Segments {
|
||||
prefix: x.prefix,
|
||||
suffix: x.suffix,
|
||||
});
|
||||
|
||||
events::Event::Completion {
|
||||
completion_id: &completion_id,
|
||||
language: &language,
|
||||
prompt: &prompt,
|
||||
segments: &segments,
|
||||
choices: vec![events::Choice {
|
||||
index: 0,
|
||||
text: &text,
|
||||
}],
|
||||
user: request.user.as_deref(),
|
||||
}
|
||||
.log();
|
||||
|
||||
let debug_data = request
|
||||
.debug_options
|
||||
.as_ref()
|
||||
.map(|debug_options| DebugData {
|
||||
snippets: debug_options.return_snippets.then_some(snippets),
|
||||
prompt: debug_options.return_prompt.then_some(prompt),
|
||||
});
|
||||
|
||||
Ok(Json(CompletionResponse {
|
||||
id: completion_id,
|
||||
choices: vec![Choice { index: 0, text }],
|
||||
debug_data,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn build_prompt(
|
||||
state: &Arc<CompletionState>,
|
||||
debug_options: &Option<DebugOptions>,
|
||||
language: &str,
|
||||
segments: &Segments,
|
||||
) -> (String, Vec<Snippet>) {
|
||||
let snippets = if !debug_options
|
||||
.as_ref()
|
||||
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
||||
{
|
||||
state.prompt_builder.collect(language, segments).await
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
(
|
||||
state
|
||||
.prompt_builder
|
||||
.build(language, segments.clone(), &snippets),
|
||||
snippets,
|
||||
)
|
||||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
prompt_builder: prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
prompt_template: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)),
|
||||
match state.generate(&request).await {
|
||||
Ok(resp) => Ok(Json(resp)),
|
||||
Err(err) => {
|
||||
warn!("{}", err);
|
||||
Err(StatusCode::BAD_REQUEST)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ use self::{
|
|||
use crate::{
|
||||
api::{Hit, HitDocument, SearchResponse},
|
||||
fatal,
|
||||
services::chat::ChatService,
|
||||
services::{chat::ChatService, completions::CompletionService},
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
|
|
@ -54,13 +54,13 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
||||
components(schemas(
|
||||
events::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
completions::CompletionResponse,
|
||||
completions::Segments,
|
||||
completions::Choice,
|
||||
completions::Snippet,
|
||||
completions::DebugOptions,
|
||||
completions::DebugData,
|
||||
crate::services::completions::CompletionRequest,
|
||||
crate::services::completions::CompletionResponse,
|
||||
crate::services::completions::Segments,
|
||||
crate::services::completions::Choice,
|
||||
crate::services::completions::Snippet,
|
||||
crate::services::completions::DebugOptions,
|
||||
crate::services::completions::DebugData,
|
||||
crate::services::chat::ChatCompletionRequest,
|
||||
crate::services::chat::Message,
|
||||
crate::services::chat::ChatCompletionChunk,
|
||||
|
|
@ -182,8 +182,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
prompt_template, ..
|
||||
},
|
||||
) = create_engine(&args.model, args).await;
|
||||
let state =
|
||||
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
||||
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
mod prompt;
|
||||
mod chat_prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use chat_prompt::ChatPromptBuilder;
|
||||
use futures::stream::BoxStream;
|
||||
use prompt::ChatPromptBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
|
|
@ -47,26 +47,22 @@ impl ChatService {
|
|||
}
|
||||
}
|
||||
|
||||
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
|
||||
let mut builder = TextGenerationOptionsBuilder::default();
|
||||
|
||||
builder
|
||||
fn text_generation_options() -> TextGenerationOptions {
|
||||
TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(2048)
|
||||
.max_decoding_length(1920)
|
||||
.language(&EMPTY_LANGUAGE)
|
||||
.sampling_temperature(0.1);
|
||||
|
||||
(
|
||||
self.prompt_builder.build(&request.messages),
|
||||
builder.build().unwrap(),
|
||||
)
|
||||
.sampling_temperature(0.1)
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
) -> BoxStream<ChatCompletionChunk> {
|
||||
let (prompt, options) = self.parse_request(request);
|
||||
let prompt = self.prompt_builder.build(&request.messages);
|
||||
let options = Self::text_generation_options();
|
||||
debug!("PROMPT: {}", prompt);
|
||||
let s = stream! {
|
||||
for await content in self.engine.generate_stream(&prompt, options).await {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,256 @@
|
|||
mod completions_prompt;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{events, languages::get_language};
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||
use thiserror::Error;
|
||||
use tracing::debug;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::api::CodeSearch;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CompletionError {
|
||||
#[error("empty prompt from completion request")]
|
||||
EmptyPrompt,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"language": "python",
|
||||
"segments": {
|
||||
"prefix": "def fib(n):\n ",
|
||||
"suffix": "\n return fib(n - 1) + fib(n - 2)"
|
||||
}
|
||||
}))]
|
||||
pub struct CompletionRequest {
|
||||
/// Language identifier, full list is maintained at
|
||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||
#[schema(example = "python")]
|
||||
language: Option<String>,
|
||||
|
||||
/// When segments are set, the `prompt` is ignored during the inference.
|
||||
segments: Option<Segments>,
|
||||
|
||||
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
|
||||
/// reports.
|
||||
user: Option<String>,
|
||||
|
||||
debug_options: Option<DebugOptions>,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
/// Returns the language info or "unknown" if not specified.
|
||||
fn language_or_unknown(&self) -> String {
|
||||
self.language.clone().unwrap_or("unknown".to_string())
|
||||
}
|
||||
|
||||
/// Returns the raw prompt if specified.
|
||||
fn raw_prompt(&self) -> Option<String> {
|
||||
self.debug_options
|
||||
.as_ref()
|
||||
.and_then(|x| x.raw_prompt.clone())
|
||||
}
|
||||
|
||||
/// Returns true if retrieval augmented code completion is disabled.
|
||||
fn disable_retrieval_augmented_code_completion(&self) -> bool {
|
||||
self.debug_options
|
||||
.as_ref()
|
||||
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct DebugOptions {
|
||||
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
|
||||
///
|
||||
/// This is useful for certain requests that aim to test the tabby's e2e quality.
|
||||
raw_prompt: Option<String>,
|
||||
|
||||
/// When true, returns `snippets` in `debug_data`.
|
||||
#[serde(default = "default_false")]
|
||||
return_snippets: bool,
|
||||
|
||||
/// When true, returns `prompt` in `debug_data`.
|
||||
#[serde(default = "default_false")]
|
||||
return_prompt: bool,
|
||||
|
||||
/// When true, disable retrieval augmented code completion.
|
||||
#[serde(default = "default_false")]
|
||||
disable_retrieval_augmented_code_completion: bool,
|
||||
}
|
||||
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Segments {
|
||||
/// Content that appears before the cursor in the editor window.
|
||||
prefix: String,
|
||||
|
||||
/// Content that appears after the cursor in the editor window.
|
||||
suffix: Option<String>,
|
||||
}
|
||||
|
||||
impl From<Segments> for events::Segments {
|
||||
fn from(val: Segments) -> Self {
|
||||
events::Segments {
|
||||
prefix: val.prefix,
|
||||
suffix: val.suffix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl Choice {
|
||||
pub fn new(text: String) -> Self {
|
||||
Self { index: 0, text }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Snippet {
|
||||
filepath: String,
|
||||
body: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
"id": "string",
|
||||
"choices": [ { "index": 0, "text": "string" } ]
|
||||
}))]
|
||||
pub struct CompletionResponse {
|
||||
id: String,
|
||||
choices: Vec<Choice>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
debug_data: Option<DebugData>,
|
||||
}
|
||||
|
||||
impl CompletionResponse {
|
||||
pub fn new(id: String, choices: Vec<Choice>, debug_data: Option<DebugData>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
choices,
|
||||
debug_data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct DebugData {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
snippets: Option<Vec<Snippet>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub struct CompletionService {
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
prompt_builder: completions_prompt::PromptBuilder,
|
||||
}
|
||||
|
||||
impl CompletionService {
|
||||
pub fn new(
|
||||
engine: Arc<dyn TextGeneration>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
prompt_template: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_snippets(
|
||||
&self,
|
||||
language: &str,
|
||||
segments: &Segments,
|
||||
disable_retrieval_augmented_code_completion: bool,
|
||||
) -> Vec<Snippet> {
|
||||
if !disable_retrieval_augmented_code_completion {
|
||||
self.prompt_builder.collect(language, segments).await
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
fn text_generation_options(language: &str) -> TextGenerationOptions {
|
||||
TextGenerationOptionsBuilder::default()
|
||||
.max_input_length(1024 + 512)
|
||||
.max_decoding_length(128)
|
||||
.sampling_temperature(0.1)
|
||||
.language(get_language(language))
|
||||
.build()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
request: &CompletionRequest,
|
||||
) -> Result<CompletionResponse, CompletionError> {
|
||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||
let language = request.language_or_unknown();
|
||||
let options = Self::text_generation_options(language.as_str());
|
||||
|
||||
let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() {
|
||||
(prompt, None, vec![])
|
||||
} else if let Some(segments) = request.segments.clone() {
|
||||
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||
let snippets = self
|
||||
.build_snippets(
|
||||
&language,
|
||||
&segments,
|
||||
request.disable_retrieval_augmented_code_completion(),
|
||||
)
|
||||
.await;
|
||||
let prompt = self
|
||||
.prompt_builder
|
||||
.build(&language, segments.clone(), &snippets);
|
||||
(prompt, Some(segments), snippets)
|
||||
} else {
|
||||
return Err(CompletionError::EmptyPrompt);
|
||||
};
|
||||
debug!("PROMPT: {}", prompt);
|
||||
|
||||
let text = self.engine.generate(&prompt, options).await;
|
||||
let segments = segments.map(|s| s.into());
|
||||
|
||||
events::Event::Completion {
|
||||
completion_id: &completion_id,
|
||||
language: &language,
|
||||
prompt: &prompt,
|
||||
segments: &segments,
|
||||
choices: vec![events::Choice {
|
||||
index: 0,
|
||||
text: &text,
|
||||
}],
|
||||
user: request.user.as_deref(),
|
||||
}
|
||||
.log();
|
||||
|
||||
let debug_data = request
|
||||
.debug_options
|
||||
.as_ref()
|
||||
.map(|debug_options| DebugData {
|
||||
snippets: debug_options.return_snippets.then_some(snippets),
|
||||
prompt: debug_options.return_prompt.then_some(prompt),
|
||||
});
|
||||
|
||||
Ok(CompletionResponse::new(
|
||||
completion_id,
|
||||
vec![Choice::new(text)],
|
||||
debug_data,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod chat;
|
||||
pub mod code;
|
||||
pub mod completions;
|
||||
|
|
|
|||
Loading…
Reference in New Issue