refactor: extract IndexServer into CodeSearchService (#728)
* refactor: extract IndexServer into CodeSearchService * refactor: make CodeSearchService interface to be asyncrefactor-extract-code
parent
8ab35b2639
commit
72d1d9f0bb
|
|
@ -1,4 +1,5 @@
|
||||||
mod download;
|
mod download;
|
||||||
|
mod search;
|
||||||
mod serve;
|
mod serve;
|
||||||
|
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,263 @@
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use axum::async_trait;
|
||||||
|
use serde::Serialize;
|
||||||
|
use tabby_common::{index::IndexExt, path};
|
||||||
|
use tantivy::{
|
||||||
|
collector::{Count, TopDocs},
|
||||||
|
query::{QueryParser, TermQuery, TermSetQuery},
|
||||||
|
schema::{Field, IndexRecordOption},
|
||||||
|
DocAddress, Document, Index, IndexReader, Term,
|
||||||
|
};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::{sync::Mutex, time::sleep};
|
||||||
|
use tracing::{debug, log::info};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct SearchResponse {
|
||||||
|
pub num_hits: usize,
|
||||||
|
pub hits: Vec<Hit>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct Hit {
|
||||||
|
pub score: f32,
|
||||||
|
pub doc: HitDocument,
|
||||||
|
pub id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct HitDocument {
|
||||||
|
pub body: String,
|
||||||
|
pub filepath: String,
|
||||||
|
pub git_url: String,
|
||||||
|
pub kind: String,
|
||||||
|
pub language: String,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum CodeSearchError {
|
||||||
|
#[error("index not ready")]
|
||||||
|
NotReady,
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
QueryParserError(#[from] tantivy::query::QueryParserError),
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
TantivyError(#[from] tantivy::TantivyError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait CodeSearch {
|
||||||
|
async fn search(
|
||||||
|
&self,
|
||||||
|
q: &str,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError>;
|
||||||
|
|
||||||
|
async fn search_with_query(
|
||||||
|
&self,
|
||||||
|
q: &dyn tantivy::query::Query,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CodeSearchImpl {
|
||||||
|
reader: IndexReader,
|
||||||
|
query_parser: QueryParser,
|
||||||
|
|
||||||
|
field_body: Field,
|
||||||
|
field_filepath: Field,
|
||||||
|
field_git_url: Field,
|
||||||
|
field_kind: Field,
|
||||||
|
field_language: Field,
|
||||||
|
field_name: Field,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CodeSearchImpl {
|
||||||
|
fn load() -> Result<Self> {
|
||||||
|
let index = Index::open_in_dir(path::index_dir())?;
|
||||||
|
index.register_tokenizer();
|
||||||
|
|
||||||
|
let schema = index.schema();
|
||||||
|
let field_body = schema.get_field("body").unwrap();
|
||||||
|
let query_parser =
|
||||||
|
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
|
||||||
|
let reader = index
|
||||||
|
.reader_builder()
|
||||||
|
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
||||||
|
.try_into()?;
|
||||||
|
Ok(Self {
|
||||||
|
reader,
|
||||||
|
query_parser,
|
||||||
|
field_body,
|
||||||
|
field_filepath: schema.get_field("filepath").unwrap(),
|
||||||
|
field_git_url: schema.get_field("git_url").unwrap(),
|
||||||
|
field_kind: schema.get_field("kind").unwrap(),
|
||||||
|
field_language: schema.get_field("language").unwrap(),
|
||||||
|
field_name: schema.get_field("name").unwrap(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_async() -> CodeSearchImpl {
|
||||||
|
loop {
|
||||||
|
match CodeSearchImpl::load() {
|
||||||
|
Ok(code) => {
|
||||||
|
info!("Index is ready, enabling server...");
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
debug!("Source code index is not ready `{}`", err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
sleep(Duration::from_secs(60)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
|
||||||
|
Hit {
|
||||||
|
score,
|
||||||
|
doc: HitDocument {
|
||||||
|
body: get_field(&doc, self.field_body),
|
||||||
|
filepath: get_field(&doc, self.field_filepath),
|
||||||
|
git_url: get_field(&doc, self.field_git_url),
|
||||||
|
kind: get_field(&doc, self.field_kind),
|
||||||
|
name: get_field(&doc, self.field_name),
|
||||||
|
language: get_field(&doc, self.field_language),
|
||||||
|
},
|
||||||
|
id: doc_address.doc_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CodeSearch for CodeSearchImpl {
|
||||||
|
async fn search(
|
||||||
|
&self,
|
||||||
|
q: &str,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError> {
|
||||||
|
let query = self.query_parser.parse_query(q)?;
|
||||||
|
self.search_with_query(&query, limit, offset).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn search_with_query(
|
||||||
|
&self,
|
||||||
|
q: &dyn tantivy::query::Query,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError> {
|
||||||
|
let searcher = self.reader.searcher();
|
||||||
|
let (top_docs, num_hits) =
|
||||||
|
{ searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? };
|
||||||
|
let hits: Vec<Hit> = {
|
||||||
|
top_docs
|
||||||
|
.iter()
|
||||||
|
.map(|(score, doc_address)| {
|
||||||
|
let doc = searcher.doc(*doc_address).unwrap();
|
||||||
|
self.create_hit(*score, doc, *doc_address)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
Ok(SearchResponse { num_hits, hits })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_field(doc: &Document, field: Field) -> String {
|
||||||
|
doc.get_first(field)
|
||||||
|
.and_then(|x| x.as_text())
|
||||||
|
.unwrap()
|
||||||
|
.to_owned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CodeSearchService {
|
||||||
|
search: Arc<Mutex<Option<CodeSearchImpl>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CodeSearchService {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let search = Arc::new(Mutex::new(None));
|
||||||
|
|
||||||
|
let ret = Self {
|
||||||
|
search: search.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let code = CodeSearchImpl::load_async().await;
|
||||||
|
*search.lock().await = Some(code);
|
||||||
|
});
|
||||||
|
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn with_impl<T, F>(&self, op: F) -> Result<T, CodeSearchError>
|
||||||
|
where
|
||||||
|
F: FnOnce(&CodeSearchImpl) -> Result<T, CodeSearchError>,
|
||||||
|
{
|
||||||
|
if let Some(imp) = self.search.lock().await.as_ref() {
|
||||||
|
op(imp)
|
||||||
|
} else {
|
||||||
|
Err(CodeSearchError::NotReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn language_query(&self, language: &str) -> Result<Box<TermQuery>, CodeSearchError> {
|
||||||
|
self.with_impl(|imp| {
|
||||||
|
Ok(Box::new(TermQuery::new(
|
||||||
|
Term::from_field_text(imp.field_language, language),
|
||||||
|
IndexRecordOption::WithFreqsAndPositions,
|
||||||
|
)))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn body_query(
|
||||||
|
&self,
|
||||||
|
tokens: &[String],
|
||||||
|
) -> Result<Box<TermSetQuery>, CodeSearchError> {
|
||||||
|
self.with_impl(|imp| {
|
||||||
|
Ok(Box::new(TermSetQuery::new(
|
||||||
|
tokens
|
||||||
|
.iter()
|
||||||
|
.map(|x| Term::from_field_text(imp.field_body, x)),
|
||||||
|
)))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl CodeSearch for CodeSearchService {
|
||||||
|
async fn search(
|
||||||
|
&self,
|
||||||
|
q: &str,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError> {
|
||||||
|
if let Some(imp) = self.search.lock().await.as_ref() {
|
||||||
|
imp.search(q, limit, offset).await
|
||||||
|
} else {
|
||||||
|
Err(CodeSearchError::NotReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn search_with_query(
|
||||||
|
&self,
|
||||||
|
q: &dyn tantivy::query::Query,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError> {
|
||||||
|
if let Some(imp) = self.search.lock().await.as_ref() {
|
||||||
|
imp.search_with_query(q, limit, offset).await
|
||||||
|
} else {
|
||||||
|
Err(CodeSearchError::NotReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -10,7 +10,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
use super::search::IndexServer;
|
use crate::search::CodeSearchService;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||||
#[schema(example=json!({
|
#[schema(example=json!({
|
||||||
|
|
@ -137,7 +137,8 @@ pub async fn completions(
|
||||||
(prompt, None, vec![])
|
(prompt, None, vec![])
|
||||||
} else if let Some(segments) = request.segments {
|
} else if let Some(segments) = request.segments {
|
||||||
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
|
||||||
let (prompt, snippets) = build_prompt(&state, &request.debug_options, &language, &segments);
|
let (prompt, snippets) =
|
||||||
|
build_prompt(&state, &request.debug_options, &language, &segments).await;
|
||||||
(prompt, Some(segments), snippets)
|
(prompt, Some(segments), snippets)
|
||||||
} else {
|
} else {
|
||||||
return Err(StatusCode::BAD_REQUEST);
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
|
@ -180,7 +181,7 @@ pub async fn completions(
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prompt(
|
async fn build_prompt(
|
||||||
state: &Arc<CompletionState>,
|
state: &Arc<CompletionState>,
|
||||||
debug_options: &Option<DebugOptions>,
|
debug_options: &Option<DebugOptions>,
|
||||||
language: &str,
|
language: &str,
|
||||||
|
|
@ -190,7 +191,7 @@ fn build_prompt(
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
|
||||||
{
|
{
|
||||||
state.prompt_builder.collect(language, segments)
|
state.prompt_builder.collect(language, segments).await
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
};
|
};
|
||||||
|
|
@ -210,12 +211,12 @@ pub struct CompletionState {
|
||||||
impl CompletionState {
|
impl CompletionState {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
engine: Arc<Box<dyn TextGeneration>>,
|
engine: Arc<Box<dyn TextGeneration>>,
|
||||||
index_server: Arc<IndexServer>,
|
code: Arc<CodeSearchService>,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
engine,
|
engine,
|
||||||
prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(index_server)),
|
prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ use textdistance::Algorithm;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
use super::{Segments, Snippet};
|
use super::{Segments, Snippet};
|
||||||
use crate::serve::search::{IndexServer, IndexServerError};
|
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService};
|
||||||
|
|
||||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
||||||
|
|
@ -17,14 +17,14 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
|
||||||
|
|
||||||
pub struct PromptBuilder {
|
pub struct PromptBuilder {
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
index_server: Option<Arc<IndexServer>>,
|
code: Option<Arc<CodeSearchService>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PromptBuilder {
|
impl PromptBuilder {
|
||||||
pub fn new(prompt_template: Option<String>, index_server: Option<Arc<IndexServer>>) -> Self {
|
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self {
|
||||||
PromptBuilder {
|
PromptBuilder {
|
||||||
prompt_template,
|
prompt_template,
|
||||||
index_server,
|
code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -36,9 +36,9 @@ impl PromptBuilder {
|
||||||
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
|
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
|
||||||
if let Some(index_server) = &self.index_server {
|
if let Some(code) = &self.code {
|
||||||
collect_snippets(index_server, language, &segments.prefix)
|
collect_snippets(code, language, &segments.prefix).await
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
|
|
@ -105,14 +105,14 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
|
||||||
format!("{}\n{}", comments, prefix)
|
format!("{}\n{}", comments, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec<Snippet> {
|
async fn collect_snippets(code: &CodeSearchService, language: &str, text: &str) -> Vec<Snippet> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
let mut tokens = tokenize_text(text);
|
let mut tokens = tokenize_text(text);
|
||||||
|
|
||||||
let Ok(language_query) = index_server.language_query(language) else {
|
let Ok(language_query) = code.language_query(language).await else {
|
||||||
return vec![];
|
return vec![];
|
||||||
};
|
};
|
||||||
let Ok(body_query) = index_server.body_query(&tokens) else {
|
let Ok(body_query) = code.body_query(&tokens).await else {
|
||||||
return vec![];
|
return vec![];
|
||||||
};
|
};
|
||||||
let query = BooleanQuery::new(vec![
|
let query = BooleanQuery::new(vec![
|
||||||
|
|
@ -120,14 +120,21 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V
|
||||||
(Occur::Must, body_query),
|
(Occur::Must, body_query),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
let serp = match index_server.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) {
|
let serp = match code
|
||||||
|
.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(serp) => serp,
|
Ok(serp) => serp,
|
||||||
Err(IndexServerError::NotReady) => {
|
Err(CodeSearchError::NotReady) => {
|
||||||
// Ignore.
|
// Ignore.
|
||||||
return vec![];
|
return vec![];
|
||||||
}
|
}
|
||||||
Err(IndexServerError::TantivyError(err)) => {
|
Err(CodeSearchError::TantivyError(err)) => {
|
||||||
warn!("Failed to search query: {}", err);
|
warn!("Failed to search: {}", err);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
Err(CodeSearchError::QueryParserError(err)) => {
|
||||||
|
warn!("Failed to parse query: {}", err);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,8 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||||
use self::{
|
use self::{
|
||||||
engine::{create_engine, EngineInfo},
|
engine::{create_engine, EngineInfo},
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
search::IndexServer,
|
|
||||||
};
|
};
|
||||||
use crate::fatal;
|
use crate::{fatal, search::CodeSearchService};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
|
|
@ -63,9 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
||||||
chat::ChatCompletionChunk,
|
chat::ChatCompletionChunk,
|
||||||
health::HealthState,
|
health::HealthState,
|
||||||
health::Version,
|
health::Version,
|
||||||
search::SearchResponse,
|
crate::search::SearchResponse,
|
||||||
search::Hit,
|
crate::search::Hit,
|
||||||
search::HitDocument
|
crate::search::HitDocument
|
||||||
))
|
))
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
@ -170,7 +169,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
let index_server = Arc::new(IndexServer::new());
|
let code = Arc::new(CodeSearchService::new());
|
||||||
let completion_state = {
|
let completion_state = {
|
||||||
let (
|
let (
|
||||||
engine,
|
engine,
|
||||||
|
|
@ -179,11 +178,8 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
},
|
},
|
||||||
) = create_engine(&args.model, args).await;
|
) = create_engine(&args.model, args).await;
|
||||||
let engine = Arc::new(engine);
|
let engine = Arc::new(engine);
|
||||||
let state = completions::CompletionState::new(
|
let state =
|
||||||
engine.clone(),
|
completions::CompletionState::new(engine.clone(), code.clone(), prompt_template);
|
||||||
index_server.clone(),
|
|
||||||
prompt_template,
|
|
||||||
);
|
|
||||||
Arc::new(state)
|
Arc::new(state)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -238,7 +234,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
||||||
routers.push({
|
routers.push({
|
||||||
Router::new().route(
|
Router::new().route(
|
||||||
"/v1beta/search",
|
"/v1beta/search",
|
||||||
routing::get(search::search).with_state(index_server),
|
routing::get(search::search).with_state(code),
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|
@ -6,18 +6,11 @@ use axum::{
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Deserialize;
|
||||||
use tabby_common::{index::IndexExt, path};
|
use tracing::{instrument, warn};
|
||||||
use tantivy::{
|
use utoipa::IntoParams;
|
||||||
collector::{Count, TopDocs},
|
|
||||||
query::{QueryParser, TermQuery, TermSetQuery},
|
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService, SearchResponse};
|
||||||
schema::{Field, IndexRecordOption},
|
|
||||||
DocAddress, Document, Index, IndexReader, Term,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::{sync::OnceCell, task, time::sleep};
|
|
||||||
use tracing::{debug, instrument, log::info, warn};
|
|
||||||
use utoipa::{IntoParams, ToSchema};
|
|
||||||
|
|
||||||
#[derive(Deserialize, IntoParams)]
|
#[derive(Deserialize, IntoParams)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
|
|
@ -31,29 +24,6 @@ pub struct SearchQuery {
|
||||||
offset: Option<usize>,
|
offset: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct SearchResponse {
|
|
||||||
pub num_hits: usize,
|
|
||||||
pub hits: Vec<Hit>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct Hit {
|
|
||||||
pub score: f32,
|
|
||||||
pub doc: HitDocument,
|
|
||||||
pub id: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct HitDocument {
|
|
||||||
pub body: String,
|
|
||||||
pub filepath: String,
|
|
||||||
pub git_url: String,
|
|
||||||
pub kind: String,
|
|
||||||
pub language: String,
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
params(SearchQuery),
|
params(SearchQuery),
|
||||||
|
|
@ -67,193 +37,26 @@ pub struct HitDocument {
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(state, query))]
|
#[instrument(skip(state, query))]
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
State(state): State<Arc<IndexServer>>,
|
State(state): State<Arc<CodeSearchService>>,
|
||||||
query: Query<SearchQuery>,
|
query: Query<SearchQuery>,
|
||||||
) -> Result<Json<SearchResponse>, StatusCode> {
|
) -> Result<Json<SearchResponse>, StatusCode> {
|
||||||
match state.search(
|
match state
|
||||||
&query.q,
|
.search(
|
||||||
query.limit.unwrap_or(20),
|
&query.q,
|
||||||
query.offset.unwrap_or(0),
|
query.limit.unwrap_or(20),
|
||||||
) {
|
query.offset.unwrap_or(0),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(serp) => Ok(Json(serp)),
|
Ok(serp) => Ok(Json(serp)),
|
||||||
Err(IndexServerError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED),
|
Err(CodeSearchError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED),
|
||||||
Err(IndexServerError::TantivyError(err)) => {
|
Err(CodeSearchError::TantivyError(err)) => {
|
||||||
warn!("{}", err);
|
warn!("{}", err);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
}
|
}
|
||||||
}
|
Err(CodeSearchError::QueryParserError(err)) => {
|
||||||
}
|
warn!("{}", err);
|
||||||
|
Err(StatusCode::BAD_REQUEST)
|
||||||
struct IndexServerImpl {
|
|
||||||
reader: IndexReader,
|
|
||||||
query_parser: QueryParser,
|
|
||||||
|
|
||||||
field_body: Field,
|
|
||||||
field_filepath: Field,
|
|
||||||
field_git_url: Field,
|
|
||||||
field_kind: Field,
|
|
||||||
field_language: Field,
|
|
||||||
field_name: Field,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IndexServerImpl {
|
|
||||||
pub fn load() -> Result<Self> {
|
|
||||||
let index = Index::open_in_dir(path::index_dir())?;
|
|
||||||
index.register_tokenizer();
|
|
||||||
|
|
||||||
let schema = index.schema();
|
|
||||||
let field_body = schema.get_field("body").unwrap();
|
|
||||||
let query_parser =
|
|
||||||
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
|
|
||||||
let reader = index
|
|
||||||
.reader_builder()
|
|
||||||
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
|
||||||
.try_into()?;
|
|
||||||
Ok(Self {
|
|
||||||
reader,
|
|
||||||
query_parser,
|
|
||||||
field_body,
|
|
||||||
field_filepath: schema.get_field("filepath").unwrap(),
|
|
||||||
field_git_url: schema.get_field("git_url").unwrap(),
|
|
||||||
field_kind: schema.get_field("kind").unwrap(),
|
|
||||||
field_language: schema.get_field("language").unwrap(),
|
|
||||||
field_name: schema.get_field("name").unwrap(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
|
||||||
let query = self.query_parser.parse_query(q)?;
|
|
||||||
self.search_with_query(&query, limit, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn search_with_query(
|
|
||||||
&self,
|
|
||||||
q: &dyn tantivy::query::Query,
|
|
||||||
limit: usize,
|
|
||||||
offset: usize,
|
|
||||||
) -> tantivy::Result<SearchResponse> {
|
|
||||||
let searcher = self.reader.searcher();
|
|
||||||
let (top_docs, num_hits) =
|
|
||||||
{ searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? };
|
|
||||||
let hits: Vec<Hit> = {
|
|
||||||
top_docs
|
|
||||||
.iter()
|
|
||||||
.map(|(score, doc_address)| {
|
|
||||||
let doc = searcher.doc(*doc_address).unwrap();
|
|
||||||
self.create_hit(*score, doc, *doc_address)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
};
|
|
||||||
Ok(SearchResponse { num_hits, hits })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit {
|
|
||||||
Hit {
|
|
||||||
score,
|
|
||||||
doc: HitDocument {
|
|
||||||
body: get_field(&doc, self.field_body),
|
|
||||||
filepath: get_field(&doc, self.field_filepath),
|
|
||||||
git_url: get_field(&doc, self.field_git_url),
|
|
||||||
kind: get_field(&doc, self.field_kind),
|
|
||||||
name: get_field(&doc, self.field_name),
|
|
||||||
language: get_field(&doc, self.field_language),
|
|
||||||
},
|
|
||||||
id: doc_address.doc_id,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_field(doc: &Document, field: Field) -> String {
|
|
||||||
doc.get_first(field)
|
|
||||||
.and_then(|x| x.as_text())
|
|
||||||
.unwrap()
|
|
||||||
.to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
static IMPL: OnceCell<IndexServerImpl> = OnceCell::const_new();
|
|
||||||
|
|
||||||
pub struct IndexServer {}
|
|
||||||
|
|
||||||
impl IndexServer {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
task::spawn(IMPL.get_or_init(|| async {
|
|
||||||
task::spawn(IndexServer::worker())
|
|
||||||
.await
|
|
||||||
.expect("Failed to create IndexServerImpl")
|
|
||||||
}));
|
|
||||||
Self {}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn with_impl<T, F>(&self, op: F) -> Result<T, IndexServerError>
|
|
||||||
where
|
|
||||||
F: FnOnce(&IndexServerImpl) -> Result<T, IndexServerError>,
|
|
||||||
{
|
|
||||||
if let Some(imp) = IMPL.get() {
|
|
||||||
op(imp)
|
|
||||||
} else {
|
|
||||||
Err(IndexServerError::NotReady)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn worker() -> IndexServerImpl {
|
|
||||||
loop {
|
|
||||||
match IndexServerImpl::load() {
|
|
||||||
Ok(index_server) => {
|
|
||||||
info!("Index is ready, enabling server...");
|
|
||||||
return index_server;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
debug!("Source code index is not ready `{}`", err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
sleep(Duration::from_secs(60)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn language_query(&self, language: &str) -> Result<Box<TermQuery>, IndexServerError> {
|
|
||||||
self.with_impl(|imp| {
|
|
||||||
Ok(Box::new(TermQuery::new(
|
|
||||||
Term::from_field_text(imp.field_language, language),
|
|
||||||
IndexRecordOption::WithFreqsAndPositions,
|
|
||||||
)))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn body_query(&self, tokens: &[String]) -> Result<Box<TermSetQuery>, IndexServerError> {
|
|
||||||
self.with_impl(|imp| {
|
|
||||||
Ok(Box::new(TermSetQuery::new(
|
|
||||||
tokens
|
|
||||||
.iter()
|
|
||||||
.map(|x| Term::from_field_text(imp.field_body, x)),
|
|
||||||
)))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn search(
|
|
||||||
&self,
|
|
||||||
q: &str,
|
|
||||||
limit: usize,
|
|
||||||
offset: usize,
|
|
||||||
) -> Result<SearchResponse, IndexServerError> {
|
|
||||||
self.with_impl(|imp| Ok(imp.search(q, limit, offset)?))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn search_with_query(
|
|
||||||
&self,
|
|
||||||
q: &dyn tantivy::query::Query,
|
|
||||||
limit: usize,
|
|
||||||
offset: usize,
|
|
||||||
) -> Result<SearchResponse, IndexServerError> {
|
|
||||||
self.with_impl(|imp| Ok(imp.search_with_query(q, limit, offset)?))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum IndexServerError {
|
|
||||||
#[error("index not ready")]
|
|
||||||
NotReady,
|
|
||||||
|
|
||||||
#[error("{0}")]
|
|
||||||
TantivyError(#[from] tantivy::TantivyError),
|
|
||||||
}
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue