diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index d5be128..194574a 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,4 +1,5 @@ mod download; +mod search; mod serve; use clap::{Parser, Subcommand}; diff --git a/crates/tabby/src/search.rs b/crates/tabby/src/search.rs new file mode 100644 index 0000000..3f6cfa5 --- /dev/null +++ b/crates/tabby/src/search.rs @@ -0,0 +1,263 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use axum::async_trait; +use serde::Serialize; +use tabby_common::{index::IndexExt, path}; +use tantivy::{ + collector::{Count, TopDocs}, + query::{QueryParser, TermQuery, TermSetQuery}, + schema::{Field, IndexRecordOption}, + DocAddress, Document, Index, IndexReader, Term, +}; +use thiserror::Error; +use tokio::{sync::Mutex, time::sleep}; +use tracing::{debug, log::info}; +use utoipa::ToSchema; + +#[derive(Serialize, ToSchema)] +pub struct SearchResponse { + pub num_hits: usize, + pub hits: Vec, +} + +#[derive(Serialize, ToSchema)] +pub struct Hit { + pub score: f32, + pub doc: HitDocument, + pub id: u32, +} + +#[derive(Serialize, ToSchema)] +pub struct HitDocument { + pub body: String, + pub filepath: String, + pub git_url: String, + pub kind: String, + pub language: String, + pub name: String, +} + +#[derive(Error, Debug)] +pub enum CodeSearchError { + #[error("index not ready")] + NotReady, + + #[error("{0}")] + QueryParserError(#[from] tantivy::query::QueryParserError), + + #[error("{0}")] + TantivyError(#[from] tantivy::TantivyError), +} + +#[async_trait] +pub trait CodeSearch { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result; + + async fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> Result; +} + +struct CodeSearchImpl { + reader: IndexReader, + query_parser: QueryParser, + + field_body: Field, + field_filepath: Field, + field_git_url: Field, + field_kind: Field, + field_language: Field, + field_name: Field, +} + +impl CodeSearchImpl { + fn load() -> Result { + let index = Index::open_in_dir(path::index_dir())?; + index.register_tokenizer(); + + let schema = index.schema(); + let field_body = schema.get_field("body").unwrap(); + let query_parser = + QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone()); + let reader = index + .reader_builder() + .reload_policy(tantivy::ReloadPolicy::OnCommit) + .try_into()?; + Ok(Self { + reader, + query_parser, + field_body, + field_filepath: schema.get_field("filepath").unwrap(), + field_git_url: schema.get_field("git_url").unwrap(), + field_kind: schema.get_field("kind").unwrap(), + field_language: schema.get_field("language").unwrap(), + field_name: schema.get_field("name").unwrap(), + }) + } + + async fn load_async() -> CodeSearchImpl { + loop { + match CodeSearchImpl::load() { + Ok(code) => { + info!("Index is ready, enabling server..."); + return code; + } + Err(err) => { + debug!("Source code index is not ready `{}`", err); + } + }; + + sleep(Duration::from_secs(60)).await; + } + } + + fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit { + Hit { + score, + doc: HitDocument { + body: get_field(&doc, self.field_body), + filepath: get_field(&doc, self.field_filepath), + git_url: get_field(&doc, self.field_git_url), + kind: get_field(&doc, self.field_kind), + name: get_field(&doc, self.field_name), + language: get_field(&doc, self.field_language), + }, + id: doc_address.doc_id, + } + } +} + +#[async_trait] +impl CodeSearch for CodeSearchImpl { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result { + let query = self.query_parser.parse_query(q)?; + self.search_with_query(&query, limit, offset).await + } + + async fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> Result { + let searcher = self.reader.searcher(); + let (top_docs, num_hits) = + { searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? }; + let hits: Vec = { + top_docs + .iter() + .map(|(score, doc_address)| { + let doc = searcher.doc(*doc_address).unwrap(); + self.create_hit(*score, doc, *doc_address) + }) + .collect() + }; + Ok(SearchResponse { num_hits, hits }) + } +} + +fn get_field(doc: &Document, field: Field) -> String { + doc.get_first(field) + .and_then(|x| x.as_text()) + .unwrap() + .to_owned() +} + +pub struct CodeSearchService { + search: Arc>>, +} + +impl CodeSearchService { + pub fn new() -> Self { + let search = Arc::new(Mutex::new(None)); + + let ret = Self { + search: search.clone(), + }; + + tokio::spawn(async move { + let code = CodeSearchImpl::load_async().await; + *search.lock().await = Some(code); + }); + + ret + } + + async fn with_impl(&self, op: F) -> Result + where + F: FnOnce(&CodeSearchImpl) -> Result, + { + if let Some(imp) = self.search.lock().await.as_ref() { + op(imp) + } else { + Err(CodeSearchError::NotReady) + } + } + + pub async fn language_query(&self, language: &str) -> Result, CodeSearchError> { + self.with_impl(|imp| { + Ok(Box::new(TermQuery::new( + Term::from_field_text(imp.field_language, language), + IndexRecordOption::WithFreqsAndPositions, + ))) + }) + .await + } + + pub async fn body_query( + &self, + tokens: &[String], + ) -> Result, CodeSearchError> { + self.with_impl(|imp| { + Ok(Box::new(TermSetQuery::new( + tokens + .iter() + .map(|x| Term::from_field_text(imp.field_body, x)), + ))) + }) + .await + } +} + +#[async_trait] +impl CodeSearch for CodeSearchService { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result { + if let Some(imp) = self.search.lock().await.as_ref() { + imp.search(q, limit, offset).await + } else { + Err(CodeSearchError::NotReady) + } + } + + async fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> Result { + if let Some(imp) = self.search.lock().await.as_ref() { + imp.search_with_query(q, limit, offset).await + } else { + Err(CodeSearchError::NotReady) + } + } +} diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 8dd6e3e..1032ee4 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -10,7 +10,7 @@ use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; -use super::search::IndexServer; +use crate::search::CodeSearchService; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ @@ -137,7 +137,8 @@ pub async fn completions( (prompt, None, vec![]) } else if let Some(segments) = request.segments { debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); - let (prompt, snippets) = build_prompt(&state, &request.debug_options, &language, &segments); + let (prompt, snippets) = + build_prompt(&state, &request.debug_options, &language, &segments).await; (prompt, Some(segments), snippets) } else { return Err(StatusCode::BAD_REQUEST); @@ -180,7 +181,7 @@ pub async fn completions( })) } -fn build_prompt( +async fn build_prompt( state: &Arc, debug_options: &Option, language: &str, @@ -190,7 +191,7 @@ fn build_prompt( .as_ref() .is_some_and(|x| x.disable_retrieval_augmented_code_completion) { - state.prompt_builder.collect(language, segments) + state.prompt_builder.collect(language, segments).await } else { vec![] }; @@ -210,12 +211,12 @@ pub struct CompletionState { impl CompletionState { pub fn new( engine: Arc>, - index_server: Arc, + code: Arc, prompt_template: Option, ) -> Self { Self { engine, - prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(index_server)), + prompt_builder: prompt::PromptBuilder::new(prompt_template, Some(code)), } } } diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index b538030..8c4c2d3 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -9,7 +9,7 @@ use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::serve::search::{IndexServer, IndexServerError}; +use crate::search::{CodeSearch, CodeSearchError, CodeSearchService}; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; @@ -17,14 +17,14 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { prompt_template: Option, - index_server: Option>, + code: Option>, } impl PromptBuilder { - pub fn new(prompt_template: Option, index_server: Option>) -> Self { + pub fn new(prompt_template: Option, code: Option>) -> Self { PromptBuilder { prompt_template, - index_server, + code, } } @@ -36,9 +36,9 @@ impl PromptBuilder { strfmt!(prompt_template, prefix => prefix, suffix => suffix).unwrap() } - pub fn collect(&self, language: &str, segments: &Segments) -> Vec { - if let Some(index_server) = &self.index_server { - collect_snippets(index_server, language, &segments.prefix) + pub async fn collect(&self, language: &str, segments: &Segments) -> Vec { + if let Some(code) = &self.code { + collect_snippets(code, language, &segments.prefix).await } else { vec![] } @@ -105,14 +105,14 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { format!("{}\n{}", comments, prefix) } -fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> Vec { +async fn collect_snippets(code: &CodeSearchService, language: &str, text: &str) -> Vec { let mut ret = Vec::new(); let mut tokens = tokenize_text(text); - let Ok(language_query) = index_server.language_query(language) else { + let Ok(language_query) = code.language_query(language).await else { return vec![]; }; - let Ok(body_query) = index_server.body_query(&tokens) else { + let Ok(body_query) = code.body_query(&tokens).await else { return vec![]; }; let query = BooleanQuery::new(vec![ @@ -120,14 +120,21 @@ fn collect_snippets(index_server: &IndexServer, language: &str, text: &str) -> V (Occur::Must, body_query), ]); - let serp = match index_server.search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) { + let serp = match code + .search_with_query(&query, MAX_SNIPPETS_TO_FETCH, 0) + .await + { Ok(serp) => serp, - Err(IndexServerError::NotReady) => { + Err(CodeSearchError::NotReady) => { // Ignore. return vec![]; } - Err(IndexServerError::TantivyError(err)) => { - warn!("Failed to search query: {}", err); + Err(CodeSearchError::TantivyError(err)) => { + warn!("Failed to search: {}", err); + return ret; + } + Err(CodeSearchError::QueryParserError(err)) => { + warn!("Failed to parse query: {}", err); return ret; } }; diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index d3c7942..cde9a94 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -27,9 +27,8 @@ use utoipa_swagger_ui::SwaggerUi; use self::{ engine::{create_engine, EngineInfo}, health::HealthState, - search::IndexServer, }; -use crate::fatal; +use crate::{fatal, search::CodeSearchService}; #[derive(OpenApi)] #[openapi( @@ -63,9 +62,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi chat::ChatCompletionChunk, health::HealthState, health::Version, - search::SearchResponse, - search::Hit, - search::HitDocument + crate::search::SearchResponse, + crate::search::Hit, + crate::search::HitDocument )) )] struct ApiDoc; @@ -170,7 +169,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { } async fn api_router(args: &ServeArgs, config: &Config) -> Router { - let index_server = Arc::new(IndexServer::new()); + let code = Arc::new(CodeSearchService::new()); let completion_state = { let ( engine, @@ -179,11 +178,8 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { }, ) = create_engine(&args.model, args).await; let engine = Arc::new(engine); - let state = completions::CompletionState::new( - engine.clone(), - index_server.clone(), - prompt_template, - ); + let state = + completions::CompletionState::new(engine.clone(), code.clone(), prompt_template); Arc::new(state) }; @@ -238,7 +234,7 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router { routers.push({ Router::new().route( "/v1beta/search", - routing::get(search::search).with_state(index_server), + routing::get(search::search).with_state(code), ) }); diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index 2be9e15..0644443 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use anyhow::Result; use axum::{ @@ -6,18 +6,11 @@ use axum::{ Json, }; use hyper::StatusCode; -use serde::{Deserialize, Serialize}; -use tabby_common::{index::IndexExt, path}; -use tantivy::{ - collector::{Count, TopDocs}, - query::{QueryParser, TermQuery, TermSetQuery}, - schema::{Field, IndexRecordOption}, - DocAddress, Document, Index, IndexReader, Term, -}; -use thiserror::Error; -use tokio::{sync::OnceCell, task, time::sleep}; -use tracing::{debug, instrument, log::info, warn}; -use utoipa::{IntoParams, ToSchema}; +use serde::Deserialize; +use tracing::{instrument, warn}; +use utoipa::IntoParams; + +use crate::search::{CodeSearch, CodeSearchError, CodeSearchService, SearchResponse}; #[derive(Deserialize, IntoParams)] pub struct SearchQuery { @@ -31,29 +24,6 @@ pub struct SearchQuery { offset: Option, } -#[derive(Serialize, ToSchema)] -pub struct SearchResponse { - pub num_hits: usize, - pub hits: Vec, -} - -#[derive(Serialize, ToSchema)] -pub struct Hit { - pub score: f32, - pub doc: HitDocument, - pub id: u32, -} - -#[derive(Serialize, ToSchema)] -pub struct HitDocument { - pub body: String, - pub filepath: String, - pub git_url: String, - pub kind: String, - pub language: String, - pub name: String, -} - #[utoipa::path( get, params(SearchQuery), @@ -67,193 +37,26 @@ pub struct HitDocument { )] #[instrument(skip(state, query))] pub async fn search( - State(state): State>, + State(state): State>, query: Query, ) -> Result, StatusCode> { - match state.search( - &query.q, - query.limit.unwrap_or(20), - query.offset.unwrap_or(0), - ) { + match state + .search( + &query.q, + query.limit.unwrap_or(20), + query.offset.unwrap_or(0), + ) + .await + { Ok(serp) => Ok(Json(serp)), - Err(IndexServerError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED), - Err(IndexServerError::TantivyError(err)) => { + Err(CodeSearchError::NotReady) => Err(StatusCode::NOT_IMPLEMENTED), + Err(CodeSearchError::TantivyError(err)) => { warn!("{}", err); Err(StatusCode::INTERNAL_SERVER_ERROR) } - } -} - -struct IndexServerImpl { - reader: IndexReader, - query_parser: QueryParser, - - field_body: Field, - field_filepath: Field, - field_git_url: Field, - field_kind: Field, - field_language: Field, - field_name: Field, -} - -impl IndexServerImpl { - pub fn load() -> Result { - let index = Index::open_in_dir(path::index_dir())?; - index.register_tokenizer(); - - let schema = index.schema(); - let field_body = schema.get_field("body").unwrap(); - let query_parser = - QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone()); - let reader = index - .reader_builder() - .reload_policy(tantivy::ReloadPolicy::OnCommit) - .try_into()?; - Ok(Self { - reader, - query_parser, - field_body, - field_filepath: schema.get_field("filepath").unwrap(), - field_git_url: schema.get_field("git_url").unwrap(), - field_kind: schema.get_field("kind").unwrap(), - field_language: schema.get_field("language").unwrap(), - field_name: schema.get_field("name").unwrap(), - }) - } - - pub fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result { - let query = self.query_parser.parse_query(q)?; - self.search_with_query(&query, limit, offset) - } - - pub fn search_with_query( - &self, - q: &dyn tantivy::query::Query, - limit: usize, - offset: usize, - ) -> tantivy::Result { - let searcher = self.reader.searcher(); - let (top_docs, num_hits) = - { searcher.search(q, &(TopDocs::with_limit(limit).and_offset(offset), Count))? }; - let hits: Vec = { - top_docs - .iter() - .map(|(score, doc_address)| { - let doc = searcher.doc(*doc_address).unwrap(); - self.create_hit(*score, doc, *doc_address) - }) - .collect() - }; - Ok(SearchResponse { num_hits, hits }) - } - - fn create_hit(&self, score: f32, doc: Document, doc_address: DocAddress) -> Hit { - Hit { - score, - doc: HitDocument { - body: get_field(&doc, self.field_body), - filepath: get_field(&doc, self.field_filepath), - git_url: get_field(&doc, self.field_git_url), - kind: get_field(&doc, self.field_kind), - name: get_field(&doc, self.field_name), - language: get_field(&doc, self.field_language), - }, - id: doc_address.doc_id, + Err(CodeSearchError::QueryParserError(err)) => { + warn!("{}", err); + Err(StatusCode::BAD_REQUEST) } } } - -fn get_field(doc: &Document, field: Field) -> String { - doc.get_first(field) - .and_then(|x| x.as_text()) - .unwrap() - .to_owned() -} - -static IMPL: OnceCell = OnceCell::const_new(); - -pub struct IndexServer {} - -impl IndexServer { - pub fn new() -> Self { - task::spawn(IMPL.get_or_init(|| async { - task::spawn(IndexServer::worker()) - .await - .expect("Failed to create IndexServerImpl") - })); - Self {} - } - - fn with_impl(&self, op: F) -> Result - where - F: FnOnce(&IndexServerImpl) -> Result, - { - if let Some(imp) = IMPL.get() { - op(imp) - } else { - Err(IndexServerError::NotReady) - } - } - - async fn worker() -> IndexServerImpl { - loop { - match IndexServerImpl::load() { - Ok(index_server) => { - info!("Index is ready, enabling server..."); - return index_server; - } - Err(err) => { - debug!("Source code index is not ready `{}`", err); - } - }; - - sleep(Duration::from_secs(60)).await; - } - } - - pub fn language_query(&self, language: &str) -> Result, IndexServerError> { - self.with_impl(|imp| { - Ok(Box::new(TermQuery::new( - Term::from_field_text(imp.field_language, language), - IndexRecordOption::WithFreqsAndPositions, - ))) - }) - } - - pub fn body_query(&self, tokens: &[String]) -> Result, IndexServerError> { - self.with_impl(|imp| { - Ok(Box::new(TermSetQuery::new( - tokens - .iter() - .map(|x| Term::from_field_text(imp.field_body, x)), - ))) - }) - } - - pub fn search( - &self, - q: &str, - limit: usize, - offset: usize, - ) -> Result { - self.with_impl(|imp| Ok(imp.search(q, limit, offset)?)) - } - - pub fn search_with_query( - &self, - q: &dyn tantivy::query::Query, - limit: usize, - offset: usize, - ) -> Result { - self.with_impl(|imp| Ok(imp.search_with_query(q, limit, offset)?)) - } -} - -#[derive(Error, Debug)] -pub enum IndexServerError { - #[error("index not ready")] - NotReady, - - #[error("{0}")] - TantivyError(#[from] tantivy::TantivyError), -}