feat: connect prompt rewriting part (#517)
* feat: enable /v1beta/search if index is available * make prompt rewriting work * update * fix test * fix api docwsxiaoys-patch-1
parent
8497fb1372
commit
d85a7892d1
|
|
@ -13,19 +13,10 @@ pub struct Config {
|
|||
#[serde(default)]
|
||||
pub repositories: Vec<Repository>,
|
||||
|
||||
#[serde(default)]
|
||||
pub experimental: Experimental,
|
||||
|
||||
#[serde(default)]
|
||||
pub swagger: SwaggerConfig,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct Experimental {
|
||||
#[serde(default = "default_as_false")]
|
||||
pub enable_prompt_rewrite: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct SwaggerConfig {
|
||||
pub server_url: Option<String>,
|
||||
|
|
@ -64,10 +55,6 @@ impl Repository {
|
|||
}
|
||||
}
|
||||
|
||||
fn default_as_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Config;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ mod tests {
|
|||
use std::fs::create_dir_all;
|
||||
|
||||
use tabby_common::{
|
||||
config::{Config, Experimental, Repository, SwaggerConfig},
|
||||
config::{Config, Repository, SwaggerConfig},
|
||||
path::set_tabby_root,
|
||||
};
|
||||
use temp_testdir::*;
|
||||
|
|
@ -21,7 +21,6 @@ mod tests {
|
|||
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
||||
}],
|
||||
swagger: SwaggerConfig { server_url: None },
|
||||
experimental: Experimental::default(),
|
||||
};
|
||||
|
||||
config.save();
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@ use std::sync::Arc;
|
|||
use axum::{extract::State, Json};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{config::Config, events};
|
||||
use tabby_common::events;
|
||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||
use tracing::{debug, instrument};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use self::languages::get_stop_words;
|
||||
use super::search::IndexServer;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
#[schema(example=json!({
|
||||
|
|
@ -127,15 +128,12 @@ pub struct CompletionState {
|
|||
impl CompletionState {
|
||||
pub fn new(
|
||||
engine: Arc<Box<dyn TextGeneration>>,
|
||||
index_server: Option<Arc<IndexServer>>,
|
||||
prompt_template: Option<String>,
|
||||
config: &Config,
|
||||
) -> Self {
|
||||
Self {
|
||||
engine,
|
||||
prompt_builder: prompt::PromptBuilder::new(
|
||||
prompt_template,
|
||||
config.experimental.enable_prompt_rewrite,
|
||||
),
|
||||
prompt_builder: prompt::PromptBuilder::new(prompt_template, index_server),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,41 +1,32 @@
|
|||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use lazy_static::lazy_static;
|
||||
use strfmt::strfmt;
|
||||
use tabby_common::path::index_dir;
|
||||
use tantivy::{
|
||||
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::Segments;
|
||||
use crate::serve::search::IndexServer;
|
||||
|
||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||
static MAX_SNIPPET_PER_NAME: u32 = 1;
|
||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
||||
|
||||
pub struct PromptBuilder {
|
||||
prompt_template: Option<String>,
|
||||
index: Option<IndexState>,
|
||||
index_server: Option<Arc<IndexServer>>,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
|
||||
let index = if enable_prompt_rewrite {
|
||||
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
|
||||
let index = IndexState::new();
|
||||
if let Err(err) = &index {
|
||||
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
|
||||
}
|
||||
index.ok()
|
||||
pub fn new(prompt_template: Option<String>, index_server: Option<Arc<IndexServer>>) -> Self {
|
||||
let index_server = if env::var("TABBY_ENABLE_PROMPT_REWRITE").is_ok() {
|
||||
info!("Prompt rewriting is enabled...");
|
||||
index_server
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
PromptBuilder {
|
||||
prompt_template,
|
||||
index,
|
||||
index_server,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -53,8 +44,8 @@ impl PromptBuilder {
|
|||
}
|
||||
|
||||
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
|
||||
if let Some(index) = &self.index {
|
||||
rewrite_with_index(index, language, segments)
|
||||
if let Some(index_server) = &self.index_server {
|
||||
rewrite_with_index(index_server, language, segments)
|
||||
} else {
|
||||
segments
|
||||
}
|
||||
|
|
@ -74,8 +65,12 @@ fn get_default_suffix(suffix: Option<String>) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
|
||||
let snippets = collect_snippets(index, language, &segments.prefix);
|
||||
fn rewrite_with_index(
|
||||
index_server: &Arc<IndexServer>,
|
||||
language: &str,
|
||||
segments: Segments,
|
||||
) -> Segments {
|
||||
let snippets = collect_snippets(index_server, language, &segments.prefix);
|
||||
if snippets.is_empty() {
|
||||
segments
|
||||
} else {
|
||||
|
|
@ -85,11 +80,18 @@ fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) ->
|
|||
}
|
||||
|
||||
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||
if snippets.is_empty() {
|
||||
return prefix.to_owned();
|
||||
}
|
||||
|
||||
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
||||
let mut lines: Vec<String> = vec![format!(
|
||||
"Below are some relevant {} snippets found in the repository:",
|
||||
language
|
||||
)];
|
||||
let mut lines: Vec<String> = vec![
|
||||
format!(
|
||||
"Below are some relevant {} snippets found in the repository:",
|
||||
language
|
||||
),
|
||||
"".to_owned(),
|
||||
];
|
||||
|
||||
let mut count_characters = 0;
|
||||
for (i, snippet) in snippets.iter().enumerate() {
|
||||
|
|
@ -102,60 +104,51 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
|||
lines.push(line.to_owned());
|
||||
}
|
||||
|
||||
if i < snippets.len() - 1 {
|
||||
lines.push("".to_owned());
|
||||
}
|
||||
count_characters += snippet.len();
|
||||
}
|
||||
|
||||
let commented_lines: Vec<String> = lines
|
||||
.iter()
|
||||
.map(|x| format!("{} {}", comment_char, x))
|
||||
.map(|x| {
|
||||
if x.is_empty() {
|
||||
comment_char.to_string()
|
||||
} else {
|
||||
format!("{} {}", comment_char, x)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let comments = commented_lines.join("\n");
|
||||
format!("{}\n{}", comments, prefix)
|
||||
}
|
||||
|
||||
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
|
||||
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<String> {
|
||||
let mut ret = Vec::new();
|
||||
let sanitized_text = sanitize_text(text);
|
||||
if sanitized_text.is_empty() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
let query_text = format!(
|
||||
"language:{} AND kind:call AND ({})",
|
||||
language, sanitized_text
|
||||
);
|
||||
let query = match index.query_parser.parse_query(&query_text) {
|
||||
Ok(query) => query,
|
||||
let query_text = format!("language:{} AND ({})", language, sanitized_text);
|
||||
|
||||
let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) {
|
||||
Ok(serp) => serp,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse query: {}", err);
|
||||
warn!("Failed to search query: {}", err);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
let top_docs = index
|
||||
.searcher
|
||||
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
|
||||
.unwrap();
|
||||
for hit in serp.hits {
|
||||
let body = hit.doc.body;
|
||||
|
||||
let mut names: HashMap<String, u32> = HashMap::new();
|
||||
for (_score, doc_address) in top_docs {
|
||||
let doc = index.searcher.doc(doc_address).unwrap();
|
||||
let name = doc
|
||||
.get_first(index.field_name)
|
||||
.and_then(|x| x.as_text())
|
||||
.unwrap();
|
||||
let count = *names.get(name).unwrap_or(&0);
|
||||
|
||||
// Max 1 snippet per identifier.
|
||||
if count >= MAX_SNIPPET_PER_NAME {
|
||||
if text.contains(&body) {
|
||||
// Exclude snippets already in the context window.
|
||||
continue;
|
||||
}
|
||||
|
||||
let body = doc
|
||||
.get_first(index.field_body)
|
||||
.and_then(|x| x.as_text())
|
||||
.unwrap();
|
||||
names.insert(name.to_owned(), count + 1);
|
||||
ret.push(body.to_owned());
|
||||
}
|
||||
|
||||
|
|
@ -172,41 +165,9 @@ fn sanitize_text(text: &str) -> String {
|
|||
tokens.join(" ")
|
||||
}
|
||||
|
||||
struct IndexState {
|
||||
searcher: Searcher,
|
||||
query_parser: QueryParser,
|
||||
field_name: Field,
|
||||
field_body: Field,
|
||||
}
|
||||
|
||||
impl IndexState {
|
||||
fn new() -> Result<IndexState> {
|
||||
let index = Index::open_in_dir(index_dir())?;
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(ReloadPolicy::OnCommit)
|
||||
.try_into()?;
|
||||
let field_name = index.schema().get_field("name")?;
|
||||
let field_body = index.schema().get_field("body")?;
|
||||
let query_parser = QueryParser::for_index(&index, vec![field_body]);
|
||||
Ok(Self {
|
||||
searcher: reader.searcher(),
|
||||
query_parser,
|
||||
field_name,
|
||||
field_body,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = HashMap::from([
|
||||
("python", "#"),
|
||||
("rust", "//"),
|
||||
("javascript-typescript", "//"),
|
||||
("go", "//"),
|
||||
("java", "//"),
|
||||
("lua", "--"),
|
||||
]);
|
||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
||||
HashMap::from([("python", "#"), ("rust", "//"),]);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -222,7 +183,7 @@ mod tests {
|
|||
};
|
||||
|
||||
// Init prompt builder with prompt rewrite disabled.
|
||||
PromptBuilder::new(prompt_template, false)
|
||||
PromptBuilder::new(prompt_template, None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -379,14 +340,19 @@ def this_is_prefix():\n";
|
|||
|
||||
let expected_built_prefix = "\
|
||||
# Below are some relevant python snippets found in the repository:
|
||||
#
|
||||
# == Snippet 1 ==
|
||||
# res_1 = invoke_function_1(n)
|
||||
#
|
||||
# == Snippet 2 ==
|
||||
# res_2 = invoke_function_2(n)
|
||||
#
|
||||
# == Snippet 3 ==
|
||||
# res_3 = invoke_function_3(n)
|
||||
#
|
||||
# == Snippet 4 ==
|
||||
# res_4 = invoke_function_4(n)
|
||||
#
|
||||
# == Snippet 5 ==
|
||||
# res_5 = invoke_function_5(n)
|
||||
'''
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ use tabby_common::{
|
|||
use tabby_download::Downloader;
|
||||
use tokio::time::sleep;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
use utoipa::{openapi::ServerBuilder, OpenApi};
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
|
|
@ -62,6 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
chat::ChatCompletionChunk,
|
||||
health::HealthState,
|
||||
health::Version,
|
||||
search::SearchResponse,
|
||||
search::Hit,
|
||||
search::HitDocument
|
||||
))
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
|
@ -92,10 +95,6 @@ pub struct ServeArgs {
|
|||
#[clap(long)]
|
||||
chat_model: Option<String>,
|
||||
|
||||
/// When set to `true`, the search API route will be enabled.
|
||||
#[clap(long, default_value_t = false)]
|
||||
enable_search: bool,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
doc.override_doc(args, &config.swagger);
|
||||
|
||||
let app = Router::new()
|
||||
.merge(api_router(args, config))
|
||||
.merge(api_router(args))
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
||||
.fallback(fallback());
|
||||
|
||||
|
|
@ -165,7 +164,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||
}
|
||||
|
||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||
fn api_router(args: &ServeArgs) -> Router {
|
||||
let index_server = match IndexServer::load() {
|
||||
Ok(index_server) => Some(Arc::new(index_server)),
|
||||
Err(err) => {
|
||||
debug!("Load index failed due to `{}`", err);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let completion_state = {
|
||||
let (
|
||||
engine,
|
||||
|
|
@ -174,7 +181,11 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
},
|
||||
) = create_engine(&args.model, args);
|
||||
let engine = Arc::new(engine);
|
||||
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
||||
let state = completions::CompletionState::new(
|
||||
engine.clone(),
|
||||
index_server.clone(),
|
||||
prompt_template,
|
||||
);
|
||||
Arc::new(state)
|
||||
};
|
||||
|
||||
|
|
@ -201,19 +212,20 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
routing::post(completions::completions).with_state(completion_state),
|
||||
);
|
||||
|
||||
let router = if args.enable_search {
|
||||
let router = if let Some(chat_state) = chat_state {
|
||||
router.route(
|
||||
"/v1beta/search",
|
||||
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
|
||||
"/v1beta/chat/completions",
|
||||
routing::post(chat::completions).with_state(chat_state),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
let router = if let Some(chat_state) = chat_state {
|
||||
let router = if let Some(index_server) = index_server {
|
||||
info!("Index is ready, enabling /v1beta/search API route");
|
||||
router.route(
|
||||
"/v1beta/chat/completions",
|
||||
routing::post(chat::completions).with_state(chat_state),
|
||||
"/v1beta/search",
|
||||
routing::get(search::search).with_state(index_server),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ use tabby_common::{index::IndexExt, path};
|
|||
use tantivy::{
|
||||
collector::{Count, TopDocs},
|
||||
query::QueryParser,
|
||||
schema::{Field, FieldType, NamedFieldDocument, Schema},
|
||||
DocAddress, Document, Index, IndexReader, Score,
|
||||
schema::Field,
|
||||
DocAddress, Document, Index, IndexReader,
|
||||
};
|
||||
use tracing::instrument;
|
||||
use utoipa::IntoParams;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
#[derive(Deserialize, IntoParams)]
|
||||
pub struct SearchQuery {
|
||||
|
|
@ -29,18 +29,27 @@ pub struct SearchQuery {
|
|||
offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub struct SearchResponse {
|
||||
q: String,
|
||||
num_hits: usize,
|
||||
hits: Vec<Hit>,
|
||||
pub num_hits: usize,
|
||||
pub hits: Vec<Hit>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub struct Hit {
|
||||
score: Score,
|
||||
doc: NamedFieldDocument,
|
||||
id: u32,
|
||||
pub score: f32,
|
||||
pub doc: HitDocument,
|
||||
pub id: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub struct HitDocument {
|
||||
pub body: String,
|
||||
pub filepath: String,
|
||||
pub git_url: String,
|
||||
pub kind: String,
|
||||
pub language: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
|
|
@ -50,7 +59,7 @@ pub struct Hit {
|
|||
operation_id = "search",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success" , content_type = "application/json"),
|
||||
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
|
||||
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
|
||||
)
|
||||
)]
|
||||
|
|
@ -73,40 +82,41 @@ pub async fn search(
|
|||
pub struct IndexServer {
|
||||
reader: IndexReader,
|
||||
query_parser: QueryParser,
|
||||
schema: Schema,
|
||||
|
||||
field_body: Field,
|
||||
field_filepath: Field,
|
||||
field_git_url: Field,
|
||||
field_kind: Field,
|
||||
field_language: Field,
|
||||
field_name: Field,
|
||||
}
|
||||
|
||||
impl IndexServer {
|
||||
pub fn new() -> Self {
|
||||
Self::load().expect("Failed to load code state")
|
||||
}
|
||||
|
||||
fn load() -> Result<Self> {
|
||||
pub fn load() -> Result<Self> {
|
||||
let index = Index::open_in_dir(path::index_dir())?;
|
||||
index.register_tokenizer();
|
||||
|
||||
let schema = index.schema();
|
||||
let default_fields: Vec<Field> = schema
|
||||
.fields()
|
||||
.filter(|&(_, field_entry)| match field_entry.field_type() {
|
||||
FieldType::Str(ref text_field_options) => {
|
||||
text_field_options.get_indexing_options().is_some()
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
.map(|(field, _)| field)
|
||||
.collect();
|
||||
let field_body = schema.get_field("body").unwrap();
|
||||
let query_parser =
|
||||
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
|
||||
let reader = index.reader()?;
|
||||
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
|
||||
let reader = index
|
||||
.reader_builder()
|
||||
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
||||
.try_into()?;
|
||||
Ok(Self {
|
||||
reader,
|
||||
query_parser,
|
||||
schema,
|
||||
field_body,
|
||||
field_filepath: schema.get_field("filepath").unwrap(),
|
||||
field_git_url: schema.get_field("git_url").unwrap(),
|
||||
field_kind: schema.get_field("kind").unwrap(),
|
||||
field_language: schema.get_field("language").unwrap(),
|
||||
field_name: schema.get_field("name").unwrap(),
|
||||
})
|
||||
}
|
||||
|
||||
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
||||
pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
||||
let query = self
|
||||
.query_parser
|
||||
.parse_query(q)
|
||||
|
|
@ -127,18 +137,28 @@ impl IndexServer {
|
|||
})
|
||||
.collect()
|
||||
};
|
||||
Ok(SearchResponse {
|
||||
q: q.to_owned(),
|
||||
num_hits,
|
||||
hits,
|
||||
})
|
||||
Ok(SearchResponse { num_hits, hits })
|
||||
}
|
||||
|
||||
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
|
||||
fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
|
||||
Hit {
|
||||
score,
|
||||
doc: self.schema.to_named_doc(&doc),
|
||||
doc: HitDocument {
|
||||
body: get_field(&doc, self.field_body),
|
||||
filepath: get_field(&doc, self.field_filepath),
|
||||
git_url: get_field(&doc, self.field_git_url),
|
||||
kind: get_field(&doc, self.field_kind),
|
||||
name: get_field(&doc, self.field_name),
|
||||
language: get_field(&doc, self.field_language),
|
||||
},
|
||||
id: doc_address.doc_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_field(doc: &Document, field: Field) -> String {
|
||||
doc.get_first(field)
|
||||
.and_then(|x| x.as_text())
|
||||
.unwrap()
|
||||
.to_owned()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue