feat: connect prompt rewriting part (#517)

* feat: enable /v1beta/search if index is available

* make prompt rewriting work

* update

* fix test

* fix api doc
wsxiaoys-patch-1
Meng Zhang 2023-10-06 17:29:24 -07:00 committed by GitHub
parent 8497fb1372
commit d85a7892d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 145 additions and 163 deletions

View File

@ -13,19 +13,10 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub repositories: Vec<Repository>, pub repositories: Vec<Repository>,
#[serde(default)]
pub experimental: Experimental,
#[serde(default)] #[serde(default)]
pub swagger: SwaggerConfig, pub swagger: SwaggerConfig,
} }
#[derive(Serialize, Deserialize, Default)]
pub struct Experimental {
#[serde(default = "default_as_false")]
pub enable_prompt_rewrite: bool,
}
#[derive(Serialize, Deserialize, Default)] #[derive(Serialize, Deserialize, Default)]
pub struct SwaggerConfig { pub struct SwaggerConfig {
pub server_url: Option<String>, pub server_url: Option<String>,
@ -64,10 +55,6 @@ impl Repository {
} }
} }
fn default_as_false() -> bool {
false
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Config; use super::Config;

View File

@ -3,7 +3,7 @@ mod tests {
use std::fs::create_dir_all; use std::fs::create_dir_all;
use tabby_common::{ use tabby_common::{
config::{Config, Experimental, Repository, SwaggerConfig}, config::{Config, Repository, SwaggerConfig},
path::set_tabby_root, path::set_tabby_root,
}; };
use temp_testdir::*; use temp_testdir::*;
@ -21,7 +21,6 @@ mod tests {
git_url: "https://github.com/TabbyML/interview-questions".to_owned(), git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
}], }],
swagger: SwaggerConfig { server_url: None }, swagger: SwaggerConfig { server_url: None },
experimental: Experimental::default(),
}; };
config.save(); config.save();

View File

