diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs index 2491a99..13464ed 100644 --- a/crates/tabby-common/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -54,3 +54,5 @@ pub trait CodeSearch: Send + Sync { offset: usize, ) -> Result; } + +pub type BoxCodeSearch = Box; diff --git a/crates/tabby/src/search.rs b/crates/tabby/src/search.rs index ed123e5..fd85195 100644 --- a/crates/tabby/src/search.rs +++ b/crates/tabby/src/search.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use anyhow::Result; use async_trait::async_trait; use tabby_common::{ - api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, + api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, index::{self, register_tokenizers, CodeSearchSchema}, path, }; @@ -118,7 +118,7 @@ fn get_field(doc: &Document, field: Field) -> String { .to_owned() } -pub struct CodeSearchService { +struct CodeSearchService { search: Arc>>, } @@ -139,6 +139,10 @@ impl CodeSearchService { } } +pub fn create_code_search() -> BoxCodeSearch { + Box::new(CodeSearchService::new()) +} + #[async_trait] impl CodeSearch for CodeSearchService { async fn search( diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 1032ee4..4a66220 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -5,13 +5,11 @@ use std::sync::Arc; use axum::{extract::State, Json}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use tabby_common::{events, languages::get_language}; +use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; -use crate::search::CodeSearchService; - #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[schema(example=json!({ "language": "python", @@ -211,7 +209,7 @@ pub struct CompletionState { impl CompletionState { pub fn new( engine: Arc>, - code: Arc, + code: Arc, prompt_template: Option, ) -> Self { Self { diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index a9a2f5a..b2eeda8 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -4,7 +4,7 @@ use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; use tabby_common::{ - api::code::{CodeSearch, CodeSearchError}, + api::code::{BoxCodeSearch, CodeSearchError}, index::CodeSearchSchema, languages::get_language, }; @@ -13,7 +13,6 @@ use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::search::CodeSearchService; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; @@ -22,11 +21,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { schema: CodeSearchSchema, prompt_template: Option, - code: Option>, + code: Option>, } impl PromptBuilder { - pub fn new(prompt_template: Option, code: Option>) -> Self { + pub fn new(prompt_template: Option, code: Option>) -> Self { PromptBuilder { schema: CodeSearchSchema::new(), prompt_template, @@ -44,7 +43,13 @@ impl PromptBuilder { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec { if let Some(code) = &self.code { - collect_snippets(&self.schema, code, language, &segments.prefix).await + collect_snippets( + &self.schema, + code.as_ref(), + language, + &segments.prefix, + ) + .await } else { vec![] } @@ -113,7 +118,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { async fn collect_snippets( schema: &CodeSearchSchema, - code: &CodeSearchService, + code: &BoxCodeSearch, language: &str, text: &str, ) -> Vec { diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 62c5ea3..1d3bcc7 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -32,7 +32,7 @@ use self::{ engine::{create_engine, EngineInfo}, health::HealthState, }; -use crate::{chat::ChatService, fatal, search::CodeSearchService}; +use crate::{chat::ChatService, fatal, search::create_code_search}; #[derive(OpenApi)] #[openapi( @@ -173,7 +173,7 @@ pub async fn main(config: &Config, args: &ServeArgs) { } async fn api_router(args: &ServeArgs, config: &Config) -> Router { - let code = Arc::new(CodeSearchService::new()); + let code = Arc::new(create_code_search()); let completion_state = { let ( engine, diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index bdc742f..28a2a67 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -7,12 +7,10 @@ use axum::{ }; use hyper::StatusCode; use serde::Deserialize; -use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse}; +use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse}; use tracing::{instrument, warn}; use utoipa::IntoParams; -use crate::search::CodeSearchService; - #[derive(Deserialize, IntoParams)] pub struct SearchQuery { #[param(default = "get")] @@ -38,7 +36,7 @@ pub struct SearchQuery { )] #[instrument(skip(state, query))] pub async fn search( - State(state): State>, + State(state): State>, query: Query, ) -> Result, StatusCode> { match state