refactor(code): extract `search_in_language` (#762)

* chore: init tabby-webserver

* add code search worker registry

* add webserver command

* add graphql

* extract schema

* refactor: extract registry.rs

* refactor

* update

* update

* update

* update

* update

* fix lint
extract-routes
Meng Zhang 2023-11-10 17:29:50 -08:00 committed by GitHub
parent bf2c1e6a79
commit 41f60d3204
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 42 deletions

View File

@ -1,22 +1,22 @@
use async_trait::async_trait;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Serialize, ToSchema)]
#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct SearchResponse {
pub num_hits: usize,
pub hits: Vec<Hit>,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct Hit {
pub score: f32,
pub doc: HitDocument,
pub id: u32,
}
#[derive(Serialize, ToSchema)]
#[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct HitDocument {
pub body: String,
pub filepath: String,
@ -47,9 +47,10 @@ pub trait CodeSearch: Send + Sync {
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
async fn search_with_query(
async fn search_in_language(
&self,
q: &dyn tantivy::query::Query,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;

View File

@ -9,7 +9,8 @@ use tabby_common::{
};
use tantivy::{
collector::{Count, TopDocs},
query::QueryParser,
query::{BooleanQuery, QueryParser},
query_grammar::Occur,
schema::Field,
DocAddress, Document, Index, IndexReader,
};
@ -75,19 +76,6 @@ impl CodeSearchImpl {
id: doc_address.doc_id,
}
}
}
#[async_trait]
impl CodeSearch for CodeSearchImpl {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let query = self.query_parser.parse_query(q)?;
self.search_with_query(&query, limit, offset).await
}
async fn search_with_query(
&self,
@ -111,6 +99,35 @@ impl CodeSearch for CodeSearchImpl {
}
}
#[async_trait]
impl CodeSearch for CodeSearchImpl {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let query = self.query_parser.parse_query(q)?;
self.search_with_query(&query, limit, offset).await
}
async fn search_in_language(
&self,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let language_query = self.schema.language_query(language);
let body_query = self.schema.body_query(tokens);
let query = BooleanQuery::new(vec![
(Occur::Must, language_query),
(Occur::Must, body_query),
]);
self.search_with_query(&query, limit, offset).await
}
}
fn get_field(doc: &Document, field: Field) -> String {
doc.get_first(field)
.and_then(|x| x.as_text())
@ -158,14 +175,16 @@ impl CodeSearch for CodeSearchService {
}
}
async fn search_with_query(
async fn search_in_language(
&self,
q: &dyn tantivy::query::Query,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
if let Some(imp) = self.search.lock().await.as_ref() {
imp.search_with_query(q, limit, offset).await
imp.search_in_language(language, tokens, limit, offset)
.await
} else {
Err(CodeSearchError::NotReady)
}

View File

@ -5,10 +5,8 @@ use regex::Regex;
use strfmt::strfmt;
use tabby_common::{
api::code::{BoxCodeSearch, CodeSearchError},
index::CodeSearchSchema,
languages::get_language,
};
use tantivy::{query::BooleanQuery, query_grammar::Occur};
use textdistance::Algorithm;
use tracing::warn;
@ -19,7 +17,6 @@ static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder {
schema: CodeSearchSchema,
prompt_template: Option<String>,
code: Option<Arc<BoxCodeSearch>>,
}
@ -27,7 +24,6 @@ pub struct PromptBuilder {
impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
PromptBuilder {
schema: CodeSearchSchema::new(),
prompt_template,
code,
}
@ -43,7 +39,7 @@ impl PromptBuilder {
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(code) = &self.code {
collect_snippets(&self.schema, code.as_ref(), language, &segments.prefix).await
collect_snippets(code.as_ref(), language, &segments.prefix).await
} else {
vec![]
}
@ -110,24 +106,12 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
format!("{}\n{}", comments, prefix)
}
async fn collect_snippets(
schema: &CodeSearchSchema,
code: &BoxCodeSearch,
language: &str,
text: &str,
) -> Vec<Snippet> {
async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> {
let mut ret = Vec::new();
let mut tokens = tokenize_text(text);
let language_query = schema.language_query(language);
let body_query = schema.body_query(&tokens);
let query = BooleanQuery::new(vec![
(Occur::Must, language_query),
(Occur::Must, body_query),
]);
let serp = match code
.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0)
.search_in_language(language, &tokens, MAX_SNIPPETS_TO_FETCH, 0)
.await
{
Ok(serp) => serp,