diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index 13464ed..47cbe8e 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -1,22 +1,22 @@ use async_trait::async_trait; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; use utoipa::ToSchema; -#[derive(Serialize, ToSchema)] +#[derive(Serialize, Deserialize, Debug, ToSchema)] pub struct SearchResponse { pub num_hits: usize, pub hits: Vec, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, Deserialize, Debug, ToSchema)] pub struct Hit { pub score: f32, pub doc: HitDocument, pub id: u32, } -#[derive(Serialize, ToSchema)] +#[derive(Serialize, Deserialize, Debug, ToSchema)] pub struct HitDocument { pub body: String, pub filepath: String, @@ -47,9 +47,10 @@ pub trait CodeSearch: Send + Sync { offset: usize, ) -> Result; - async fn search_with_query( + async fn search_in_language( &self, - q: &dyn tantivy::query::Query, + language: &str, + tokens: &[String], limit: usize, offset: usize, ) -> Result; diff --git a/crates/tabby/src/search.rs b/crates/tabby/src/search.rs index fd85195..b7de2da 100644 --- a/crates/tabby/src/search.rs +++ b/crates/tabby/src/search.rs @@ -9,7 +9,8 @@ use tabby_common::{ }; use tantivy::{ collector::{Count, TopDocs}, - query::QueryParser, + query::{BooleanQuery, QueryParser}, + query_grammar::Occur, schema::Field, DocAddress, Document, Index, IndexReader, }; @@ -75,19 +76,6 @@ impl CodeSearchImpl { id: doc_address.doc_id, } } -} - -#[async_trait] -impl CodeSearch for CodeSearchImpl { - async fn search( - &self, - q: &str, - limit: usize, - offset: usize, - ) -> Result { - let query = self.query_parser.parse_query(q)?; - self.search_with_query(&query, limit, offset).await - } async fn search_with_query( &self, @@ -111,6 +99,35 @@ impl CodeSearch for CodeSearchImpl { } } +#[async_trait] +impl CodeSearch for CodeSearchImpl { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result { + let query = self.query_parser.parse_query(q)?; + self.search_with_query(&query, limit, offset).await + } + + async fn search_in_language( + &self, + language: &str, + tokens: &[String], + limit: usize, + offset: usize, + ) -> Result { + let language_query = self.schema.language_query(language); + let body_query = self.schema.body_query(tokens); + let query = BooleanQuery::new(vec![ + (Occur::Must, language_query), + (Occur::Must, body_query), + ]); + self.search_with_query(&query, limit, offset).await + } +} + fn get_field(doc: &Document, field: Field) -> String { doc.get_first(field) .and_then(|x| x.as_text()) @@ -158,14 +175,16 @@ impl CodeSearch for CodeSearchService { } } - async fn search_with_query( + async fn search_in_language( &self, - q: &dyn tantivy::query::Query, + language: &str, + tokens: &[String], limit: usize, offset: usize, ) -> Result { if let Some(imp) = self.search.lock().await.as_ref() { - imp.search_with_query(q, limit, offset).await + imp.search_in_language(language, tokens, limit, offset) + .await } else { Err(CodeSearchError::NotReady) } diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index b16cbbd..a863d73 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -5,10 +5,8 @@ use regex::Regex; use strfmt::strfmt; use tabby_common::{ api::code::{BoxCodeSearch, CodeSearchError}, - index::CodeSearchSchema, languages::get_language, }; -use tantivy::{query::BooleanQuery, query_grammar::Occur}; use textdistance::Algorithm; use tracing::warn; @@ -19,7 +17,6 @@ static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { - schema: CodeSearchSchema, prompt_template: Option, code: Option>, } @@ -27,7 +24,6 @@ pub struct PromptBuilder { impl PromptBuilder { pub fn new(prompt_template: Option, code: Option>) -> Self { PromptBuilder { - schema: CodeSearchSchema::new(), prompt_template, code, } @@ -43,7 +39,7 @@ impl PromptBuilder { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec { if let Some(code) = &self.code { - collect_snippets(&self.schema, code.as_ref(), language, &segments.prefix).await + collect_snippets(code.as_ref(), language, &segments.prefix).await } else { vec![] } @@ -110,24 +106,12 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { format!("{}\n{}", comments, prefix) } -async fn collect_snippets( - schema: &CodeSearchSchema, - code: &BoxCodeSearch, - language: &str, - text: &str, -) -> Vec { +async fn collect_snippets(code: &BoxCodeSearch, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); let mut tokens = tokenize_text(text); - let language_query = schema.language_query(language); - let body_query = schema.body_query(&tokens); - let query = BooleanQuery::new(vec![ - (Occur::Must, language_query), - (Occur::Must, body_query), - ]); - let serp = match code - .search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) + .search_in_language(language, &tokens, MAX_SNIPPETS_TO_FETCH, 0) .await { Ok(serp) => serp,