refactor(code): extract `search_in_language` (#762)

* chore: init tabby-webserver

* add code search worker registry

* add webserver command

* add graphql

* extract schema

* refactor: extract registry.rs

* refactor

* update

* update

* update

* update

* update

* fix lint
extract-routes
Meng Zhang 2023-11-10 17:29:50 -08:00 committed by GitHub
parent bf2c1e6a79
commit 41f60d3204
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 42 deletions

View File

@ -1,22 +1,22 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde::Serialize; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
use utoipa::ToSchema; use utoipa::ToSchema;
#[derive(Serialize, ToSchema)] #[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct SearchResponse { pub struct SearchResponse {
pub num_hits: usize, pub num_hits: usize,
pub hits: Vec<Hit>, pub hits: Vec<Hit>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct Hit { pub struct Hit {
pub score: f32, pub score: f32,
pub doc: HitDocument, pub doc: HitDocument,
pub id: u32, pub id: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, Deserialize, Debug, ToSchema)]
pub struct HitDocument { pub struct HitDocument {
pub body: String, pub body: String,
pub filepath: String, pub filepath: String,
@ -47,9 +47,10 @@ pub trait CodeSearch: Send + Sync {
offset: usize, offset: usize,
) -> Result<SearchResponse, CodeSearchError>; ) -> Result<SearchResponse, CodeSearchError>;
async fn search_with_query( async fn search_in_language(
&self, &self,
q: &dyn tantivy::query::Query, language: &str,
tokens: &[String],
limit: usize, limit: usize,
offset: usize, offset: usize,
) -> Result<SearchResponse, CodeSearchError>; ) -> Result<SearchResponse, CodeSearchError>;

View File

@ -9,7 +9,8 @@ use tabby_common::{
}; };
use tantivy::{ use tantivy::{
collector::{Count, TopDocs}, collector::{Count, TopDocs},
query::QueryParser, query::{BooleanQuery, QueryParser},
query_grammar::Occur,
schema::Field, schema::Field,
DocAddress, Document, Index, IndexReader, DocAddress, Document, Index, IndexReader,
}; };
@ -75,19 +76,6 @@ impl CodeSearchImpl {
id: doc_address.doc_id, id: doc_address.doc_id,
} }
} }
}
#[async_trait]
impl CodeSearch for CodeSearchImpl {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let query = self.query_parser.parse_query(q)?;
self.search_with_query(&query, limit, offset).await
}
async fn search_with_query( async fn search_with_query(
&self, &self,
@ -111,6 +99,35 @@ impl CodeSearch for CodeSearchImpl {
} }
} }
#[async_trait]
impl CodeSearch for CodeSearchImpl {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let query = self.query_parser.parse_query(q)?;
self.search_with_query(&query, limit, offset).await
}
async fn search_in_language(
&self,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
let language_query = self.schema.language_query(language);
let body_query = self.schema.body_query(tokens);
let query = BooleanQuery::new(vec![
(Occur::Must, language_query),
(Occur::Must, body_query),
]);
self.search_with_query(&query, limit, offset).await
}
}
fn get_field(doc: &Document, field: Field) -> String { fn get_field(doc: &Document, field: Field) -> String {
doc.get_first(field) doc.get_first(field)
.and_then(|x| x.as_text()) .and_then(|x| x.as_text())
@ -158,14 +175,16 @@ impl CodeSearch for CodeSearchService {
} }
} }
async fn search_with_query( async fn search_in_language(
&self, &self,
q: &dyn tantivy::query::Query, language: &str,
tokens: &[String],
limit: usize, limit: usize,
offset: usize, offset: usize,
) -> Result<SearchResponse, CodeSearchError> { ) -> Result<SearchResponse, CodeSearchError> {
if let Some(imp) = self.search.lock().await.as_ref() { if let Some(imp) = self.search.lock().await.as_ref() {
imp.search_with_query(q, limit, offset).await imp.search_in_language(language, tokens, limit, offset)
.await
} else { } else {
Err(CodeSearchError::NotReady) Err(CodeSearchError::NotReady)
} }

View File

@ -5,10 +5,8 @@ use regex::Regex;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::{ use tabby_common::{
api::code::{BoxCodeSearch, CodeSearchError}, api::code::{BoxCodeSearch, CodeSearchError},
index::CodeSearchSchema,
languages::get_language, languages::get_language,
}; };
use tantivy::{query::BooleanQuery, query_grammar::Occur};
use textdistance::Algorithm; use textdistance::Algorithm;
use tracing::warn; use tracing::warn;
@ -19,7 +17,6 @@ static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder { pub struct PromptBuilder {
schema: CodeSearchSchema,
prompt_template: Option<String>, prompt_template: Option<String>,
code: Option<Arc<BoxCodeSearch>>, code: Option<Arc<BoxCodeSearch>>,
} }
@ -27,7 +24,6 @@ pub struct PromptBuilder {
impl PromptBuilder { impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self { pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
PromptBuilder { PromptBuilder {
schema: CodeSearchSchema::new(),
prompt_template, prompt_template,
code, code,
} }
@ -43,7 +39,7 @@ impl PromptBuilder {
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(code) = &self.code { if let Some(code) = &self.code {
collect_snippets(&self.schema, code.as_ref(), language, &segments.prefix).await collect_snippets(code.as_ref(), language, &segments.prefix).await
} else { } else {
vec![] vec![]
} }
@ -110,24 +106,12 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
format!("{}\n{}", comments, prefix) format!("{}\n{}", comments, prefix)
} }
async fn collect_snippets( async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec<Snippet> {
schema: &CodeSearchSchema,
code: &BoxCodeSearch,
language: &str,
text: &str,
) -> Vec<Snippet> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let mut tokens = tokenize_text(text); let mut tokens = tokenize_text(text);
let language_query = schema.language_query(language);
let body_query = schema.body_query(&tokens);
let query = BooleanQuery::new(vec![
(Occur::Must, language_query),
(Occur::Must, body_query),
]);
let serp = match code let serp = match code
.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) .search_in_language(language, &tokens, MAX_SNIPPETS_TO_FETCH, 0)
.await .await
{ {
Ok(serp) => serp, Ok(serp) => serp,