From 4388fd00500979ce88a7900eb05ca3d813b959c2 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 13 Jul 2023 17:05:41 +0800 Subject: [PATCH] feat: support prompt rewriting (#295) * refactor: extract PromptBuilder * feat: load tantivy index in prompt builder * integrate with searcher * add enable_prompt_rewrite to control rewrite behavior * nit docs * limit 1 snippet per identifier * extract magic numbers --- Cargo.lock | 2 + Cargo.toml | 1 + crates/tabby-common/src/config.rs | 8 +- crates/tabby-scheduler/Cargo.toml | 2 +- crates/tabby-scheduler/src/lib.rs | 3 +- crates/tabby/Cargo.toml | 2 + crates/tabby/src/main.rs | 5 +- crates/tabby/src/serve/completions.rs | 40 ++-- crates/tabby/src/serve/completions/prompt.rs | 198 +++++++++++++++++++ crates/tabby/src/serve/mod.rs | 9 +- 10 files changed, 242 insertions(+), 28 deletions(-) create mode 100644 crates/tabby/src/serve/completions/prompt.rs diff --git a/Cargo.lock b/Cargo.lock index 996076e..2db1fcd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2776,6 +2776,7 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" name = "tabby" version = "0.1.0" dependencies = [ + "anyhow", "axum", "axum-tracing-opentelemetry", "clap", @@ -2794,6 +2795,7 @@ dependencies = [ "tabby-common", "tabby-download", "tabby-scheduler", + "tantivy", "tokio", "tower", "tower-http 0.4.0", diff --git a/Cargo.toml b/Cargo.toml index 28fa131..3a221b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ tracing = "0.1" tracing-subscriber = "0.3" anyhow = "1.0.71" serde-jsonlines = "0.4.0" +tantivy = "0.19.2" diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index f13b932..8b9047f 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -8,9 +8,15 @@ use serde::Deserialize; use crate::path::{config_file, repositories_dir}; -#[derive(Deserialize)] +#[derive(Deserialize, Default)] pub struct Config { pub repositories: Vec, + pub experimental: Experimental, +} + +#[derive(Deserialize, Default)] +pub struct Experimental { + pub enable_prompt_rewrite: bool, } impl Config { diff --git a/crates/tabby-scheduler/Cargo.toml b/crates/tabby-scheduler/Cargo.toml index 96060ff..a21d8e2 100644 --- a/crates/tabby-scheduler/Cargo.toml +++ b/crates/tabby-scheduler/Cargo.toml @@ -10,7 +10,7 @@ anyhow = { workspace = true } filenamify = "0.1.0" job_scheduler = "1.2.1" tabby-common = { path = "../tabby-common" } -tantivy = "0.19.2" +tantivy = { workspace = true } tracing = { workspace = true } tree-sitter-javascript = "0.20.0" tree-sitter-tags = "0.20.2" diff --git a/crates/tabby-scheduler/src/lib.rs b/crates/tabby-scheduler/src/lib.rs index 423392f..5e1aa55 100644 --- a/crates/tabby-scheduler/src/lib.rs +++ b/crates/tabby-scheduler/src/lib.rs @@ -59,7 +59,7 @@ pub async fn scheduler(now: bool) -> Result<()> { #[cfg(test)] mod tests { use tabby_common::{ - config::{Config, Repository}, + config::{Config, Experimental, Repository}, path::set_tabby_root, }; use temp_testdir::*; @@ -76,6 +76,7 @@ mod tests { repositories: vec![Repository { git_url: "https://github.com/TabbyML/interview-questions".to_owned(), }], + experimental: Experimental::default(), }; repository::sync_repositories(&config).unwrap(); diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index be6ee1a..be518e8 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -30,6 +30,8 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = "0.11.0" axum-tracing-opentelemetry = "0.10.0" tracing-opentelemetry = "0.18.0" +tantivy = { workspace = true } +anyhow = { workspace = true } [dependencies.uuid] diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 1c2f9c8..6fd37a2 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -8,6 +8,7 @@ use opentelemetry::{ KeyValue, }; use opentelemetry_otlp::WithExportConfig; +use tabby_common::config::Config; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer}; #[derive(Parser)] @@ -47,8 +48,10 @@ async fn main() { let cli = Cli::parse(); init_logging(cli.otlp_endpoint); + let config = Config::load().unwrap_or(Config::default()); + match &cli.command { - Commands::Serve(args) => serve::main(args).await, + Commands::Serve(args) => serve::main(&config, args).await, Commands::Download(args) => download::main(args).await, #[cfg(feature = "scheduler")] Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index dac13b6..ed6d22e 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -1,3 +1,6 @@ +mod languages; +mod prompt; + use std::{path::Path, sync::Arc}; use axum::{extract::State, Json}; @@ -6,16 +9,13 @@ use ctranslate2_bindings::{ }; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use strfmt::{strfmt, strfmt_builder}; -use tabby_common::{events, path::ModelDir}; -use tracing::instrument; +use tabby_common::{config::Config, events, path::ModelDir}; +use tracing::{debug, instrument}; use utoipa::ToSchema; use self::languages::get_stop_words; use crate::fatal; -mod languages; - #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ "language": "python", @@ -86,23 +86,20 @@ pub async fn completion( .build() .unwrap(); - let prompt = if let Some(Segments { prefix, suffix }) = request.segments { - if let (Some(prompt_template), Some(suffix)) = (&state.prompt_template, suffix) { - if !suffix.is_empty() { - strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() - } else { - prefix - } - } else { - // If there's no prompt template, just use prefix. - prefix - } + let segments = if let Some(segments) = request.segments { + segments } else if let Some(prompt) = request.prompt { - prompt + Segments { + prefix: prompt, + suffix: None, + } } else { return Err(StatusCode::BAD_REQUEST); }; + debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); + let prompt = state.prompt_builder.build(&language, segments); + debug!("PROMPT: {}", prompt); let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let text = state.engine.inference(&prompt, options).await; @@ -126,11 +123,11 @@ pub async fn completion( pub struct CompletionState { engine: TextInferenceEngine, - prompt_template: Option, + prompt_builder: prompt::PromptBuilder, } impl CompletionState { - pub fn new(args: &crate::serve::ServeArgs) -> Self { + pub fn new(args: &crate::serve::ServeArgs, config: &Config) -> Self { let model_dir = get_model_dir(&args.model); let metadata = read_metadata(&model_dir); @@ -149,7 +146,10 @@ impl CompletionState { let engine = TextInferenceEngine::create(options); Self { engine, - prompt_template: metadata.prompt_template, + prompt_builder: prompt::PromptBuilder::new( + metadata.prompt_template, + config.experimental.enable_prompt_rewrite, + ), } } } diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs new file mode 100644 index 0000000..d29be56 --- /dev/null +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -0,0 +1,198 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, Result}; +use lazy_static::lazy_static; +use strfmt::strfmt; +use tabby_common::path::index_dir; +use tantivy::{ + collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher, +}; +use tracing::{info, warn}; + +use super::Segments; + +static MAX_SNIPPETS_TO_FETCH: usize = 20; +static MAX_SNIPPET_PER_NAME: u32 = 1; +static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 1024; + +pub struct PromptBuilder { + prompt_template: Option, + index: Option, +} + +impl PromptBuilder { + pub fn new(prompt_template: Option, enable_prompt_rewrite: bool) -> Self { + let index = if enable_prompt_rewrite { + info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ..."); + let index = IndexState::new(); + if let Err(err) = &index { + warn!("Failed to open index in {:?}: {:?}", index_dir(), err); + } + index.ok() + } else { + None + }; + + PromptBuilder { + prompt_template, + index, + } + } + + fn build_prompt(&self, prefix: String, suffix: String) -> String { + if let Some(prompt_template) = &self.prompt_template { + strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() + } else { + prefix + } + } + + pub fn build(&self, language: &str, segments: Segments) -> String { + let segments = self.rewrite(language, segments); + if let Some(suffix) = segments.suffix { + self.build_prompt(segments.prefix, suffix) + } else { + self.build_prompt(segments.prefix, "".to_owned()) + } + } + + fn rewrite(&self, language: &str, segments: Segments) -> Segments { + if let Some(index) = &self.index { + rewrite_with_index(index, language, segments) + } else { + segments + } + } +} + +fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments { + let snippets = collect_snippets(index, language, &segments.prefix); + if snippets.is_empty() { + segments + } else { + let prefix = build_prefix(language, &segments.prefix, snippets); + Segments { prefix, ..segments } + } +} + +fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { + let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap(); + let mut lines: Vec = vec![format!( + "Below are some relevant {} snippets found in the repository:", + language + )]; + + let mut count_characters = 0; + for (i, snippet) in snippets.iter().enumerate() { + if count_characters + snippet.len() > MAX_SNIPPET_CHARS_IN_PROMPT { + break; + } + + lines.push(format!("== Snippet {} ==", i + 1)); + for line in snippet.lines() { + lines.push(line.to_owned()); + count_characters += line.len(); + } + + count_characters += snippet.len(); + } + + let commented_lines: Vec = lines + .iter() + .map(|x| format!("{} {}", comment_char, x)) + .collect(); + let comments = commented_lines.join("\n"); + format!("{}\n{}", comments, prefix) +} + +fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec { + let mut ret = Vec::new(); + let sanitized_text = sanitize_text(text); + if sanitized_text.is_empty() { + return ret; + } + + let query_text = format!( + "language:{} AND kind:call AND ({})", + language, sanitized_text + ); + let query = match index.query_parser.parse_query(&query_text) { + Ok(query) => query, + Err(err) => { + warn!("Failed to parse query: {}", err); + return ret; + } + }; + + let top_docs = index + .searcher + .search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH)) + .unwrap(); + + let mut names: HashMap = HashMap::new(); + for (_score, doc_address) in top_docs { + let doc = index.searcher.doc(doc_address).unwrap(); + let name = doc + .get_first(index.field_name) + .and_then(|x| x.as_text()) + .unwrap(); + let count = *names.get(name).unwrap_or(&0); + + // Max 1 snippet per identifier. + if count >= MAX_SNIPPET_PER_NAME { + continue; + } + + let body = doc + .get_first(index.field_body) + .and_then(|x| x.as_text()) + .unwrap(); + names.insert(name.to_owned(), count + 1); + ret.push(body.to_owned()); + } + + ret +} + +fn sanitize_text(text: &str) -> String { + let x = text.replace(|c: char| !c.is_ascii_digit() && !c.is_alphabetic(), " "); + let tokens: Vec<&str> = x.split(' ').collect(); + tokens.join(" ") +} + +struct IndexState { + searcher: Searcher, + query_parser: QueryParser, + field_name: Field, + field_body: Field, +} + +impl IndexState { + fn new() -> Result { + let index = Index::open_in_dir(index_dir())?; + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::OnCommit) + .try_into()?; + let field_name = index + .schema() + .get_field("name") + .ok_or(anyhow!("Index doesn't have required field"))?; + let field_body = index + .schema() + .get_field("body") + .ok_or(anyhow!("Index doesn't have required field"))?; + let query_parser = QueryParser::for_index(&index, vec![field_name]); + Ok(Self { + searcher: reader.searcher(), + query_parser, + field_name, + field_body, + }) + } +} + +lazy_static! { + static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = + HashMap::from([("python", "#")]); +} diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 4b8d67f..03fff5c 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -10,6 +10,7 @@ use std::{ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; +use tabby_common::config::Config; use tower_http::cors::CorsLayer; use tracing::info; use utoipa::OpenApi; @@ -107,7 +108,7 @@ pub struct ServeArgs { compute_type: ComputeType, } -pub async fn main(args: &ServeArgs) { +pub async fn main(config: &Config, args: &ServeArgs) { valid_args(args); // Ensure model exists. @@ -123,7 +124,7 @@ pub async fn main(args: &ServeArgs) { let app = Router::new() .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) - .nest("/v1", api_router(args)) + .nest("/v1", api_router(args, config)) .fallback(fallback()); let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port)); @@ -134,7 +135,7 @@ pub async fn main(args: &ServeArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -fn api_router(args: &ServeArgs) -> Router { +fn api_router(args: &ServeArgs, config: &Config) -> Router { Router::new() .route("/events", routing::post(events::log_event)) .route( @@ -144,7 +145,7 @@ fn api_router(args: &ServeArgs) -> Router { .route( "/completions", routing::post(completions::completion) - .with_state(Arc::new(completions::CompletionState::new(args))), + .with_state(Arc::new(completions::CompletionState::new(args, config))), ) .layer(CorsLayer::permissive()) .layer(opentelemetry_tracing_layer())