feat: support prompt rewriting (#295)

* refactor: extract PromptBuilder

* feat: load tantivy index in prompt builder

* integrate with searcher

* add enable_prompt_rewrite to control rewrite behavior

* nit docs

* limit 1 snippet per identifier

* extract magic numbers
sweep/improve-logging-information
Meng Zhang 2023-07-13 17:05:41 +08:00 committed by GitHub
parent 207559b0a2
commit 4388fd0050
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 242 additions and 28 deletions

2
Cargo.lock generated
View File

@ -2776,6 +2776,7 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
name = "tabby"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-tracing-opentelemetry",
"clap",
@ -2794,6 +2795,7 @@ dependencies = [
"tabby-common",
"tabby-download",
"tabby-scheduler",
"tantivy",
"tokio",
"tower",
"tower-http 0.4.0",

View File

@ -24,3 +24,4 @@ tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1.0.71"
serde-jsonlines = "0.4.0"
tantivy = "0.19.2"

View File

@ -8,9 +8,15 @@ use serde::Deserialize;
use crate::path::{config_file, repositories_dir};
#[derive(Deserialize)]
#[derive(Deserialize, Default)]
pub struct Config {
pub repositories: Vec<Repository>,
pub experimental: Experimental,
}
#[derive(Deserialize, Default)]
pub struct Experimental {
pub enable_prompt_rewrite: bool,
}
impl Config {

View File

@ -10,7 +10,7 @@ anyhow = { workspace = true }
filenamify = "0.1.0"
job_scheduler = "1.2.1"
tabby-common = { path = "../tabby-common" }
tantivy = "0.19.2"
tantivy = { workspace = true }
tracing = { workspace = true }
tree-sitter-javascript = "0.20.0"
tree-sitter-tags = "0.20.2"

View File

@ -59,7 +59,7 @@ pub async fn scheduler(now: bool) -> Result<()> {
#[cfg(test)]
mod tests {
use tabby_common::{
config::{Config, Repository},
config::{Config, Experimental, Repository},
path::set_tabby_root,
};
use temp_testdir::*;
@ -76,6 +76,7 @@ mod tests {
repositories: vec![Repository {
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
}],
experimental: Experimental::default(),
};
repository::sync_repositories(&config).unwrap();

View File

@ -30,6 +30,8 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
axum-tracing-opentelemetry = "0.10.0"
tracing-opentelemetry = "0.18.0"
tantivy = { workspace = true }
anyhow = { workspace = true }
[dependencies.uuid]

View File

@ -8,6 +8,7 @@ use opentelemetry::{
KeyValue,
};
use opentelemetry_otlp::WithExportConfig;
use tabby_common::config::Config;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Parser)]
@ -47,8 +48,10 @@ async fn main() {
let cli = Cli::parse();
init_logging(cli.otlp_endpoint);
let config = Config::load().unwrap_or(Config::default());
match &cli.command {
Commands::Serve(args) => serve::main(args).await,
Commands::Serve(args) => serve::main(&config, args).await,
Commands::Download(args) => download::main(args).await,
#[cfg(feature = "scheduler")]
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)

View File

@ -1,3 +1,6 @@
mod languages;
mod prompt;
use std::{path::Path, sync::Arc};
use axum::{extract::State, Json};
@ -6,16 +9,13 @@ use ctranslate2_bindings::{
};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use strfmt::{strfmt, strfmt_builder};
use tabby_common::{events, path::ModelDir};
use tracing::instrument;
use tabby_common::{config::Config, events, path::ModelDir};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use self::languages::get_stop_words;
use crate::fatal;
mod languages;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"language": "python",
@ -86,23 +86,20 @@ pub async fn completion(
.build()
.unwrap();
let prompt = if let Some(Segments { prefix, suffix }) = request.segments {
if let (Some(prompt_template), Some(suffix)) = (&state.prompt_template, suffix) {
if !suffix.is_empty() {
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
} else {
prefix
}
} else {
// If there's no prompt template, just use prefix.
prefix
}
let segments = if let Some(segments) = request.segments {
segments
} else if let Some(prompt) = request.prompt {
prompt
Segments {
prefix: prompt,
suffix: None,
}
} else {
return Err(StatusCode::BAD_REQUEST);
};
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let prompt = state.prompt_builder.build(&language, segments);
debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.inference(&prompt, options).await;
@ -126,11 +123,11 @@ pub async fn completion(
pub struct CompletionState {
engine: TextInferenceEngine,
prompt_template: Option<String>,
prompt_builder: prompt::PromptBuilder,
}
impl CompletionState {
pub fn new(args: &crate::serve::ServeArgs) -> Self {
pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self {
let model_dir = get_model_dir(&args.model);
let metadata = read_metadata(&model_dir);
@ -149,7 +146,10 @@ impl CompletionState {
let engine = TextInferenceEngine::create(options);
Self {
engine,
prompt_template: metadata.prompt_template,
prompt_builder: prompt::PromptBuilder::new(
metadata.prompt_template,
config.experimental.enable_prompt_rewrite,
),
}
}
}

View File

@ -0,0 +1,198 @@
use std::collections::HashMap;
use anyhow::{anyhow, Result};
use lazy_static::lazy_static;
use strfmt::strfmt;
use tabby_common::path::index_dir;
use tantivy::{
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
};
use tracing::{info, warn};
use super::Segments;
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_PER_NAME: u32 = 1;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024;
pub struct PromptBuilder {
prompt_template: Option<String>,
index: Option<IndexState>,
}
impl PromptBuilder {
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
let index = if enable_prompt_rewrite {
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
let index = IndexState::new();
if let Err(err) = &index {
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
}
index.ok()
} else {
None
};
PromptBuilder {
prompt_template,
index,
}
}
fn build_prompt(&self, prefix: String, suffix: String) -> String {
if let Some(prompt_template) = &self.prompt_template {
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
} else {
prefix
}
}
pub fn build(&self, language: &str, segments: Segments) -> String {
let segments = self.rewrite(language, segments);
if let Some(suffix) = segments.suffix {
self.build_prompt(segments.prefix, suffix)
} else {
self.build_prompt(segments.prefix, "".to_owned())
}
}
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
if let Some(index) = &self.index {
rewrite_with_index(index, language, segments)
} else {
segments
}
}
}
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
let snippets = collect_snippets(index, language, &segments.prefix);
if snippets.is_empty() {
segments
} else {
let prefix = build_prefix(language, &segments.prefix, snippets);
Segments { prefix, ..segments }
}
}
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
let mut lines: Vec<String> = vec![format!(
"Below are some relevant {} snippets found in the repository:",
language
)];
let mut count_characters = 0;
for (i, snippet) in snippets.iter().enumerate() {
if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT {
break;
}
lines.push(format!("== Snippet {} ==", i + 1));
for line in snippet.lines() {
lines.push(line.to_owned());
count_characters += line.len();
}
count_characters += snippet.len();
}
let commented_lines: Vec<String> = lines
.iter()
.map(|x| format!("{} {}", comment_char, x))
.collect();
let comments = commented_lines.join("\n");
format!("{}\n{}", comments, prefix)
}
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
let mut ret = Vec::new();
let sanitized_text = sanitize_text(text);
if sanitized_text.is_empty() {
return ret;
}
let query_text = format!(
"language:{} AND kind:call AND ({})",
language, sanitized_text
);
let query = match index.query_parser.parse_query(&query_text) {
Ok(query) => query,
Err(err) => {
warn!("Failed to parse query: {}", err);
return ret;
}
};
let top_docs = index
.searcher
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
.unwrap();
let mut names: HashMap<String, u32> = HashMap::new();
for (_score, doc_address) in top_docs {
let doc = index.searcher.doc(doc_address).unwrap();
let name = doc
.get_first(index.field_name)
.and_then(|x| x.as_text())
.unwrap();
let count = *names.get(name).unwrap_or(&0);
// Max 1 snippet per identifier.
if count >= MAX_SNIPPET_PER_NAME {
continue;
}
let body = doc
.get_first(index.field_body)
.and_then(|x| x.as_text())
.unwrap();
names.insert(name.to_owned(), count + 1);
ret.push(body.to_owned());
}
ret
}
fn sanitize_text(text: &str) -> String {
let x = text.replace(|c: char| !c.is_ascii_digit() && !c.is_alphabetic(), " ");
let tokens: Vec<&str> = x.split(' ').collect();
tokens.join(" ")
}
struct IndexState {
searcher: Searcher,
query_parser: QueryParser,
field_name: Field,
field_body: Field,
}
impl IndexState {
fn new() -> Result<IndexState> {
let index = Index::open_in_dir(index_dir())?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()?;
let field_name = index
.schema()
.get_field("name")
.ok_or(anyhow!("Index doesn't have required field"))?;
let field_body = index
.schema()
.get_field("body")
.ok_or(anyhow!("Index doesn't have required field"))?;
let query_parser = QueryParser::for_index(&index, vec![field_name]);
Ok(Self {
searcher: reader.searcher(),
query_parser,
field_name,
field_body,
})
}
}
lazy_static! {
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
HashMap::from([("python", "#")]);
}

View File

@ -10,6 +10,7 @@ use std::{
use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args;
use tabby_common::config::Config;
use tower_http::cors::CorsLayer;
use tracing::info;
use utoipa::OpenApi;
@ -107,7 +108,7 @@ pub struct ServeArgs {
compute_type: ComputeType,
}
pub async fn main(args: &ServeArgs) {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
// Ensure model exists.
@ -123,7 +124,7 @@ pub async fn main(args: &ServeArgs) {
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.nest("/v1", api_router(args))
.nest("/v1", api_router(args, config))
.fallback(fallback());
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
@ -134,7 +135,7 @@ pub async fn main(args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}
fn api_router(args: &ServeArgs) -> Router {
fn api_router(args: &ServeArgs, config: &Config) -> Router {
Router::new()
.route("/events", routing::post(events::log_event))
.route(
@ -144,7 +145,7 @@ fn api_router(args: &ServeArgs) -> Router {
.route(
"/completions",
routing::post(completions::completion)
.with_state(Arc::new(completions::CompletionState::new(args))),
.with_state(Arc::new(completions::CompletionState::new(args, config))),
)
.layer(CorsLayer::permissive())
.layer(opentelemetry_tracing_layer())