@ -6,12 +6,13 @@ use std::sync::Arc;
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::{config::Config, events}; use tabby_common::events;
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
use self::languages::get_stop_words; use self::languages::get_stop_words;
use super::search::IndexServer;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({ #[schema(example=json!({
@ -127,15 +128,12 @@ pub struct CompletionState {
impl CompletionState { impl CompletionState {
pub fn new( pub fn new(
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<Box<dyn TextGeneration>>,
index_server: Option<Arc<IndexServer>>,
prompt_template: Option<String>, prompt_template: Option<String>,
config: &Config,
) -> Self { ) -> Self {
Self { Self {
engine, engine,
prompt_builder: prompt::PromptBuilder::new( prompt_builder: prompt::PromptBuilder::new(prompt_template, index_server),
prompt_template,
config.experimental.enable_prompt_rewrite,
),
} }
} }
} }

View File

@ -1,41 +1,32 @@
use std::collections::HashMap; use std::{collections::HashMap, env, sync::Arc};
use anyhow::Result;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::path::index_dir;
use tantivy::{
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
};
use tracing::{info, warn}; use tracing::{info, warn};
use super::Segments; use super::Segments;
use crate::serve::search::IndexServer;
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_PER_NAME: u32 = 1;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
pub struct PromptBuilder { pub struct PromptBuilder {
prompt_template: Option<String>, prompt_template: Option<String>,
index: Option<IndexState>, index_server: Option<Arc<IndexServer>>,
} }
impl PromptBuilder { impl PromptBuilder {
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self { pub fn new(prompt_template: Option<String>, index_server: Option<Arc<IndexServer>>) -> Self {
let index = if enable_prompt_rewrite { let index_server = if env::var("TABBY_ENABLE_PROMPT_REWRITE").is_ok() {
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ..."); info!("Prompt rewriting is enabled...");
let index = IndexState::new(); index_server
if let Err(err) = &index {
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
}
index.ok()
} else { } else {
None None
}; };
PromptBuilder { PromptBuilder {
prompt_template, prompt_template,
index, index_server,
} }
} }
@ -53,8 +44,8 @@ impl PromptBuilder {
} }
fn rewrite(&self, language: &str, segments: Segments) -> Segments { fn rewrite(&self, language: &str, segments: Segments) -> Segments {
if let Some(index) = &self.index { if let Some(index_server) = &self.index_server {
rewrite_with_index(index, language, segments) rewrite_with_index(index_server, language, segments)
} else { } else {
segments segments
} }
@ -74,8 +65,12 @@ fn get_default_suffix(suffix: Option<String>) -> String {
} }
} }
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments { fn rewrite_with_index(
let snippets = collect_snippets(index, language, &segments.prefix); index_server: &Arc<IndexServer>,
language: &str,
segments: Segments,
) -> Segments {
let snippets = collect_snippets(index_server, language, &segments.prefix);
if snippets.is_empty() { if snippets.is_empty() {
segments segments
} else { } else {
@ -85,11 +80,18 @@ fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) ->
} }
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String { fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
if snippets.is_empty() {
return prefix.to_owned();
}
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap(); let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
let mut lines: Vec<String> = vec![format!( let mut lines: Vec<String> = vec![
"Below are some relevant {} snippets found in the repository:", format!(
language "Below are some relevant {} snippets found in the repository:",
)]; language
),
"".to_owned(),
];
let mut count_characters = 0; let mut count_characters = 0;
for (i, snippet) in snippets.iter().enumerate() { for (i, snippet) in snippets.iter().enumerate() {
@ -102,60 +104,51 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
lines.push(line.to_owned()); lines.push(line.to_owned());
} }
if i < snippets.len() - 1 {
lines.push("".to_owned());
}
count_characters += snippet.len(); count_characters += snippet.len();
} }
let commented_lines: Vec<String> = lines let commented_lines: Vec<String> = lines
.iter() .iter()
.map(|x| format!("{} {}", comment_char, x)) .map(|x| {
if x.is_empty() {
comment_char.to_string()
} else {
format!("{} {}", comment_char, x)
}
})
.collect(); .collect();
let comments = commented_lines.join("\n"); let comments = commented_lines.join("\n");
format!("{}\n{}", comments, prefix) format!("{}\n{}", comments, prefix)
} }
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> { fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<String> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let sanitized_text = sanitize_text(text); let sanitized_text = sanitize_text(text);
if sanitized_text.is_empty() { if sanitized_text.is_empty() {
return ret; return ret;
} }
let query_text = format!( let query_text = format!("language:{} AND ({})", language, sanitized_text);
"language:{} AND kind:call AND ({})",
language, sanitized_text let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) {
); Ok(serp) => serp,
let query = match index.query_parser.parse_query(&query_text) {
Ok(query) => query,
Err(err) => { Err(err) => {
warn!("Failed to parse query: {}", err); warn!("Failed to search query: {}", err);
return ret; return ret;
} }
}; };
let top_docs = index for hit in serp.hits {
.searcher let body = hit.doc.body;
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
.unwrap();
let mut names: HashMap<String, u32> = HashMap::new(); if text.contains(&body) {
for (_score, doc_address) in top_docs { // Exclude snippets already in the context window.
let doc = index.searcher.doc(doc_address).unwrap();
let name = doc
.get_first(index.field_name)
.and_then(|x| x.as_text())
.unwrap();
let count = *names.get(name).unwrap_or(&0);
// Max 1 snippet per identifier.
if count >= MAX_SNIPPET_PER_NAME {
continue; continue;
} }
let body = doc
.get_first(index.field_body)
.and_then(|x| x.as_text())
.unwrap();
names.insert(name.to_owned(), count + 1);
ret.push(body.to_owned()); ret.push(body.to_owned());
} }
@ -172,41 +165,9 @@ fn sanitize_text(text: &str) -> String {
tokens.join(" ") tokens.join(" ")
} }
struct IndexState {
searcher: Searcher,
query_parser: QueryParser,
field_name: Field,
field_body: Field,
}
impl IndexState {
fn new() -> Result<IndexState> {
let index = Index::open_in_dir(index_dir())?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()?;
let field_name = index.schema().get_field("name")?;
let field_body = index.schema().get_field("body")?;
let query_parser = QueryParser::for_index(&index, vec![field_body]);
Ok(Self {
searcher: reader.searcher(),
query_parser,
field_name,
field_body,
})
}
}
lazy_static! { lazy_static! {
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = HashMap::from([ static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
("python", "#"), HashMap::from([("python", "#"), ("rust", "//"),]);
("rust", "//"),
("javascript-typescript", "//"),
("go", "//"),
("java", "//"),
("lua", "--"),
]);
} }
#[cfg(test)] #[cfg(test)]
@ -222,7 +183,7 @@ mod tests {
}; };
// Init prompt builder with prompt rewrite disabled. // Init prompt builder with prompt rewrite disabled.
PromptBuilder::new(prompt_template, false) PromptBuilder::new(prompt_template, None)
} }
#[test] #[test]
@ -379,14 +340,19 @@ def this_is_prefix():\n";
let expected_built_prefix = "\ let expected_built_prefix = "\
# Below are some relevant python snippets found in the repository: # Below are some relevant python snippets found in the repository:
#
# == Snippet 1 == # == Snippet 1 ==
# res_1 = invoke_function_1(n) # res_1 = invoke_function_1(n)
#
# == Snippet 2 == # == Snippet 2 ==
# res_2 = invoke_function_2(n) # res_2 = invoke_function_2(n)
#
# == Snippet 3 == # == Snippet 3 ==
# res_3 = invoke_function_3(n) # res_3 = invoke_function_3(n)
#
# == Snippet 4 == # == Snippet 4 ==
# res_4 = invoke_function_4(n) # res_4 = invoke_function_4(n)
#
# == Snippet 5 == # == Snippet 5 ==
# res_5 = invoke_function_5(n) # res_5 = invoke_function_5(n)
''' '''

