refactor: extract CompletionState -> CompletionService (#773)
* refactor: extract CompletionState -> CompletionService * fix comment * Update README.md cmake is preinstalled in ubuntu / debian * fix compile error * format files * format files --------- Co-authored-by: darknight <illuminating.me@gmail.com>extract-routes
parent
6f1a3039b0
commit
4359b0cc4b
|
|
@ -1,108 +1,10 @@
|
||||||
mod prompt;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use tracing::{instrument, warn};
|
||||||
use tabby_common::{events, languages::get_language};
|
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
|
||||||
use tracing::{debug, instrument};
|
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
use crate::api::CodeSearch;
|
use crate::services::completions::{CompletionRequest, CompletionResponse, CompletionService};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
#[schema(example=json!({
|
|
||||||
"language": "python",
|
|
||||||
"segments": {
|
|
||||||
"prefix": "def fib(n):\n ",
|
|
||||||
"suffix": "\n return fib(n - 1) + fib(n - 2)"
|
|
||||||
}
|
|
||||||
}))]
|
|
||||||
pub struct CompletionRequest {
|
|
||||||
/// Language identifier, full list is maintained at
|
|
||||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
|
||||||
#[schema(example = "python")]
|
|
||||||
language: Option<String>,
|
|
||||||
|
|
||||||
/// When segments are set, the `prompt` is ignored during the inference.
|
|
||||||
segments: Option<Segments>,
|
|
||||||
|
|
||||||
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
|
|
||||||
/// reports.
|
|
||||||
user: Option<String>,
|
|
||||||
|
|
||||||
debug_options: Option<DebugOptions>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct DebugOptions {
|
|
||||||
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
|
|
||||||
///
|
|
||||||
/// This is useful for certain requests that aim to test the tabby's e2e quality.
|
|
||||||
raw_prompt: Option<String>,
|
|
||||||
|
|
||||||
/// When true, returns `snippets` in `debug_data`.
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
return_snippets: bool,
|
|
||||||
|
|
||||||
/// When true, returns `prompt` in `debug_data`.
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
return_prompt: bool,
|
|
||||||
|
|
||||||
/// When true, disable retrieval augmented code completion.
|
|
||||||
#[serde(default = "default_false")]
|
|
||||||
disable_retrieval_augmented_code_completion: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_false() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct Segments {
|
|
||||||
/// Content that appears before the cursor in the editor window.
|
|
||||||
prefix: String,
|
|
||||||
|
|
||||||
/// Content that appears after the cursor in the editor window.
|
|
||||||
suffix: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct Choice {
|
|
||||||
index: u32,
|
|
||||||
text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct Snippet {
|
|
||||||
filepath: String,
|
|
||||||
body: String,
|
|
||||||
score: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
#[schema(example=json!({
|
|
||||||
"id": "string",
|
|
||||||
"choices": [ { "index": 0, "text": "string" } ]
|
|
||||||
}))]
|
|
||||||
pub struct CompletionResponse {
|
|
||||||
id: String,
|
|
||||||
choices: Vec<Choice>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
debug_data: Option<DebugData>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
|
||||||
pub struct DebugData {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
snippets: Option<Vec<Snippet>>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
prompt: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
|
@ -117,106 +19,14 @@ pub struct DebugData {
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, request))]
|
#[instrument(skip(state, request))]
|
||||||
pub async fn completions(
|
pub async fn completions(
|
||||||
State(state): State<Arc<CompletionState>>,
|
State(state): State<Arc<CompletionService>>,
|
||||||
Json(request): Json<CompletionRequest>,
|
Json(request): Json<CompletionRequest>,
|
||||||
) -> Result<Json<CompletionResponse>, StatusCode> {
|
) -> Result<Json<CompletionResponse>, StatusCode> {
|
||||||
let language = request.language.unwrap_or("unknown".to_string());
|
match state.generate(&request).await {
|
||||||
let options = TextGenerationOptionsBuilder::default()
|
Ok(resp) => Ok(Json(resp)),
|
||||||
.max_input_length(1024 + 512)
|
Err(err) => {
|
||||||
.max_decoding_length(128)
|
warn!("{}", err);
|
||||||
.sampling_temperature(0.1)
|
Err(StatusCode::BAD_REQUEST)
|
||||||
.language(get_language(&language))
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let (prompt, segments, snippets) = if let Some(prompt) = request
|
|
||||||
.debug_options
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|x| x.raw_prompt.clone())
|
|
||||||
{
|
|
||||||
(prompt, None, vec![])
|
|
||||||
} else if let Some(segments) = request.segments {
|
|
||||||
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
|
||||||
let (prompt, snippets) =
|
|
||||||
build_prompt(&state, &request.debug_options, &language, &segments).await;
|
|
||||||
(prompt, Some(segments), snippets)
|
|
||||||
} else {
|
|
||||||
return Err(StatusCode::BAD_REQUEST);
|
|
||||||
};
|
|
||||||
debug!("PROMPT: {}", prompt);
|
|
||||||
|
|
||||||
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
|
||||||
let text = state.engine.generate(&prompt, options).await;
|
|
||||||
|
|
||||||
let segments = segments.map(|x| tabby_common::events::Segments {
|
|
||||||
prefix: x.prefix,
|
|
||||||
suffix: x.suffix,
|
|
||||||
});
|
|
||||||
|
|
||||||
events::Event::Completion {
|
|
||||||
completion_id: &completion_id,
|
|
||||||
language: &language,
|
|
||||||
prompt: &prompt,
|
|
||||||
segments: &segments,
|
|
||||||
choices: vec![events::Choice {
|
|
||||||
index: 0,
|
|
||||||
text: &text,
|
|
||||||
}],
|
|
||||||
user: request.user.as_deref(),
|
|
||||||
}
|
|
||||||
.log();
|
|
||||||
|
|
||||||
let debug_data = request
|
|
||||||
.debug_options
|
|
||||||
.as_ref()
|
|
||||||
.map(|debug_options| DebugData {
|
|
||||||
snippets: debug_options.return_snippets.then_some(snippets),
|
|
||||||
prompt: debug_options.return_prompt.then_some(prompt),
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(Json(CompletionResponse {
|
|
||||||
id: completion_id,
|
|
||||||
choices: vec![Choice { index: 0, text }],
|
|
||||||
debug_data,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn build_prompt(
|
|
||||||
state: &Arc<CompletionState>,
|
|
||||||
debug_options: &Option<DebugOptions>,
|
|
||||||
language: &str,
|
|
||||||
segments: &Segments,
|
|
||||||
) -> (String, Vec<Snippet>) {
|
|
||||||
let snippets = if !debug_options
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
|
||||||
{
|
|
||||||
state.prompt_builder.collect(language, segments).await
|
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
(
|
|
||||||
state
|
|
||||||
.prompt_builder
|
|
||||||
.build(language, segments.clone(), &snippets),
|
|
||||||
snippets,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CompletionState {
|
|
||||||
engine: Arc<dyn TextGeneration>,
|
|
||||||
prompt_builder: prompt::PromptBuilder,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionState {
|
|
||||||
pub fn new(
|
|
||||||
engine: Arc<dyn TextGeneration>,
|
|
||||||
code: Arc<dyn CodeSearch>,
|
|
||||||
prompt_template: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
engine,
|
|
||||||
prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ use self::{
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{Hit, HitDocument, SearchResponse},
|
api::{Hit, HitDocument, SearchResponse},
|
||||||
fatal,
|
fatal,
|
||||||
services::chat::ChatService,
|
services::{chat::ChatService, completions::CompletionService},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
|
@ -54,13 +54,13 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
events::LogEventRequest,
|
events::LogEventRequest,
|
||||||
completions::CompletionRequest,
|
crate::services::completions::CompletionRequest,
|
||||||
completions::CompletionResponse,
|
crate::services::completions::CompletionResponse,
|
||||||
completions::Segments,
|
crate::services::completions::Segments,
|
||||||
completions::Choice,
|
crate::services::completions::Choice,
|
||||||
completions::Snippet,
|
crate::services::completions::Snippet,
|
||||||
completions::DebugOptions,
|
crate::services::completions::DebugOptions,
|
||||||
completions::DebugData,
|
crate::services::completions::DebugData,
|
||||||
crate::services::chat::ChatCompletionRequest,
|
crate::services::chat::ChatCompletionRequest,
|
||||||
crate::services::chat::Message,
|
crate::services::chat::Message,
|
||||||
crate::services::chat::ChatCompletionChunk,
|
crate::services::chat::ChatCompletionChunk,
|
||||||
|
|
@ -182,8 +182,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
prompt_template, ..
|
prompt_template, ..
|
||||||
},
|
},
|
||||||
) = create_engine(&args.model, args).await;
|
) = create_engine(&args.model, args).await;
|
||||||
let state =
|
let state = CompletionService::new(engine.clone(), code.clone(), prompt_template);
|
||||||
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
mod prompt;
|
mod chat_prompt;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_stream::stream;
|
use async_stream::stream;
|
||||||
|
use chat_prompt::ChatPromptBuilder;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use prompt::ChatPromptBuilder;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::languages::EMPTY_LANGUAGE;
|
use tabby_common::languages::EMPTY_LANGUAGE;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
|
|
@ -47,26 +47,22 @@ impl ChatService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_request(&self, request: &ChatCompletionRequest) -> (String, TextGenerationOptions) {
|
fn text_generation_options() -> TextGenerationOptions {
|
||||||
let mut builder = TextGenerationOptionsBuilder::default();
|
TextGenerationOptionsBuilder::default()
|
||||||
|
|
||||||
builder
|
|
||||||
.max_input_length(2048)
|
.max_input_length(2048)
|
||||||
.max_decoding_length(1920)
|
.max_decoding_length(1920)
|
||||||
.language(&EMPTY_LANGUAGE)
|
.language(&EMPTY_LANGUAGE)
|
||||||
.sampling_temperature(0.1);
|
.sampling_temperature(0.1)
|
||||||
|
.build()
|
||||||
(
|
.unwrap()
|
||||||
self.prompt_builder.build(&request.messages),
|
|
||||||
builder.build().unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn generate(
|
pub async fn generate(
|
||||||
&self,
|
&self,
|
||||||
request: &ChatCompletionRequest,
|
request: &ChatCompletionRequest,
|
||||||
) -> BoxStream<ChatCompletionChunk> {
|
) -> BoxStream<ChatCompletionChunk> {
|
||||||
let (prompt, options) = self.parse_request(request);
|
let prompt = self.prompt_builder.build(&request.messages);
|
||||||
|
let options = Self::text_generation_options();
|
||||||
debug!("PROMPT: {}", prompt);
|
debug!("PROMPT: {}", prompt);
|
||||||
let s = stream! {
|
let s = stream! {
|
||||||
for await content in self.engine.generate_stream(&prompt, options).await {
|
for await content in self.engine.generate_stream(&prompt, options).await {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,256 @@
|
||||||
|
mod completions_prompt;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tabby_common::{events, languages::get_language};
|
||||||
|
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tracing::debug;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
use crate::api::CodeSearch;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum CompletionError {
|
||||||
|
#[error("empty prompt from completion request")]
|
||||||
|
EmptyPrompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
#[schema(example=json!({
|
||||||
|
"language": "python",
|
||||||
|
"segments": {
|
||||||
|
"prefix": "def fib(n):\n ",
|
||||||
|
"suffix": "\n return fib(n - 1) + fib(n - 2)"
|
||||||
|
}
|
||||||
|
}))]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// Language identifier, full list is maintained at
|
||||||
|
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||||
|
#[schema(example = "python")]
|
||||||
|
language: Option<String>,
|
||||||
|
|
||||||
|
/// When segments are set, the `prompt` is ignored during the inference.
|
||||||
|
segments: Option<Segments>,
|
||||||
|
|
||||||
|
/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
|
||||||
|
/// reports.
|
||||||
|
user: Option<String>,
|
||||||
|
|
||||||
|
debug_options: Option<DebugOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionRequest {
|
||||||
|
/// Returns the language info or "unknown" if not specified.
|
||||||
|
fn language_or_unknown(&self) -> String {
|
||||||
|
self.language.clone().unwrap_or("unknown".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the raw prompt if specified.
|
||||||
|
fn raw_prompt(&self) -> Option<String> {
|
||||||
|
self.debug_options
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|x| x.raw_prompt.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if retrieval augmented code completion is disabled.
|
||||||
|
fn disable_retrieval_augmented_code_completion(&self) -> bool {
|
||||||
|
self.debug_options
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct DebugOptions {
|
||||||
|
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
|
||||||
|
///
|
||||||
|
/// This is useful for certain requests that aim to test the tabby's e2e quality.
|
||||||
|
raw_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// When true, returns `snippets` in `debug_data`.
|
||||||
|
#[serde(default = "default_false")]
|
||||||
|
return_snippets: bool,
|
||||||
|
|
||||||
|
/// When true, returns `prompt` in `debug_data`.
|
||||||
|
#[serde(default = "default_false")]
|
||||||
|
return_prompt: bool,
|
||||||
|
|
||||||
|
/// When true, disable retrieval augmented code completion.
|
||||||
|
#[serde(default = "default_false")]
|
||||||
|
disable_retrieval_augmented_code_completion: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_false() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Segments {
|
||||||
|
/// Content that appears before the cursor in the editor window.
|
||||||
|
prefix: String,
|
||||||
|
|
||||||
|
/// Content that appears after the cursor in the editor window.
|
||||||
|
suffix: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Segments> for events::Segments {
|
||||||
|
fn from(val: Segments) -> Self {
|
||||||
|
events::Segments {
|
||||||
|
prefix: val.prefix,
|
||||||
|
suffix: val.suffix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Choice {
|
||||||
|
index: u32,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Choice {
|
||||||
|
pub fn new(text: String) -> Self {
|
||||||
|
Self { index: 0, text }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct Snippet {
|
||||||
|
filepath: String,
|
||||||
|
body: String,
|
||||||
|
score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
#[schema(example=json!({
|
||||||
|
"id": "string",
|
||||||
|
"choices": [ { "index": 0, "text": "string" } ]
|
||||||
|
}))]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
id: String,
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
debug_data: Option<DebugData>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionResponse {
|
||||||
|
pub fn new(id: String, choices: Vec<Choice>, debug_data: Option<DebugData>) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
choices,
|
||||||
|
debug_data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
|
pub struct DebugData {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
snippets: Option<Vec<Snippet>>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
prompt: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CompletionService {
|
||||||
|
engine: Arc<dyn TextGeneration>,
|
||||||
|
prompt_builder: completions_prompt::PromptBuilder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionService {
|
||||||
|
pub fn new(
|
||||||
|
engine: Arc<dyn TextGeneration>,
|
||||||
|
code: Arc<dyn CodeSearch>,
|
||||||
|
prompt_template: Option<String>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
engine,
|
||||||
|
prompt_builder: completions_prompt::PromptBuilder::new(prompt_template, Some(code)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn build_snippets(
|
||||||
|
&self,
|
||||||
|
language: &str,
|
||||||
|
segments: &Segments,
|
||||||
|
disable_retrieval_augmented_code_completion: bool,
|
||||||
|
) -> Vec<Snippet> {
|
||||||
|
if !disable_retrieval_augmented_code_completion {
|
||||||
|
self.prompt_builder.collect(language, segments).await
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn text_generation_options(language: &str) -> TextGenerationOptions {
|
||||||
|
TextGenerationOptionsBuilder::default()
|
||||||
|
.max_input_length(1024 + 512)
|
||||||
|
.max_decoding_length(128)
|
||||||
|
.sampling_temperature(0.1)
|
||||||
|
.language(get_language(language))
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
request: &CompletionRequest,
|
||||||
|
) -> Result<CompletionResponse, CompletionError> {
|
||||||
|
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
|
||||||
|
let language = request.language_or_unknown();
|
||||||
|
let options = Self::text_generation_options(language.as_str());
|
||||||
|
|
||||||
|
let (prompt, segments, snippets) = if let Some(prompt) = request.raw_prompt() {
|
||||||
|
(prompt, None, vec![])
|
||||||
|
} else if let Some(segments) = request.segments.clone() {
|
||||||
|
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||||
|
let snippets = self
|
||||||
|
.build_snippets(
|
||||||
|
&language,
|
||||||
|
&segments,
|
||||||
|
request.disable_retrieval_augmented_code_completion(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let prompt = self
|
||||||
|
.prompt_builder
|
||||||
|
.build(&language, segments.clone(), &snippets);
|
||||||
|
(prompt, Some(segments), snippets)
|
||||||
|
} else {
|
||||||
|
return Err(CompletionError::EmptyPrompt);
|
||||||
|
};
|
||||||
|
debug!("PROMPT: {}", prompt);
|
||||||
|
|
||||||
|
let text = self.engine.generate(&prompt, options).await;
|
||||||
|
let segments = segments.map(|s| s.into());
|
||||||
|
|
||||||
|
events::Event::Completion {
|
||||||
|
completion_id: &completion_id,
|
||||||
|
language: &language,
|
||||||
|
prompt: &prompt,
|
||||||
|
segments: &segments,
|
||||||
|
choices: vec![events::Choice {
|
||||||
|
index: 0,
|
||||||
|
text: &text,
|
||||||
|
}],
|
||||||
|
user: request.user.as_deref(),
|
||||||
|
}
|
||||||
|
.log();
|
||||||
|
|
||||||
|
let debug_data = request
|
||||||
|
.debug_options
|
||||||
|
.as_ref()
|
||||||
|
.map(|debug_options| DebugData {
|
||||||
|
snippets: debug_options.return_snippets.then_some(snippets),
|
||||||
|
prompt: debug_options.return_prompt.then_some(prompt),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(CompletionResponse::new(
|
||||||
|
completion_id,
|
||||||
|
vec![Choice::new(text)],
|
||||||
|
debug_data,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod code;
|
pub mod code;
|
||||||
|
pub mod completions;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue