feat: connect prompt rewriting part (#517)

* feat: enable /v1beta/search if index is available

* make prompt rewriting work

* update

* fix test

* fix api doc
wsxiaoys-patch-1
Meng Zhang 2023-10-06 17:29:24 -07:00 committed by GitHub
parent 8497fb1372
commit d85a7892d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 145 additions and 163 deletions

View File

@ -13,19 +13,10 @@ pub struct Config {
#[serde(default)]
pub repositories: Vec<Repository>,
#[serde(default)]
pub experimental: Experimental,
#[serde(default)]
pub swagger: SwaggerConfig,
}
#[derive(Serialize, Deserialize, Default)]
pub struct Experimental {
#[serde(default = "default_as_false")]
pub enable_prompt_rewrite: bool,
}
#[derive(Serialize, Deserialize, Default)]
pub struct SwaggerConfig {
pub server_url: Option<String>,
@ -64,10 +55,6 @@ impl Repository {
}
}
fn default_as_false() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::Config;

View File

@ -3,7 +3,7 @@ mod tests {
use std::fs::create_dir_all;
use tabby_common::{
config::{Config, Experimental, Repository, SwaggerConfig},
config::{Config, Repository, SwaggerConfig},
path::set_tabby_root,
};
use temp_testdir::*;
@ -21,7 +21,6 @@ mod tests {
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
}],
swagger: SwaggerConfig { server_url: None },
experimental: Experimental::default(),
};
config.save();

View File

