feat: connect prompt rewriting part (#517)
* feat: enable /v1beta/search if index is available * make prompt rewriting work * update * fix test * fix api docwsxiaoys-patch-1
parent
8497fb1372
commit
d85a7892d1
|
|
@ -13,19 +13,10 @@ pub struct Config {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub repositories: Vec<Repository>,
|
pub repositories: Vec<Repository>,
|
||||||
|
|
||||||
#[serde(default)]
|
|
||||||
pub experimental: Experimental,
|
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub swagger: SwaggerConfig,
|
pub swagger: SwaggerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default)]
|
|
||||||
pub struct Experimental {
|
|
||||||
#[serde(default = "default_as_false")]
|
|
||||||
pub enable_prompt_rewrite: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default)]
|
#[derive(Serialize, Deserialize, Default)]
|
||||||
pub struct SwaggerConfig {
|
pub struct SwaggerConfig {
|
||||||
pub server_url: Option<String>,
|
pub server_url: Option<String>,
|
||||||
|
|
@ -64,10 +55,6 @@ impl Repository {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_as_false() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::Config;
|
use super::Config;
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ mod tests {
|
||||||
use std::fs::create_dir_all;
|
use std::fs::create_dir_all;
|
||||||
|
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
config::{Config, Experimental, Repository, SwaggerConfig},
|
config::{Config, Repository, SwaggerConfig},
|
||||||
path::set_tabby_root,
|
path::set_tabby_root,
|
||||||
};
|
};
|
||||||
use temp_testdir::*;
|
use temp_testdir::*;
|
||||||
|
|
@ -21,7 +21,6 @@ mod tests {
|
||||||
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
|
||||||
}],
|
}],
|
||||||
swagger: SwaggerConfig { server_url: None },
|
swagger: SwaggerConfig { server_url: None },
|
||||||
experimental: Experimental::default(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
config.save();
|
config.save();
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,13 @@ use std::sync::Arc;
|
||||||
use axum::{extract::State, Json};
|
use axum::{extract::State, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tabby_common::{config::Config, events};
|
use tabby_common::events;
|
||||||
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use self::languages::get_stop_words;
|
use self::languages::get_stop_words;
|
||||||
|
use super::search::IndexServer;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
#[schema(example=json!({
|
#[schema(example=json!({
|
||||||
|
|
@ -127,15 +128,12 @@ pub struct CompletionState {
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
|
index_server: Option<Arc<IndexServer>>,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
config: &Config,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: prompt::PromptBuilder::new(
|
prompt_builder: prompt::PromptBuilder::new(prompt_template, index_server),
|
||||||
prompt_template,
|
|
||||||
config.experimental.enable_prompt_rewrite,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,32 @@
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, env, sync::Arc};
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use strfmt::strfmt;
|
use strfmt::strfmt;
|
||||||
use tabby_common::path::index_dir;
|
|
||||||
use tantivy::{
|
|
||||||
collector::TopDocs, query::QueryParser, schema::Field, Index, ReloadPolicy, Searcher,
|
|
||||||
};
|
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use super::Segments;
|
use super::Segments;
|
||||||
|
use crate::serve::search::IndexServer;
|
||||||
|
|
||||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
static MAX_SNIPPET_PER_NAME: u32 = 1;
|
|
||||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 512;
|
||||||
|
|
||||||
pub struct PromptBuilder {
|
pub struct PromptBuilder {
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
index: Option<IndexState>,
|
index_server: Option<Arc<IndexServer>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PromptBuilder {
|
impl PromptBuilder {
|
||||||
pub fn new(prompt_template: Option<String>, enable_prompt_rewrite: bool) -> Self {
|
pub fn new(prompt_template: Option<String>, index_server: Option<Arc<IndexServer>>) -> Self {
|
||||||
let index = if enable_prompt_rewrite {
|
let index_server = if env::var("TABBY_ENABLE_PROMPT_REWRITE").is_ok() {
|
||||||
info!("Experimental feature `enable_prompt_rewrite` is enabled, loading index ...");
|
info!("Prompt rewriting is enabled...");
|
||||||
let index = IndexState::new();
|
index_server
|
||||||
if let Err(err) = &index {
|
|
||||||
warn!("Failed to open index in {:?}: {:?}", index_dir(), err);
|
|
||||||
}
|
|
||||||
index.ok()
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
PromptBuilder {
|
PromptBuilder {
|
||||||
prompt_template,
|
prompt_template,
|
||||||
index,
|
index_server,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -53,8 +44,8 @@ impl PromptBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
|
fn rewrite(&self, language: &str, segments: Segments) -> Segments {
|
||||||
if let Some(index) = &self.index {
|
if let Some(index_server) = &self.index_server {
|
||||||
rewrite_with_index(index, language, segments)
|
rewrite_with_index(index_server, language, segments)
|
||||||
} else {
|
} else {
|
||||||
segments
|
segments
|
||||||
}
|
}
|
||||||
|
|
@ -74,8 +65,12 @@ fn get_default_suffix(suffix: Option<String>) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) -> Segments {
|
fn rewrite_with_index(
|
||||||
let snippets = collect_snippets(index, language, &segments.prefix);
|
index_server: &Arc<IndexServer>,
|
||||||
|
language: &str,
|
||||||
|
segments: Segments,
|
||||||
|
) -> Segments {
|
||||||
|
let snippets = collect_snippets(index_server, language, &segments.prefix);
|
||||||
if snippets.is_empty() {
|
if snippets.is_empty() {
|
||||||
segments
|
segments
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -85,11 +80,18 @@ fn rewrite_with_index(index: &IndexState, language: &str, segments: Segments) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||||
|
if snippets.is_empty() {
|
||||||
|
return prefix.to_owned();
|
||||||
|
}
|
||||||
|
|
||||||
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
let comment_char = LANGUAGE_LINE_COMMENT_CHAR.get(language).unwrap();
|
||||||
let mut lines: Vec<String> = vec![format!(
|
let mut lines: Vec<String> = vec![
|
||||||
"Below are some relevant {} snippets found in the repository:",
|
format!(
|
||||||
language
|
"Below are some relevant {} snippets found in the repository:",
|
||||||
)];
|
language
|
||||||
|
),
|
||||||
|
"".to_owned(),
|
||||||
|
];
|
||||||
|
|
||||||
let mut count_characters = 0;
|
let mut count_characters = 0;
|
||||||
for (i, snippet) in snippets.iter().enumerate() {
|
for (i, snippet) in snippets.iter().enumerate() {
|
||||||
|
|
@ -102,60 +104,51 @@ fn build_prefix(language: &str, prefix: &str, snippets: Vec<String>) -> String {
|
||||||
lines.push(line.to_owned());
|
lines.push(line.to_owned());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if i < snippets.len() - 1 {
|
||||||
|
lines.push("".to_owned());
|
||||||
|
}
|
||||||
count_characters += snippet.len();
|
count_characters += snippet.len();
|
||||||
}
|
}
|
||||||
|
|
||||||
let commented_lines: Vec<String> = lines
|
let commented_lines: Vec<String> = lines
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| format!("{} {}", comment_char, x))
|
.map(|x| {
|
||||||
|
if x.is_empty() {
|
||||||
|
comment_char.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{} {}", comment_char, x)
|
||||||
|
}
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let comments = commented_lines.join("\n");
|
let comments = commented_lines.join("\n");
|
||||||
format!("{}\n{}", comments, prefix)
|
format!("{}\n{}", comments, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn collect_snippets(index: &IndexState, language: &str, text: &str) -> Vec<String> {
|
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<String> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
let sanitized_text = sanitize_text(text);
|
let sanitized_text = sanitize_text(text);
|
||||||
if sanitized_text.is_empty() {
|
if sanitized_text.is_empty() {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
let query_text = format!(
|
let query_text = format!("language:{} AND ({})", language, sanitized_text);
|
||||||
"language:{} AND kind:call AND ({})",
|
|
||||||
language, sanitized_text
|
let serp = match index_server.search(&query_text, MAX_SNIPPETS_TO_FETCH, 0) {
|
||||||
);
|
Ok(serp) => serp,
|
||||||
let query = match index.query_parser.parse_query(&query_text) {
|
|
||||||
Ok(query) => query,
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to parse query: {}", err);
|
warn!("Failed to search query: {}", err);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let top_docs = index
|
for hit in serp.hits {
|
||||||
.searcher
|
let body = hit.doc.body;
|
||||||
.search(&query, &TopDocs::with_limit(MAX_SNIPPETS_TO_FETCH))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let mut names: HashMap<String, u32> = HashMap::new();
|
if text.contains(&body) {
|
||||||
for (_score, doc_address) in top_docs {
|
// Exclude snippets already in the context window.
|
||||||
let doc = index.searcher.doc(doc_address).unwrap();
|
|
||||||
let name = doc
|
|
||||||
.get_first(index.field_name)
|
|
||||||
.and_then(|x| x.as_text())
|
|
||||||
.unwrap();
|
|
||||||
let count = *names.get(name).unwrap_or(&0);
|
|
||||||
|
|
||||||
// Max 1 snippet per identifier.
|
|
||||||
if count >= MAX_SNIPPET_PER_NAME {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let body = doc
|
|
||||||
.get_first(index.field_body)
|
|
||||||
.and_then(|x| x.as_text())
|
|
||||||
.unwrap();
|
|
||||||
names.insert(name.to_owned(), count + 1);
|
|
||||||
ret.push(body.to_owned());
|
ret.push(body.to_owned());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -172,41 +165,9 @@ fn sanitize_text(text: &str) -> String {
|
||||||
tokens.join(" ")
|
tokens.join(" ")
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IndexState {
|
|
||||||
searcher: Searcher,
|
|
||||||
query_parser: QueryParser,
|
|
||||||
field_name: Field,
|
|
||||||
field_body: Field,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexState {
|
|
||||||
fn new() -> Result<IndexState> {
|
|
||||||
let index = Index::open_in_dir(index_dir())?;
|
|
||||||
let reader = index
|
|
||||||
.reader_builder()
|
|
||||||
.reload_policy(ReloadPolicy::OnCommit)
|
|
||||||
.try_into()?;
|
|
||||||
let field_name = index.schema().get_field("name")?;
|
|
||||||
let field_body = index.schema().get_field("body")?;
|
|
||||||
let query_parser = QueryParser::for_index(&index, vec![field_body]);
|
|
||||||
Ok(Self {
|
|
||||||
searcher: reader.searcher(),
|
|
||||||
query_parser,
|
|
||||||
field_name,
|
|
||||||
field_body,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> = HashMap::from([
|
static ref LANGUAGE_LINE_COMMENT_CHAR: HashMap<&'static str, &'static str> =
|
||||||
("python", "#"),
|
HashMap::from([("python", "#"), ("rust", "//"),]);
|
||||||
("rust", "//"),
|
|
||||||
("javascript-typescript", "//"),
|
|
||||||
("go", "//"),
|
|
||||||
("java", "//"),
|
|
||||||
("lua", "--"),
|
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -222,7 +183,7 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Init prompt builder with prompt rewrite disabled.
|
// Init prompt builder with prompt rewrite disabled.
|
||||||
PromptBuilder::new(prompt_template, false)
|
PromptBuilder::new(prompt_template, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -379,14 +340,19 @@ def this_is_prefix():\n";
|
||||||
|
|
||||||
let expected_built_prefix = "\
|
let expected_built_prefix = "\
|
||||||
# Below are some relevant python snippets found in the repository:
|
# Below are some relevant python snippets found in the repository:
|
||||||
|
#
|
||||||
# == Snippet 1 ==
|
# == Snippet 1 ==
|
||||||
# res_1 = invoke_function_1(n)
|
# res_1 = invoke_function_1(n)
|
||||||
|
#
|
||||||
# == Snippet 2 ==
|
# == Snippet 2 ==
|
||||||
# res_2 = invoke_function_2(n)
|
# res_2 = invoke_function_2(n)
|
||||||
|
#
|
||||||
# == Snippet 3 ==
|
# == Snippet 3 ==
|
||||||
# res_3 = invoke_function_3(n)
|
# res_3 = invoke_function_3(n)
|
||||||
|
#
|
||||||
# == Snippet 4 ==
|
# == Snippet 4 ==
|
||||||
# res_4 = invoke_function_4(n)
|
# res_4 = invoke_function_4(n)
|
||||||
|
#
|
||||||
# == Snippet 5 ==
|
# == Snippet 5 ==
|
||||||
# res_5 = invoke_function_5(n)
|
# res_5 = invoke_function_5(n)
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ use tabby_common::{
|
||||||
use tabby_download::Downloader;
|
use tabby_download::Downloader;
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
use tracing::{info, warn};
|
use tracing::{debug, info, warn};
|
||||||
use utoipa::{openapi::ServerBuilder, OpenApi};
|
use utoipa::{openapi::ServerBuilder, OpenApi};
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
|
@ -62,6 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
chat::ChatCompletionChunk,
|
chat::ChatCompletionChunk,
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
health::Version,
|
health::Version,
|
||||||
|
search::SearchResponse,
|
||||||
|
search::Hit,
|
||||||
|
search::HitDocument
|
||||||
))
|
))
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
@ -92,10 +95,6 @@ pub struct ServeArgs {
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
chat_model: Option<String>,
|
chat_model: Option<String>,
|
||||||
|
|
||||||
/// When set to `true`, the search API route will be enabled.
|
|
||||||
#[clap(long, default_value_t = false)]
|
|
||||||
enable_search: bool,
|
|
||||||
|
|
||||||
#[clap(long, default_value_t = 8080)]
|
#[clap(long, default_value_t = 8080)]
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
|
|
@ -144,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
doc.override_doc(args, &config.swagger);
|
doc.override_doc(args, &config.swagger);
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(api_router(args, config))
|
.merge(api_router(args))
|
||||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
|
||||||
.fallback(fallback());
|
.fallback(fallback());
|
||||||
|
|
||||||
|
|
@ -165,7 +164,15 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
fn api_router(args: &ServeArgs) -> Router {
|
||||||
|
let index_server = match IndexServer::load() {
|
||||||
|
Ok(index_server) => Some(Arc::new(index_server)),
|
||||||
|
Err(err) => {
|
||||||
|
debug!("Load index failed due to `{}`", err);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
let (
|
let (
|
||||||
engine,
|
engine,
|
||||||
|
|
@ -174,7 +181,11 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
},
|
},
|
||||||
) = create_engine(&args.model, args);
|
) = create_engine(&args.model, args);
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
let state = completions::CompletionState::new(engine.clone(), prompt_template, config);
|
let state = completions::CompletionState::new(
|
||||||
|
engine.clone(),
|
||||||
|
index_server.clone(),
|
||||||
|
prompt_template,
|
||||||
|
);
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -201,19 +212,20 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
routing::post(completions::completions).with_state(completion_state),
|
routing::post(completions::completions).with_state(completion_state),
|
||||||
);
|
);
|
||||||
|
|
||||||
let router = if args.enable_search {
|
let router = if let Some(chat_state) = chat_state {
|
||||||
router.route(
|
router.route(
|
||||||
"/v1beta/search",
|
"/v1beta/chat/completions",
|
||||||
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
|
routing::post(chat::completions).with_state(chat_state),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
router
|
router
|
||||||
};
|
};
|
||||||
|
|
||||||
let router = if let Some(chat_state) = chat_state {
|
let router = if let Some(index_server) = index_server {
|
||||||
|
info!("Index is ready, enabling /v1beta/search API route");
|
||||||
router.route(
|
router.route(
|
||||||
"/v1beta/chat/completions",
|
"/v1beta/search",
|
||||||
routing::post(chat::completions).with_state(chat_state),
|
routing::get(search::search).with_state(index_server),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
router
|
router
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,11 @@ use tabby_common::{index::IndexExt, path};
|
||||||
use tantivy::{
|
use tantivy::{
|
||||||
collector::{Count, TopDocs},
|
collector::{Count, TopDocs},
|
||||||
query::QueryParser,
|
query::QueryParser,
|
||||||
schema::{Field, FieldType, NamedFieldDocument, Schema},
|
schema::Field,
|
||||||
DocAddress, Document, Index, IndexReader, Score,
|
DocAddress, Document, Index, IndexReader,
|
||||||
};
|
};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
use utoipa::IntoParams;
|
use utoipa::{IntoParams, ToSchema};
|
||||||
|
|
||||||
#[derive(Deserialize, IntoParams)]
|
#[derive(Deserialize, IntoParams)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
|
|
@ -29,18 +29,27 @@ pub struct SearchQuery {
|
||||||
offset: Option<usize>,
|
offset: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub struct SearchResponse {
|
pub struct SearchResponse {
|
||||||
q: String,
|
pub num_hits: usize,
|
||||||
num_hits: usize,
|
pub hits: Vec<Hit>,
|
||||||
hits: Vec<Hit>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub struct Hit {
|
pub struct Hit {
|
||||||
score: Score,
|
pub score: f32,
|
||||||
doc: NamedFieldDocument,
|
pub doc: HitDocument,
|
||||||
id: u32,
|
pub id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct HitDocument {
|
||||||
|
pub body: String,
|
||||||
|
pub filepath: String,
|
||||||
|
pub git_url: String,
|
||||||
|
pub kind: String,
|
||||||
|
pub language: String,
|
||||||
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
|
|
@ -50,7 +59,7 @@ pub struct Hit {
|
||||||
operation_id = "search",
|
operation_id = "search",
|
||||||
tag = "v1beta",
|
tag = "v1beta",
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Success" , content_type = "application/json"),
|
(status = 200, description = "Success" , body = SearchResponse, content_type = "application/json"),
|
||||||
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
|
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
|
|
@ -73,40 +82,41 @@ pub async fn search(
|
||||||
pub struct IndexServer {
|
pub struct IndexServer {
|
||||||
reader: IndexReader,
|
reader: IndexReader,
|
||||||
query_parser: QueryParser,
|
query_parser: QueryParser,
|
||||||
schema: Schema,
|
|
||||||
|
field_body: Field,
|
||||||
|
field_filepath: Field,
|
||||||
|
field_git_url: Field,
|
||||||
|
field_kind: Field,
|
||||||
|
field_language: Field,
|
||||||
|
field_name: Field,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IndexServer {
|
impl IndexServer {
|
||||||
pub fn new() -> Self {
|
pub fn load() -> Result<Self> {
|
||||||
Self::load().expect("Failed to load code state")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load() -> Result<Self> {
|
|
||||||
let index = Index::open_in_dir(path::index_dir())?;
|
let index = Index::open_in_dir(path::index_dir())?;
|
||||||
index.register_tokenizer();
|
index.register_tokenizer();
|
||||||
|
|
||||||
let schema = index.schema();
|
let schema = index.schema();
|
||||||
let default_fields: Vec<Field> = schema
|
let field_body = schema.get_field("body").unwrap();
|
||||||
.fields()
|
|
||||||
.filter(|&(_, field_entry)| match field_entry.field_type() {
|
|
||||||
FieldType::Str(ref text_field_options) => {
|
|
||||||
text_field_options.get_indexing_options().is_some()
|
|
||||||
}
|
|
||||||
_ => false,
|
|
||||||
})
|
|
||||||
.map(|(field, _)| field)
|
|
||||||
.collect();
|
|
||||||
let query_parser =
|
let query_parser =
|
||||||
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
|
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
|
||||||
let reader = index.reader()?;
|
let reader = index
|
||||||
|
.reader_builder()
|
||||||
|
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
||||||
|
.try_into()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
reader,
|
reader,
|
||||||
query_parser,
|
query_parser,
|
||||||
schema,
|
field_body,
|
||||||
|
field_filepath: schema.get_field("filepath").unwrap(),
|
||||||
|
field_git_url: schema.get_field("git_url").unwrap(),
|
||||||
|
field_kind: schema.get_field("kind").unwrap(),
|
||||||
|
field_language: schema.get_field("language").unwrap(),
|
||||||
|
field_name: schema.get_field("name").unwrap(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
||||||
let query = self
|
let query = self
|
||||||
.query_parser
|
.query_parser
|
||||||
.parse_query(q)
|
.parse_query(q)
|
||||||
|
|
@ -127,18 +137,28 @@ impl IndexServer {
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
};
|
};
|
||||||
Ok(SearchResponse {
|
Ok(SearchResponse { num_hits, hits })
|
||||||
q: q.to_owned(),
|
|
||||||
num_hits,
|
|
||||||
hits,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
|
fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
|
||||||
Hit {
|
Hit {
|
||||||
score,
|
score,
|
||||||
doc: self.schema.to_named_doc(&doc),
|
doc: HitDocument {
|
||||||
|
body: get_field(&doc, self.field_body),
|
||||||
|
filepath: get_field(&doc, self.field_filepath),
|
||||||
|
git_url: get_field(&doc, self.field_git_url),
|
||||||
|
kind: get_field(&doc, self.field_kind),
|
||||||
|
name: get_field(&doc, self.field_name),
|
||||||
|
language: get_field(&doc, self.field_language),
|
||||||
|
},
|
||||||
id: doc_address.doc_id,
|
id: doc_address.doc_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_field(doc: &Document, field: Field) -> String {
|
||||||
|
doc.get_first(field)
|
||||||
|
.and_then(|x| x.as_text())
|
||||||
|
.unwrap()
|
||||||
|
.to_owned()
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue