From d85a7892d18011b28285309cdf12440283d45ccf Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 6 Oct 2023 17:29:24 -0700 Subject: [PATCH] feat: connect prompt rewriting part (#517) * feat: enable /v1beta/search if index is available * make prompt rewriting work * update * fix test * fix api doc --- crates/tabby-common/src/config.rs | 13 -- .../tabby-scheduler/tests/integration_test.rs | 3 +- crates/tabby/src/serve/completions.rs | 10 +- crates/tabby/src/serve/completions/prompt.rs | 144 +++++++----------- crates/tabby/src/serve/mod.rs | 40 +++-- crates/tabby/src/serve/search.rs | 98 +++++++----- 6 files changed, 145 insertions(+), 163 deletions(-) diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 281ce22..c1cd255 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -13,19 +13,10 @@ pub struct Config { #[serde(default)] pub repositories: Vec, - #[serde(default)] - pub experimental: Experimental, - #[serde(default)] pub swagger: SwaggerConfig, } -#[derive(Serialize, Deserialize, Default)] -pub struct Experimental { - #[serde(default = "default_as_false")] - pub enable_prompt_rewrite: bool, -} - #[derive(Serialize, Deserialize, Default)] pub struct SwaggerConfig { pub server_url: Option, @@ -64,10 +55,6 @@ impl Repository { } } -fn default_as_false() -> bool { - false -} - #[cfg(test)] mod tests { use super::Config; diff --git a/crates/tabby-scheduler/tests/integration_test.rs b/crates/tabby-scheduler/tests/integration_test.rs index 2fcde8d..d4522f1 100644 --- a/crates/tabby-scheduler/tests/integration_test.rs +++ b/crates/tabby-scheduler/tests/integration_test.rs @@ -3,7 +3,7 @@ mod tests { use std::fs::create_dir_all; use tabby_common::{ - config::{Config, Experimental, Repository, SwaggerConfig}, + config::{Config, Repository, SwaggerConfig}, path::set_tabby_root, }; use temp_testdir::*; @@ -21,7 +21,6 @@ mod tests { git_url: "https://github.com/TabbyML/interview-questions".to_owned(), }], swagger: SwaggerConfig { server_url: None }, - experimental: Experimental::default(), }; config.save(); diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 1334b87..8bfc16e 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -6,12 +6,13 @@ use std::sync::Arc; use axum::{extract::State, Json}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use tabby_common::{config::Config, events}; +use tabby_common::events; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; use self::languages::get_stop_words; +use super::search::IndexServer; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ @@ -127,15 +128,12 @@ pub struct CompletionState { impl CompletionState { pub fn new( engine: Arc>, + index_server: Option>, prompt_template: Option, - config: &Config, ) -> Self { Self { engine, - prompt_builder: prompt::PromptBuilder::new( - prompt_template, - config.experimental.enable_prompt_rewrite, - ), + prompt_builder: prompt::PromptBuilder::new(prompt_template, index_server), } } } diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index 0380317..ff99a86 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -1,41 +1,32 @@ -use std::collections::HashMap; +use std::{collections::HashMap, env, sync::Arc}; -use anyhow::Result; use lazy_static::lazy_static; use strfmt::strfmt; -use tabby_common::path::index_dir; -use tantivy::{ - collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher, -}; use tracing::{info, warn}; use super::Segments; +use crate::serve::search::IndexServer; static MAX_SNIPPETS_TO_FETCH: usize = 20; -static MAX_SNIPPET_PER_NAME: u32 = 1; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; pub struct PromptBuilder { prompt_template: Option, - index: Option, + index_server: Option>, } impl PromptBuilder { - pub fn new(prompt_template: Option, enable_prompt_rewrite: bool) -> Self { - let index = if enable_prompt_rewrite { - info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ..."); - let index = IndexState::new(); - if let Err(err) = &index { - warn!("Failed to open index in {:?}: {:?}", index_dir(), err); - } - index.ok() + pub fn new(prompt_template: Option, index_server: Option>) -> Self { + let index_server = if env::var("TABBY_ENABLE_PROMPT_REWRITE").is_ok() { + info!("Prompt rewriting is enabled..."); + index_server } else { None }; PromptBuilder { prompt_template, - index, + index_server, } } @@ -53,8 +44,8 @@ impl PromptBuilder { } fn rewrite(&self, language: &str, segments: Segments) -> Segments { - if let Some(index) = &self.index { - rewrite_with_index(index, language, segments) + if let Some(index_server) = &self.index_server { + rewrite_with_index(index_server, language, segments) } else { segments } @@ -74,8 +65,12 @@ fn get_default_suffix(suffix: Option) -> String { } } -fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments { - let snippets = collect_snippets(index, language, &segments.prefix); +fn rewrite_with_index( + index_server: &Arc, + language: &str, + segments: Segments, +) -> Segments { + let snippets = collect_snippets(index_server, language, &segments.prefix); if snippets.is_empty() { segments } else { @@ -85,11 +80,18 @@ fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> } fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { + if snippets.is_empty() { + return prefix.to_owned(); + } + let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap(); - let mut lines: Vec = vec![format!( - "Below are some relevant {} snippets found in the repository:", - language - )]; + let mut lines: Vec = vec![ + format!( + "Below are some relevant {} snippets found in the repository:", + language + ), + "".to_owned(), + ]; let mut count_characters = 0; for (i, snippet) in snippets.iter().enumerate() { @@ -102,60 +104,51 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec) -> String { lines.push(line.to_owned()); } + if i < snippets.len() - 1 { + lines.push("".to_owned()); + } count_characters += snippet.len(); } let commented_lines: Vec = lines .iter() - .map(|x| format!("{} {}", comment_char, x)) + .map(|x| { + if x.is_empty() { + comment_char.to_string() + } else { + format!("{} {}", comment_char, x) + } + }) .collect(); let comments = commented_lines.join("\n"); format!("{}\n{}", comments, prefix) } -fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec { +fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); let sanitized_text = sanitize_text(text); if sanitized_text.is_empty() { return ret; } - let query_text = format!( - "language:{} AND kind:call AND ({})", - language, sanitized_text - ); - let query = match index.query_parser.parse_query(&query_text) { - Ok(query) => query, + let query_text = format!("language:{} AND ({})", language, sanitized_text); + + let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) { + Ok(serp) => serp, Err(err) => { - warn!("Failed to parse query: {}", err); + warn!("Failed to search query: {}", err); return ret; } }; - let top_docs = index - .searcher - .search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH)) - .unwrap(); + for hit in serp.hits { + let body = hit.doc.body; - let mut names: HashMap = HashMap::new(); - for (_score, doc_address) in top_docs { - let doc = index.searcher.doc(doc_address).unwrap(); - let name = doc - .get_first(index.field_name) - .and_then(|x| x.as_text()) - .unwrap(); - let count = *names.get(name).unwrap_or(&0); - - // Max 1 snippet per identifier. - if count >= MAX_SNIPPET_PER_NAME { + if text.contains(&body) { + // Exclude snippets already in the context window. continue; } - let body = doc - .get_first(index.field_body) - .and_then(|x| x.as_text()) - .unwrap(); - names.insert(name.to_owned(), count + 1); ret.push(body.to_owned()); } @@ -172,41 +165,9 @@ fn sanitize_text(text: &str) -> String { tokens.join(" ") } -struct IndexState { - searcher: Searcher, - query_parser: QueryParser, - field_name: Field, - field_body: Field, -} - -impl IndexState { - fn new() -> Result { - let index = Index::open_in_dir(index_dir())?; - let reader = index - .reader_builder() - .reload_policy(ReloadPolicy::OnCommit) - .try_into()?; - let field_name = index.schema().get_field("name")?; - let field_body = index.schema().get_field("body")?; - let query_parser = QueryParser::for_index(&index, vec![field_body]); - Ok(Self { - searcher: reader.searcher(), - query_parser, - field_name, - field_body, - }) - } -} - lazy_static! { - static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = HashMap::from([ - ("python", "#"), - ("rust", "//"), - ("javascript-typescript", "//"), - ("go", "//"), - ("java", "//"), - ("lua", "--"), - ]); + static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = + HashMap::from([("python", "#"), ("rust", "//"),]); } #[cfg(test)] @@ -222,7 +183,7 @@ mod tests { }; // Init prompt builder with prompt rewrite disabled. - PromptBuilder::new(prompt_template, false) + PromptBuilder::new(prompt_template, None) } #[test] @@ -379,14 +340,19 @@ def this_is_prefix():\n"; let expected_built_prefix = "\ # Below are some relevant python snippets found in the repository: +# # == Snippet 1 == # res_1 = invoke_function_1(n) +# # == Snippet 2 == # res_2 = invoke_function_2(n) +# # == Snippet 3 == # res_3 = invoke_function_3(n) +# # == Snippet 4 == # res_4 = invoke_function_4(n) +# # == Snippet 5 == # res_5 = invoke_function_5(n) ''' diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 896bab7..5152c2e 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -22,7 +22,7 @@ use tabby_common::{ use tabby_download::Downloader; use tokio::time::sleep; use tower_http::cors::CorsLayer; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa_swagger_ui::SwaggerUi; @@ -62,6 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi chat::ChatCompletionChunk, health::HealthState, health::Version, + search::SearchResponse, + search::Hit, + search::HitDocument )) )] struct ApiDoc; @@ -92,10 +95,6 @@ pub struct ServeArgs { #[clap(long)] chat_model: Option, - /// When set to `true`, the search API route will be enabled. - #[clap(long, default_value_t = false)] - enable_search: bool, - #[clap(long, default_value_t = 8080)] port: u16, @@ -144,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { doc.override_doc(args, &config.swagger); let app = Router::new() - .merge(api_router(args, config)) + .merge(api_router(args)) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .fallback(fallback()); @@ -165,7 +164,15 @@ pub async fn main(config: &Config, args: &ServeArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } -fn api_router(args: &ServeArgs, config: &Config) -> Router { +fn api_router(args: &ServeArgs) -> Router { + let index_server = match IndexServer::load() { + Ok(index_server) => Some(Arc::new(index_server)), + Err(err) => { + debug!("Load index failed due to `{}`", err); + None + } + }; + let completion_state = { let ( engine, @@ -174,7 +181,11 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { }, ) = create_engine(&args.model, args); let engine = Arc::new(engine); - let state = completions::CompletionState::new(engine.clone(), prompt_template, config); + let state = completions::CompletionState::new( + engine.clone(), + index_server.clone(), + prompt_template, + ); Arc::new(state) }; @@ -201,19 +212,20 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router { routing::post(completions::completions).with_state(completion_state), ); - let router = if args.enable_search { + let router = if let Some(chat_state) = chat_state { router.route( - "/v1beta/search", - routing::get(search::search).with_state(Arc::new(IndexServer::new())), + "/v1beta/chat/completions", + routing::post(chat::completions).with_state(chat_state), ) } else { router }; - let router = if let Some(chat_state) = chat_state { + let router = if let Some(index_server) = index_server { + info!("Index is ready, enabling /v1beta/search API route"); router.route( - "/v1beta/chat/completions", - routing::post(chat::completions).with_state(chat_state), + "/v1beta/search", + routing::get(search::search).with_state(index_server), ) } else { router diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index 4aed1a7..9e97fd1 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -11,11 +11,11 @@ use tabby_common::{index::IndexExt, path}; use tantivy::{ collector::{Count, TopDocs}, query::QueryParser, - schema::{Field, FieldType, NamedFieldDocument, Schema}, - DocAddress, Document, Index, IndexReader, Score, + schema::Field, + DocAddress, Document, Index, IndexReader, }; use tracing::instrument; -use utoipa::IntoParams; +use utoipa::{IntoParams, ToSchema}; #[derive(Deserialize, IntoParams)] pub struct SearchQuery { @@ -29,18 +29,27 @@ pub struct SearchQuery { offset: Option, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct SearchResponse { - q: String, - num_hits: usize, - hits: Vec, + pub num_hits: usize, + pub hits: Vec, } -#[derive(Serialize)] +#[derive(Serialize, ToSchema)] pub struct Hit { - score: Score, - doc: NamedFieldDocument, - id: u32, + pub score: f32, + pub doc: HitDocument, + pub id: u32, +} + +#[derive(Serialize, ToSchema)] +pub struct HitDocument { + pub body: String, + pub filepath: String, + pub git_url: String, + pub kind: String, + pub language: String, + pub name: String, } #[utoipa::path( @@ -50,7 +59,7 @@ pub struct Hit { operation_id = "search", tag = "v1beta", responses( - (status = 200, description = "Success" , content_type = "application/json"), + (status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"), (status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"), ) )] @@ -73,40 +82,41 @@ pub async fn search( pub struct IndexServer { reader: IndexReader, query_parser: QueryParser, - schema: Schema, + + field_body: Field, + field_filepath: Field, + field_git_url: Field, + field_kind: Field, + field_language: Field, + field_name: Field, } impl IndexServer { - pub fn new() -> Self { - Self::load().expect("Failed to load code state") - } - - fn load() -> Result { + pub fn load() -> Result { let index = Index::open_in_dir(path::index_dir())?; index.register_tokenizer(); let schema = index.schema(); - let default_fields: Vec = schema - .fields() - .filter(|&(_, field_entry)| match field_entry.field_type() { - FieldType::Str(ref text_field_options) => { - text_field_options.get_indexing_options().is_some() - } - _ => false, - }) - .map(|(field, _)| field) - .collect(); + let field_body = schema.get_field("body").unwrap(); let query_parser = - QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone()); - let reader = index.reader()?; + QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone()); + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::OnCommit) + .try_into()?; Ok(Self { reader, query_parser, - schema, + field_body, + field_filepath: schema.get_field("filepath").unwrap(), + field_git_url: schema.get_field("git_url").unwrap(), + field_kind: schema.get_field("kind").unwrap(), + field_language: schema.get_field("language").unwrap(), + field_name: schema.get_field("name").unwrap(), }) } - fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result { + pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result { let query = self .query_parser .parse_query(q) @@ -127,18 +137,28 @@ impl IndexServer { }) .collect() }; - Ok(SearchResponse { - q: q.to_owned(), - num_hits, - hits, - }) + Ok(SearchResponse { num_hits, hits }) } - fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit { + fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit { Hit { score, - doc: self.schema.to_named_doc(&doc), + doc: HitDocument { + body: get_field(&doc, self.field_body), + filepath: get_field(&doc, self.field_filepath), + git_url: get_field(&doc, self.field_git_url), + kind: get_field(&doc, self.field_kind), + name: get_field(&doc, self.field_name), + language: get_field(&doc, self.field_language), + }, id: doc_address.doc_id, } } } + +fn get_field(doc: &Document, field: Field) -> String { + doc.get_first(field) + .and_then(|x| x.as_text()) + .unwrap() + .to_owned() +}