@ -6,12 +6,13 @@ use std::sync::Arc;
use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{config::Config, events};
use tabby_common::events;
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use self::languages::get_stop_words;
use super::search::IndexServer;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
@ -127,15 +128,12 @@ pub struct CompletionState {
impl CompletionState {
pub fn new(
engine: Arc<Box<dyn TextGeneration>>,
index_server: Option<Arc<IndexServer>>,
prompt_template: Option<String>,
config: &Config,
) -> Self {
Self {
engine,
prompt_builder: prompt::PromptBuilder::new(
prompt_template,
config.experimental.enable_prompt_rewrite,
),
prompt_builder: prompt::PromptBuilder::new(prompt_template, index_server),
}
}
}

View File

@ -1,41 +1,32 @@
use std::collections::HashMap;
use std::{collections::HashMap, env, sync::Arc};
use anyhow::Result;
use lazy_static::lazy_static;
use strfmt::strfmt;
use tabby_common::path::index_dir;
use tantivy::{
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
};
use tracing::{info, warn};
use super::Segments;
use crate::serve::search::IndexServer;
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_PER_NAME: u32 = 1;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
pub struct PromptBuilder {
prompt_template: Option<String>,
index: Option<IndexState>,
index_server: Option<Arc<IndexServer>>,
}
impl PromptBuilder {
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
let index = if enable_prompt_rewrite {
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
let index = IndexState::new();
if let Err(err) = &index {
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
}
index.ok()
pub fn new(prompt_template: Option<String>, index_server: Option<Arc<IndexServer>>) -> Self {
let index_server = if env::var("TABBY_ENABLE_PROMPT_REWRITE").is_ok() {
info!("Prompt rewriting is enabled...");
index_server
} else {
None
};
PromptBuilder {
prompt_template,
index,
index_server,
}
}
@ -53,8 +44,8 @@ impl PromptBuilder {
}
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
if let Some(index) = &self.index {
rewrite_with_index(index, language, segments)
if let Some(index_server) = &self.index_server {
rewrite_with_index(index_server, language, segments)
} else {
segments
}
@ -74,8 +65,12 @@ fn get_default_suffix(suffix: Option<String>) -> String {
}
}
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
let snippets = collect_snippets(index, language, &segments.prefix);
fn rewrite_with_index(
index_server: &Arc<IndexServer>,
language: &str,
segments: Segments,
) -> Segments {
let snippets = collect_snippets(index_server, language, &segments.prefix);
if snippets.is_empty() {
segments
} else {
@ -85,11 +80,18 @@ fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) ->
}
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
if snippets.is_empty() {
return prefix.to_owned();
}
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
let mut lines: Vec<String> = vec![format!(
"Below are some relevant {} snippets found in the repository:",
language
)];
let mut lines: Vec<String> = vec![
format!(
"Below are some relevant {} snippets found in the repository:",
language
),
"".to_owned(),
];
let mut count_characters = 0;
for (i, snippet) in snippets.iter().enumerate() {
@ -102,60 +104,51 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
lines.push(line.to_owned());
}
if i < snippets.len() - 1 {
lines.push("".to_owned());
}
count_characters += snippet.len();
}
let commented_lines: Vec<String> = lines
.iter()
.map(|x| format!("{} {}", comment_char, x))
.map(|x| {
if x.is_empty() {
comment_char.to_string()
} else {
format!("{} {}", comment_char, x)
}
})
.collect();
let comments = commented_lines.join("\n");
format!("{}\n{}", comments, prefix)
}
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<String> {
let mut ret = Vec::new();
let sanitized_text = sanitize_text(text);
if sanitized_text.is_empty() {
return ret;
}
let query_text = format!(
"language:{} AND kind:call AND ({})",
language, sanitized_text
);
let query = match index.query_parser.parse_query(&query_text) {
Ok(query) => query,
let query_text = format!("language:{} AND ({})", language, sanitized_text);
let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) {
Ok(serp) => serp,
Err(err) => {
warn!("Failed to parse query: {}", err);
warn!("Failed to search query: {}", err);
return ret;
}
};
let top_docs = index
.searcher
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
.unwrap();
for hit in serp.hits {
let body = hit.doc.body;
let mut names: HashMap<String, u32> = HashMap::new();
for (_score, doc_address) in top_docs {
let doc = index.searcher.doc(doc_address).unwrap();
let name = doc
.get_first(index.field_name)
.and_then(|x| x.as_text())
.unwrap();
let count = *names.get(name).unwrap_or(&0);
// Max 1 snippet per identifier.
if count >= MAX_SNIPPET_PER_NAME {
if text.contains(&body) {
// Exclude snippets already in the context window.
continue;
}
let body = doc
.get_first(index.field_body)
.and_then(|x| x.as_text())
.unwrap();
names.insert(name.to_owned(), count + 1);
ret.push(body.to_owned());
}
@ -172,41 +165,9 @@ fn sanitize_text(text: &str) -> String {
tokens.join(" ")
}
struct IndexState {
searcher: Searcher,
query_parser: QueryParser,
field_name: Field,
field_body: Field,
}
impl IndexState {
fn new() -> Result<IndexState> {
let index = Index::open_in_dir(index_dir())?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommit)
.try_into()?;
let field_name = index.schema().get_field("name")?;
let field_body = index.schema().get_field("body")?;
let query_parser = QueryParser::for_index(&index, vec![field_body]);
Ok(Self {
searcher: reader.searcher(),
query_parser,
field_name,
field_body,
})
}
}
lazy_static! {
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = HashMap::from([
("python", "#"),
("rust", "//"),
("javascript-typescript", "//"),
("go", "//"),
("java", "//"),
("lua", "--"),
]);
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
HashMap::from([("python", "#"), ("rust", "//"),]);
}
#[cfg(test)]
@ -222,7 +183,7 @@ mod tests {
};
// Init prompt builder with prompt rewrite disabled.
PromptBuilder::new(prompt_template, false)
PromptBuilder::new(prompt_template, None)
}
#[test]
@ -379,14 +340,19 @@ def this_is_prefix():\n";
let expected_built_prefix = "\
# Below are some relevant python snippets found in the repository:
#
# == Snippet 1 ==
# res_1 = invoke_function_1(n)
#
# == Snippet 2 ==
# res_2 = invoke_function_2(n)
#
# == Snippet 3 ==
# res_3 = invoke_function_3(n)
#
# == Snippet 4 ==
# res_4 = invoke_function_4(n)
#
# == Snippet 5 ==
# res_5 = invoke_function_5(n)
'''

View File

@ -22,7 +22,7 @@ use tabby_common::{
use tabby_download::Downloader;
use tokio::time::sleep;
use tower_http::cors::CorsLayer;
use tracing::{info, warn};
use tracing::{debug, info, warn};
use utoipa::{openapi::ServerBuilder, OpenApi};
use utoipa_swagger_ui::SwaggerUi;
@ -62,6 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
chat::ChatCompletionChunk,
health::HealthState,
health::Version,
search::SearchResponse,
search::Hit,
search::HitDocument
))
)]
struct ApiDoc;
@ -92,10 +95,6 @@ pub struct ServeArgs {
#[clap(long)]
chat_model: Option<String>,
/// When set to `true`, the search API route will be enabled.
#[clap(long, default_value_t = false)]
enable_search: bool,
#[clap(long, default_value_t = 8080)]
port: u16,
@ -144,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
doc.override_doc(args, &config.swagger);
let app = Router::new()
.merge(api_router(args, config))
.merge(api_router(args))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.fallback(fallback());
@ -165,7 +164,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}
fn api_router(args: &ServeArgs, config: &Config) -> Router {
fn api_router(args: &ServeArgs) -> Router {
let index_server = match IndexServer::load() {
Ok(index_server) => Some(Arc::new(index_server)),
Err(err) => {
debug!("Load index failed due to `{}`", err);
None
}
};
let completion_state = {
let (
engine,
@ -174,7 +181,11 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
},
) = create_engine(&args.model, args);
let engine = Arc::new(engine);
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
let state = completions::CompletionState::new(
engine.clone(),
index_server.clone(),
prompt_template,
);
Arc::new(state)
};
@ -201,19 +212,20 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
routing::post(completions::completions).with_state(completion_state),
);
let router = if args.enable_search {
let router = if let Some(chat_state) = chat_state {
router.route(
"/v1beta/search",
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
"/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state),
)
} else {
router
};
let router = if let Some(chat_state) = chat_state {
let router = if let Some(index_server) = index_server {
info!("Index is ready, enabling /v1beta/search API route");
router.route(
"/v1beta/chat/completions",
routing::post(chat::completions).with_state(chat_state),
"/v1beta/search",
routing::get(search::search).with_state(index_server),
)
} else {
router

View File

@ -11,11 +11,11 @@ use tabby_common::{index::IndexExt, path};
use tantivy::{
collector::{Count, TopDocs},
query::QueryParser,
schema::{Field, FieldType, NamedFieldDocument, Schema},
DocAddress, Document, Index, IndexReader, Score,
schema::Field,
DocAddress, Document, Index, IndexReader,
};
use tracing::instrument;
use utoipa::IntoParams;
use utoipa::{IntoParams, ToSchema};
#[derive(Deserialize, IntoParams)]
pub struct SearchQuery {
@ -29,18 +29,27 @@ pub struct SearchQuery {
offset: Option<usize>,
}
#[derive(Serialize)]
#[derive(Serialize, ToSchema)]
pub struct SearchResponse {
q: String,
num_hits: usize,
hits: Vec<Hit>,
pub num_hits: usize,
pub hits: Vec<Hit>,
}
#[derive(Serialize)]
#[derive(Serialize, ToSchema)]
pub struct Hit {
score: Score,
doc: NamedFieldDocument,
id: u32,
pub score: f32,
pub doc: HitDocument,
pub id: u32,
}
#[derive(Serialize, ToSchema)]
pub struct HitDocument {
pub body: String,
pub filepath: String,
pub git_url: String,
pub kind: String,
pub language: String,
pub name: String,
}
#[utoipa::path(
@ -50,7 +59,7 @@ pub struct Hit {
operation_id = "search",
tag = "v1beta",
responses(
(status = 200, description = "Success" , content_type = "application/json"),
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
)
)]
@ -73,40 +82,41 @@ pub async fn search(
pub struct IndexServer {
reader: IndexReader,
query_parser: QueryParser,
schema: Schema,
field_body: Field,
field_filepath: Field,
field_git_url: Field,
field_kind: Field,
field_language: Field,
field_name: Field,
}
impl IndexServer {
pub fn new() -> Self {
Self::load().expect("Failed to load code state")
}
fn load() -> Result<Self> {
pub fn load() -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?;
index.register_tokenizer();
let schema = index.schema();
let default_fields: Vec<Field> = schema
.fields()
.filter(|&(_, field_entry)| match field_entry.field_type() {
FieldType::Str(ref text_field_options) => {
text_field_options.get_indexing_options().is_some()
}
_ => false,
})
.map(|(field, _)| field)
.collect();
let field_body = schema.get_field("body").unwrap();
let query_parser =
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
let reader = index.reader()?;
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
let reader = index
.reader_builder()
.reload_policy(tantivy::ReloadPolicy::OnCommit)
.try_into()?;
Ok(Self {
reader,
query_parser,
schema,
field_body,
field_filepath: schema.get_field("filepath").unwrap(),
field_git_url: schema.get_field("git_url").unwrap(),
field_kind: schema.get_field("kind").unwrap(),
field_language: schema.get_field("language").unwrap(),
field_name: schema.get_field("name").unwrap(),
})
}
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
let query = self
.query_parser
.parse_query(q)
@ -127,18 +137,28 @@ impl IndexServer {
})
.collect()
};
Ok(SearchResponse {
q: q.to_owned(),
num_hits,
hits,
})
Ok(SearchResponse { num_hits, hits })
}
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
Hit {
score,
doc: self.schema.to_named_doc(&doc),
doc: HitDocument {
body: get_field(&doc, self.field_body),
filepath: get_field(&doc, self.field_filepath),
git_url: get_field(&doc, self.field_git_url),
kind: get_field(&doc, self.field_kind),
name: get_field(&doc, self.field_name),
language: get_field(&doc, self.field_language),
},
id: doc_address.doc_id,
}
}
}
fn get_field(doc: &Document, field: Field) -> String {
doc.get_first(field)
.and_then(|x| x.as_text())
.unwrap()
.to_owned()
}