feat: add debug request / response to visualize prompting with source code index (#544)

* feat: logs segments in completion log

* feat: tune prompt format and improve testing

* add debug options for easier of visualizing the prompt

* update
r0.3
Meng Zhang 2023-10-12 19:27:52 -07:00 committed by GitHub
parent b9d5c12d96
commit 1ad871e1ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 205 additions and 141 deletions

32
Cargo.lock generated
View File

@ -1931,9 +1931,9 @@ dependencies = [
[[package]]
name = "num-traits"
version = "0.2.15"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
]
@ -2504,14 +2504,14 @@ dependencies = [
[[package]]
name = "regex"
version = "1.9.5"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47"
checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87"
dependencies = [
"aho-corasick 1.0.1",
"memchr",
"regex-automata 0.3.8",
"regex-syntax 0.7.5",
"regex-automata 0.4.1",
"regex-syntax 0.8.1",
]
[[package]]
@ -2525,13 +2525,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.3.8"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795"
checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b"
dependencies = [
"aho-corasick 1.0.1",
"memchr",
"regex-syntax 0.7.5",
"regex-syntax 0.8.1",
]
[[package]]
@ -2546,6 +2546,12 @@ version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]]
name = "regex-syntax"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33"
[[package]]
name = "reqwest"
version = "0.11.18"
@ -3107,6 +3113,7 @@ dependencies = [
"nvml-wrapper",
"opentelemetry",
"opentelemetry-otlp",
"regex",
"rust-embed 8.0.0",
"serde",
"serde_json",
@ -3119,6 +3126,7 @@ dependencies = [
"tabby-inference",
"tabby-scheduler",
"tantivy",
"textdistance",
"tokio",
"tower",
"tower-http 0.4.0",
@ -3383,6 +3391,12 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "textdistance"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d321c8576c2b47e43953e9cce236550d4cd6af0a6ce518fe084340082ca6037b"
[[package]]
name = "thiserror"
version = "1.0.45"

View File

@ -36,3 +36,4 @@ derive_builder = "0.12.0"
tokenizers = "0.13.4-rc3"
futures = "0.3.28"
async-stream = "0.3.5"
regex = "1.10.0"

View File

@ -56,10 +56,16 @@ pub enum Event<'a> {
completion_id: &'a str,
language: &'a str,
prompt: &'a str,
segments: &'a Segments<'a>,
choices: Vec<Choice<'a>>,
user: Option<&'a str>,
},
}
#[derive(Serialize)]
pub struct Segments<'a> {
pub prefix: &'a str,
pub suffix: Option<&'a str>,
}
#[derive(Serialize)]
struct Log<'a> {

View File

@ -11,5 +11,5 @@ async-trait = { workspace = true }
dashmap = "5.5.3"
derive_builder = "0.12.0"
futures = { workspace = true }
regex = "1.9.5"
regex.workspace = true
tokenizers.workspace = true

View File

@ -40,6 +40,8 @@ futures = { workspace = true }
async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] }
textdistance = "1.0.2"
regex.workspace = true
[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }

View File