View File

@ -22,7 +22,7 @@ use tabby_common::{
use tabby_download::Downloader; use tabby_download::Downloader;
use tokio::time::sleep; use tokio::time::sleep;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use tracing::{info, warn}; use tracing::{debug, info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi}; use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -62,6 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
chat::ChatCompletionChunk, chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
search::SearchResponse,
search::Hit,
search::HitDocument
)) ))
)] )]
struct ApiDoc; struct ApiDoc;
@ -92,10 +95,6 @@ pub struct ServeArgs {
#[clap(long)] #[clap(long)]
chat_model: Option<String>, chat_model: Option<String>,
/// When set to `true`, the search API route will be enabled.
#[clap(long, default_value_t = false)]
enable_search: bool,
#[clap(long, default_value_t = 8080)] #[clap(long, default_value_t = 8080)]
port: u16, port: u16,
@ -144,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
doc.override_doc(args, &config.swagger); doc.override_doc(args, &config.swagger);
let app = Router::new() let app = Router::new()
.merge(api_router(args, config)) .merge(api_router(args))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc)) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.fallback(fallback()); .fallback(fallback());
@ -165,7 +164,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
} }
fn api_router(args: &ServeArgs, config: &Config) -> Router { fn api_router(args: &ServeArgs) -> Router {
let index_server = match IndexServer::load() {
Ok(index_server) => Some(Arc::new(index_server)),
Err(err) => {
debug!("Load index failed due to `{}`", err);
None
}
};
let completion_state = { let completion_state = {
let ( let (
engine, engine,
@ -174,7 +181,11 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
}, },
) = create_engine(&args.model, args); ) = create_engine(&args.model, args);
let engine = Arc::new(engine); let engine = Arc::new(engine);
let state = completions::CompletionState::new(engine.clone(), prompt_template, config); let state = completions::CompletionState::new(
engine.clone(),
index_server.clone(),
prompt_template,
);
Arc::new(state) Arc::new(state)
}; };
@ -201,19 +212,20 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
routing::post(completions::completions).with_state(completion_state), routing::post(completions::completions).with_state(completion_state),
); );
let router = if args.enable_search { let router = if let Some(chat_state) = chat_state {
router.route( router.route(
"/v1beta/search", "/v1beta/chat/completions",
routing::get(search::search).with_state(Arc::new(IndexServer::new())), routing::post(chat::completions).with_state(chat_state),
) )
} else { } else {
router router
}; };
let router = if let Some(chat_state) = chat_state { let router = if let Some(index_server) = index_server {
info!("Index is ready, enabling /v1beta/search API route");
router.route( router.route(
"/v1beta/chat/completions", "/v1beta/search",
routing::post(chat::completions).with_state(chat_state), routing::get(search::search).with_state(index_server),
) )
} else { } else {
router router

View File

@ -11,11 +11,11 @@ use tabby_common::{index::IndexExt, path};
use tantivy::{ use tantivy::{
collector::{Count, TopDocs}, collector::{Count, TopDocs},
query::QueryParser, query::QueryParser,
schema::{Field, FieldType, NamedFieldDocument, Schema}, schema::Field,
DocAddress, Document, Index, IndexReader, Score, DocAddress, Document, Index, IndexReader,
}; };
use tracing::instrument; use tracing::instrument;
use utoipa::IntoParams; use utoipa::{IntoParams, ToSchema};
#[derive(Deserialize, IntoParams)] #[derive(Deserialize, IntoParams)]
pub struct SearchQuery { pub struct SearchQuery {
@ -29,18 +29,27 @@ pub struct SearchQuery {
offset: Option<usize>, offset: Option<usize>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub struct SearchResponse { pub struct SearchResponse {
q: String, pub num_hits: usize,
num_hits: usize, pub hits: Vec<Hit>,
hits: Vec<Hit>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub struct Hit { pub struct Hit {
score: Score, pub score: f32,
doc: NamedFieldDocument, pub doc: HitDocument,
id: u32, pub id: u32,
}
#[derive(Serialize, ToSchema)]
pub struct HitDocument {
pub body: String,
pub filepath: String,
pub git_url: String,
pub kind: String,
pub language: String,
pub name: String,
} }
#[utoipa::path( #[utoipa::path(
@ -50,7 +59,7 @@ pub struct Hit {
operation_id = "search", operation_id = "search",
tag = "v1beta", tag = "v1beta",
responses( responses(
(status = 200, description = "Success" , content_type = "application/json"), (status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"), (status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
) )
)] )]
@ -73,40 +82,41 @@ pub async fn search(
pub struct IndexServer { pub struct IndexServer {
reader: IndexReader, reader: IndexReader,
query_parser: QueryParser, query_parser: QueryParser,
schema: Schema,
field_body: Field,
field_filepath: Field,
field_git_url: Field,
field_kind: Field,
field_language: Field,
field_name: Field,
} }
impl IndexServer { impl IndexServer {
pub fn new() -> Self { pub fn load() -> Result<Self> {
Self::load().expect("Failed to load code state")
}
fn load() -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?; let index = Index::open_in_dir(path::index_dir())?;
index.register_tokenizer(); index.register_tokenizer();
let schema = index.schema(); let schema = index.schema();
let default_fields: Vec<Field> = schema let field_body = schema.get_field("body").unwrap();
.fields()
.filter(|&(_, field_entry)| match field_entry.field_type() {
FieldType::Str(ref text_field_options) => {
text_field_options.get_indexing_options().is_some()
}
_ => false,
})
.map(|(field, _)| field)
.collect();
let query_parser = let query_parser =
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone()); QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
let reader = index.reader()?; let reader = index
.reader_builder()
.reload_policy(tantivy::ReloadPolicy::OnCommit)
.try_into()?;
Ok(Self { Ok(Self {
reader, reader,
query_parser, query_parser,
schema, field_body,
field_filepath: schema.get_field("filepath").unwrap(),
field_git_url: schema.get_field("git_url").unwrap(),
field_kind: schema.get_field("kind").unwrap(),
field_language: schema.get_field("language").unwrap(),
field_name: schema.get_field("name").unwrap(),
}) })
} }
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> { pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
let query = self let query = self
.query_parser .query_parser
.parse_query(q) .parse_query(q)
@ -127,18 +137,28 @@ impl IndexServer {
}) })
.collect() .collect()
}; };
Ok(SearchResponse { Ok(SearchResponse { num_hits, hits })
q: q.to_owned(),
num_hits,
hits,
})
} }
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit { fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
Hit { Hit {
score, score,
doc: self.schema.to_named_doc(&doc), doc: HitDocument {
body: get_field(&doc, self.field_body),
filepath: get_field(&doc, self.field_filepath),
git_url: get_field(&doc, self.field_git_url),
kind: get_field(&doc, self.field_kind),
name: get_field(&doc, self.field_name),
language: get_field(&doc, self.field_language),
},
id: doc_address.doc_id, id: doc_address.doc_id,
} }
} }
} }
fn get_field(doc: &Document, field: Field) -> String {
doc.get_first(field)
.and_then(|x| x.as_text())
.unwrap()
.to_owned()
}