feat(eval): add debug_options.raw_prompt to tabby api for evaluation purpose (#605)

* update eval

* feat: re-purpose  as raw input to LLM

* move prompt to Debug options

* Update crates/tabby/src/serve/completions.rs
r0.4
Meng Zhang 2023-10-21 13:47:44 -07:00 committed by GitHub
parent 049ebdf9a9
commit 8fca850037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 48 deletions

View File

@ -56,15 +56,15 @@ pub enum Event<'a> {
completion_id: &'a str, completion_id: &'a str,
language: &'a str, language: &'a str,
prompt: &'a str, prompt: &'a str,
segments: &'a Segments<'a>, segments: &'a Option<Segments>,
choices: Vec<Choice<'a>>, choices: Vec<Choice<'a>>,
user: Option<&'a str>, user: Option<&'a str>,
}, },
} }
#[derive(Serialize)] #[derive(Serialize)]
pub struct Segments<'a> { pub struct Segments {
pub prefix: &'a str, pub prefix: String,
pub suffix: Option<&'a str>, pub suffix: Option<String>,
} }
#[derive(Serialize)] #[derive(Serialize)]

View File

@ -21,10 +21,6 @@ use super::search::IndexServer;
} }
}))] }))]
pub struct CompletionRequest { pub struct CompletionRequest {
#[deprecated]
#[schema(example = "def fib(n):")]
prompt: Option<String>,
/// Language identifier, full list is maintained at /// Language identifier, full list is maintained at
/// https://code.visualstudio.com/docs/languages/identifiers /// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")] #[schema(example = "python")]
@ -42,9 +38,18 @@ pub struct CompletionRequest {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugOptions { pub struct DebugOptions {
/// When true, returns debug_data in completion response. /// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
///
/// This is useful for certain requests that aim to test the tabby's e2e quality.
raw_prompt: Option<String>,
/// When true, returns `snippets` in `debug_data`.
#[serde(default = "default_false")] #[serde(default = "default_false")]
enabled: bool, return_snippets: bool,
/// When true, returns `prompt` in `debug_data`.
#[serde(default = "default_false")]
return_prompt: bool,
/// When true, disable retrieval augmented code completion. /// When true, disable retrieval augmented code completion.
#[serde(default = "default_false")] #[serde(default = "default_false")]
@ -92,10 +97,11 @@ pub struct CompletionResponse {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugData { pub struct DebugData {
#[serde(skip_serializing_if = "Vec::is_empty")] #[serde(skip_serializing_if = "Option::is_none")]
snippets: Vec<Snippet>, snippets: Option<Vec<Snippet>>,
prompt: String, #[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
} }
#[utoipa::path( #[utoipa::path(
@ -123,42 +129,34 @@ pub async fn completions(
.build() .build()
.unwrap(); .unwrap();
let segments = if let Some(segments) = request.segments { let (prompt, segments, snippets) = if let Some(prompt) = request
segments .debug_options
} else if let Some(prompt) = request.prompt { .as_ref()
Segments { .and_then(|x| x.raw_prompt.clone())
prefix: prompt, {
suffix: None, (prompt, None, vec![])
} } else if let Some(segments) = request.segments {
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let (prompt, snippets) = build_prompt(&state, &request.debug_options, &language, &segments);
(prompt, Some(segments), snippets)
} else { } else {
return Err(StatusCode::BAD_REQUEST); return Err(StatusCode::BAD_REQUEST);
}; };
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let snippets = if !request
.debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
{
state.prompt_builder.collect(&language, &segments)
} else {
vec![]
};
let prompt = state
.prompt_builder
.build(&language, segments.clone(), &snippets);
debug!("PROMPT: {}", prompt); debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.generate(&prompt, options).await; let text = state.engine.generate(&prompt, options).await;
let segments = segments.map(|x| tabby_common::events::Segments {
prefix: x.prefix,
suffix: x.suffix,
});
events::Event::Completion { events::Event::Completion {
completion_id: &completion_id, completion_id: &completion_id,
language: &language, language: &language,
prompt: &prompt, prompt: &prompt,
segments: &tabby_common::events::Segments { segments: &segments,
prefix: &segments.prefix,
suffix: segments.suffix.as_deref(),
},
choices: vec![events::Choice { choices: vec![events::Choice {
index: 0, index: 0,
text: &text, text: &text,
@ -167,19 +165,43 @@ pub async fn completions(
} }
.log(); .log();
let debug_data = DebugData { snippets, prompt }; let debug_data = request
.debug_options
.as_ref()
.map(|debug_options| DebugData {
snippets: debug_options.return_snippets.then_some(snippets),
prompt: debug_options.return_prompt.then_some(prompt),
});
Ok(Json(CompletionResponse { Ok(Json(CompletionResponse {
id: completion_id, id: completion_id,
choices: vec![Choice { index: 0, text }], choices: vec![Choice { index: 0, text }],
debug_data: if request.debug_options.is_some_and(|x| x.enabled) { debug_data,
Some(debug_data)
} else {
None
},
})) }))
} }
fn build_prompt(
state: &Arc<CompletionState>,
debug_options: &Option<DebugOptions>,
language: &str,
segments: &Segments,
) -> (String, Vec<Snippet>) {
let snippets = if !debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
{
state.prompt_builder.collect(language, segments)
} else {
vec![]
};
(
state
.prompt_builder
.build(language, segments.clone(), &snippets),
snippets,
)
}
pub struct CompletionState { pub struct CompletionState {
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: prompt::PromptBuilder, prompt_builder: prompt::PromptBuilder,

View File

@ -82,10 +82,10 @@ class Model:
@method() @method()
async def complete(self, language: str, prompt: str): async def complete(self, language: str, prompt: str):
from tabby_client.api.v1 import completion from tabby_client.api.v1 import completion
from tabby_client.models import CompletionRequest, CompletionResponse, Segments from tabby_client.models import CompletionRequest, DebugOptions, CompletionResponse, Segments
request = CompletionRequest( request = CompletionRequest(
language=language, prompt=prompt language=language, debug_options=DebugOptions(raw_prompt=prompt)
) )
resp: CompletionResponse = await completion.asyncio(client=self.client, json_body=request) resp: CompletionResponse = await completion.asyncio(client=self.client, json_body=request)
return resp.choices[0].text return resp.choices[0].text

View File

@ -11,7 +11,7 @@ language = st.text_input("Language", "rust")
query = st.text_area("Query", "to_owned") query = st.text_area("Query", "to_owned")
if query: if query:
r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(enabled=True))) r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(return_snippets=True, return_prompt=True)))
json = r.json() json = r.json()
debug = json["debug_data"] debug = json["debug_data"]
snippets = debug.get("snippets", []) snippets = debug.get("snippets", [])
@ -25,4 +25,4 @@ if query:
for x in snippets: for x in snippets:
st.write(f"**{x['filepath']}**: {x['score']}") st.write(f"**{x['filepath']}**: {x['score']}")
st.write(f"Length: {len(x['body'])}") st.write(f"Length: {len(x['body'])}")
st.code(x['body']) st.code(x['body'])