@ -23,6 +23,7 @@ use super::search::IndexServer;
}
}))]
pub struct CompletionRequest {
#[deprecated]
#[schema(example = "def fib(n):")]
prompt: Option<String>,
@ -37,6 +38,14 @@ pub struct CompletionRequest {
// A unique identifier representing your end-user, which can help Tabby to monitor & generating
// reports.
user: Option<String>,
debug: Option<DebugRequest>,
}
#[derive(Serialize, ToSchema, Deserialize, Clone, Debug)]
pub struct DebugRequest {
// When true, returns debug_data in completion response.
enabled: bool,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
@ -55,9 +64,31 @@ pub struct Choice {
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Snippet {
filepath: String,
body: String,
score: f32,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"id": "string",
"choices": [ { "index": 0, "text": "string" } ]
}))]
pub struct CompletionResponse {
id: String,
choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
debug_data: Option<DebugData>,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugData {
#[serde(skip_serializing_if = "Vec::is_empty")]
snippets: Vec<Snippet>,
prompt: String,
}
#[utoipa::path(
@ -97,7 +128,10 @@ pub async fn completions(
};
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let prompt = state.prompt_builder.build(&language, segments);
let snippets = state.prompt_builder.collect(&language, &segments);
let prompt = state
.prompt_builder
.build(&language, segments.clone(), &snippets);
debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.generate(&prompt, options).await;
@ -106,6 +140,10 @@ pub async fn completions(
completion_id: &completion_id,
language: &language,
prompt: &prompt,
segments: &tabby_common::events::Segments {
prefix: &segments.prefix,
suffix: segments.suffix.as_deref(),
},
choices: vec![events::Choice {
index: 0,
text: &text,
@ -114,9 +152,16 @@ pub async fn completions(
}
.log();
let debug_data = DebugData { snippets, prompt };
Ok(Json(CompletionResponse {
id: completion_id,
choices: vec![Choice { index: 0, text }],
debug_data: if request.debug.is_some_and(|x| x.enabled) {
Some(debug_data)
} else {
None
},
}))
}

View File

@ -1,14 +1,17 @@
use std::sync::Arc;
use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use textdistance::Algorithm;
use tracing::warn;
use super::Segments;
use super::{Segments, Snippet};
use crate::serve::{completions::languages::get_language, search::IndexServer};
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
static SNIPPET_SCORE_THRESHOLD: f32 = 5.0;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder {
prompt_template: Option<String>,
@ -23,47 +26,36 @@ impl PromptBuilder {
}
}
fn build_prompt(&self, prefix: String, suffix: String) -> String {
if let Some(prompt_template) = &self.prompt_template {
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
} else {
prefix
}
fn build_prompt(&self, prefix: String, suffix: Option<String>) -> String {
let Some(suffix) = suffix else {
return prefix;
};
let Some(prompt_template) = &self.prompt_template else {
return prefix;
};
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
}
pub fn build(&self, language: &str, segments: Segments) -> String {
let segments = self.rewrite(language, segments);
self.build_prompt(segments.prefix, get_default_suffix(segments.suffix))
}
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
pub fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(index_server) = &self.index_server {
rewrite_with_index(index_server, language, segments)
collect_snippets(index_server, language, &segments.prefix)
} else {
segments
vec![]
}
}
}
fn get_default_suffix(suffix: Option<String>) -> String {
if suffix.is_none() {
return "\n".to_owned();
}
let suffix = suffix.unwrap();
if suffix.is_empty() {
"\n".to_owned()
} else {
suffix
pub fn build(&self, language: &str, segments: Segments, snippets: &[Snippet]) -> String {
let segments = rewrite_with_snippets(language, segments, snippets);
self.build_prompt(
segments.prefix,
segments.suffix.filter(|x| !x.trim_end().is_empty()),
)
}
}
fn rewrite_with_index(
index_server: &Arc<IndexServer>,
language: &str,
segments: Segments,
) -> Segments {
let snippets = collect_snippets(index_server, language, &segments.prefix);
fn rewrite_with_snippets(language: &str, segments: Segments, snippets: &[Snippet]) -> Segments {
if snippets.is_empty() {
segments
} else {
@ -72,35 +64,23 @@ fn rewrite_with_index(
}
}
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
if snippets.is_empty() {
return prefix.to_owned();
}
let comment_char = get_language(language).line_comment;
let mut lines: Vec<String> = vec![
format!(
"Below are some relevant {} snippets found in the repository:",
language
),
"".to_owned(),
];
let mut lines: Vec<String> = vec![];
let mut count_characters = 0;
for (i, snippet) in snippets.iter().enumerate() {
if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT {
break;
}
lines.push(format!("== Snippet {} ==", i + 1));
for line in snippet.lines() {
lines.push(format!("Path: {}", snippet.filepath));
for line in snippet.body.lines() {
lines.push(line.to_owned());
}
if i < snippets.len() - 1 {
lines.push("".to_owned());
}
count_characters += snippet.len();
}
let commented_lines: Vec<String> = lines
@ -117,9 +97,10 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
format!("{}\n{}", comments, prefix)
}
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<String> {
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<Snippet> {
let mut ret = Vec::new();
let sanitized_text = sanitize_text(text);
let tokens = tokenize_text(text);
let sanitized_text = tokens.join(" ");
if sanitized_text.is_empty() {
return ret;
}
@ -134,35 +115,49 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
}
};
let mut count_characters = 0;
for hit in serp.hits {
if hit.score < SNIPPET_SCORE_THRESHOLD {
let body = hit.doc.body;
let body_tokens = tokenize_text(&body);
if count_characters + body.len() > MAX_SNIPPET_CHARS_IN_PROMPT {
break;
}
let body = hit.doc.body;
let similarity = if body_tokens.len() > tokens.len() {
0.0
} else {
let distance = textdistance::LCSSeq::default()
.for_iter(tokens.iter(), body_tokens.iter())
.val() as f32;
distance / body_tokens.len() as f32
};
if text.contains(&body) {
// Exclude snippets already in the context window.
if similarity > MAX_SIMILARITY_THRESHOLD {
// Exclude snippets presents in context window.
continue;
}
ret.push(body.to_owned());
count_characters += body.len();
ret.push(Snippet {
filepath: hit.doc.filepath,
body,
score: hit.score,
});
}
ret
}
fn sanitize_text(text: &str) -> String {
// Only keep [a-zA-Z0-9-_]
let x = text.replace(
|c: char| !c.is_ascii_digit() && !c.is_alphabetic() && c != '_' && c != '-',
" ",
);
let tokens: Vec<&str> = x
.split(' ')
.filter(|x| *x != "AND" && *x != "NOT" && *x != "OR" && x.len() > 5)
.collect();
tokens.join(" ")
lazy_static! {
static ref TOKENIZER: Regex = Regex::new(r"[^\w]").unwrap();
}
fn tokenize_text(text: &str) -> Vec<&str> {
TOKENIZER
.split(text)
.filter(|s| *s != "AND" && *s != "OR" && *s != "NOT")
.collect()
}
#[cfg(test)]
@ -187,6 +182,7 @@ mod tests {
// Rewrite disabled, so the language doesn't matter.
let language = "python";
let snippets = &vec![];
// Test w/ prefix, w/ suffix.
{
@ -195,7 +191,7 @@ mod tests {
suffix: Some("this is some suffix".into()),
};
assert_eq!(
pb.build(language, segments),
pb.build(language, segments, snippets),
"<PRE> this is some prefix <SUF>this is some suffix <MID>"
);
}
@ -207,8 +203,8 @@ mod tests {
suffix: None,
};
assert_eq!(
pb.build(language, segments),
"<PRE> this is some prefix <SUF>\n <MID>"
pb.build(language, segments, snippets),
"this is some prefix"
);
}
@ -219,8 +215,8 @@ mod tests {
suffix: Some("".into()),
};
assert_eq!(
pb.build(language, segments),
"<PRE> this is some prefix <SUF>\n <MID>"
pb.build(language, segments, snippets),
"this is some prefix"
);
}
@ -231,7 +227,7 @@ mod tests {
suffix: Some("this is some suffix".into()),
};
assert_eq!(
pb.build(language, segments),
pb.build(language, segments, snippets),
"<PRE> <SUF>this is some suffix <MID>"
);
}
@ -242,7 +238,7 @@ mod tests {
prefix: "".into(),
suffix: None,
};
assert_eq!(pb.build(language, segments), "<PRE> <SUF>\n <MID>");
assert_eq!(pb.build(language, segments, snippets), "");
}
// Test w/ emtpy prefix, w/ empty suffix.
@ -251,7 +247,7 @@ mod tests {
prefix: "".into(),
suffix: Some("".into()),
};
assert_eq!(pb.build(language, segments), "<PRE> <SUF>\n <MID>");
assert_eq!(pb.build(language, segments, snippets), "");
}
}
@ -261,6 +257,7 @@ mod tests {
// Rewrite disabled, so the language doesn't matter.
let language = "python";
let snippets = &vec![];
// Test w/ prefix, w/ suffix.
{
@ -268,7 +265,10 @@ mod tests {
prefix: "this is some prefix".into(),
suffix: Some("this is some suffix".into()),
};
assert_eq!(pb.build(language, segments), "this is some prefix");
assert_eq!(
pb.build(language, segments, snippets),
"this is some prefix"
);
}
// Test w/ prefix, w/o suffix.
@ -277,7 +277,10 @@ mod tests {
prefix: "this is some prefix".into(),
suffix: None,
};
assert_eq!(pb.build(language, segments), "this is some prefix");
assert_eq!(
pb.build(language, segments, snippets),
"this is some prefix"
);
}
// Test w/ prefix, w/ empty suffix.
@ -286,7 +289,10 @@ mod tests {
prefix: "this is some prefix".into(),
suffix: Some("".into()),
};
assert_eq!(pb.build(language, segments), "this is some prefix");
assert_eq!(
pb.build(language, segments, snippets),
"this is some prefix"
);
}
// Test w/ empty prefix, w/ suffix.
@ -295,7 +301,7 @@ mod tests {
prefix: "".into(),
suffix: Some("this is some suffix".into()),
};
assert_eq!(pb.build(language, segments), "");
assert_eq!(pb.build(language, segments, snippets), "");
}
// Test w/ empty prefix, w/o suffix.
@ -304,7 +310,7 @@ mod tests {
prefix: "".into(),
suffix: None,
};
assert_eq!(pb.build(language, segments), "");
assert_eq!(pb.build(language, segments, snippets), "");
}
// Test w/ empty prefix, w/ empty suffix.
@ -313,18 +319,28 @@ mod tests {
prefix: "".into(),
suffix: Some("".into()),
};
assert_eq!(pb.build(language, segments), "");
assert_eq!(pb.build(language, segments, snippets), "");
}
}
#[test]
fn test_build_prefix_readable() {
let snippets = vec![
"res_1 = invoke_function_1(n)".to_string(),
"res_2 = invoke_function_2(n)".to_string(),
"res_3 = invoke_function_3(n)".to_string(),
"res_4 = invoke_function_4(n)".to_string(),
"res_5 = invoke_function_5(n)".to_string(),
Snippet {
filepath: "a1.py".to_owned(),
body: "res_1 = invoke_function_1(n)".to_owned(),
score: 1.0,
},
Snippet {
filepath: "a2.py".to_owned(),
body: "res_2 = invoke_function_2(n)".to_owned(),
score: 1.0,
},
Snippet {
filepath: "a3.py".to_owned(),
body: "res_3 = invoke_function_3(n)".to_owned(),
score: 1.0,
},
];
let prefix = "\
@ -334,53 +350,22 @@ Use some invoke_function to do some job.
def this_is_prefix():\n";
let expected_built_prefix = "\
# Below are some relevant python snippets found in the repository:
#
# == Snippet 1 ==
# Path: a1.py
# res_1 = invoke_function_1(n)
#
# == Snippet 2 ==
# Path: a2.py
# res_2 = invoke_function_2(n)
#
# == Snippet 3 ==
# Path: a3.py
# res_3 = invoke_function_3(n)
#
# == Snippet 4 ==
# res_4 = invoke_function_4(n)
#
# == Snippet 5 ==
# res_5 = invoke_function_5(n)
'''
Use some invoke_function to do some job.
'''
def this_is_prefix():\n";
assert_eq!(
build_prefix("python", prefix, snippets),
build_prefix("python", prefix, &snippets),
expected_built_prefix
);
}
#[test]
fn test_build_prefix_count_chars() {
let snippets_expected = 4;
let snippet_payload = "a".repeat(MAX_SNIPPET_CHARS_IN_PROMPT / snippets_expected);
let mut snippets = vec![];
for _ in 0..snippets_expected + 1 {
snippets.push(snippet_payload.clone());
}
let prefix = "def this_is_prefix():\n";
let generated_prompt = build_prefix("python", prefix, snippets);
for i in 0..snippets_expected + 1 {
let st = format!("# == Snippet {} ==", i + 1);
if i < snippets_expected {
assert!(generated_prompt.contains(&st));
} else {
assert!(!generated_prompt.contains(&st));
}
}
}
}

View File

@ -57,6 +57,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
completions::CompletionResponse,
completions::Segments,
completions::Choice,
completions::DebugRequest,
completions::DebugData,
completions::Snippet,
chat::ChatCompletionRequest,
chat::Message,
chat::ChatCompletionChunk,

View File

@ -8,13 +8,21 @@ st.set_page_config(layout="wide")
language = st.text_input("Language", "rust")
query = st.text_area("Query", "get")
tokens = re.findall(r"\w+", query)
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
query = st.text_area("Query", "to_owned")
if query:
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"]
for x in hits:
st.write(x)
r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug=dict(enabled=True)))
json = r.json()
debug = json["debug_data"]
snippets = debug.get("snippets", [])
st.write("Prompt")
st.code(debug["prompt"])
st.write("Completion")
st.code(json["choices"][0]["text"])
for x in snippets:
st.write(f"**{x['filepath']}**: {x['score']}")
st.write(f"Length: {len(x['body'])}")
st.code(x['body'])