refactor: extract BoxCodeSearch as interface to CodeSearch (#756)

extract-routes
Meng Zhang 2023-11-10 14:55:51 -08:00 committed by GitHub
parent aa61f0549f
commit 4068d6e81d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 18 deletions

View File

@ -54,3 +54,5 @@ pub trait CodeSearch: Send + Sync {
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
}
pub type BoxCodeSearch = Box<dyn CodeSearch>;

View File

@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration};
use anyhow::Result;
use async_trait::async_trait;
use tabby_common::{
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
api::code::{BoxCodeSearch, CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
@ -118,7 +118,7 @@ fn get_field(doc: &Document, field: Field) -> String {
.to_owned()
}
pub struct CodeSearchService {
struct CodeSearchService {
search: Arc<Mutex<Option<CodeSearchImpl>>>,
}
@ -139,6 +139,10 @@ impl CodeSearchService {
}
}
pub fn create_code_search() -> BoxCodeSearch {
Box::new(CodeSearchService::new())
}
#[async_trait]
impl CodeSearch for CodeSearchService {
async fn search(

View File

@ -5,13 +5,11 @@ use std::sync::Arc;
use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{events, languages::get_language};
use tabby_common::{api::code::BoxCodeSearch, events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;
use crate::search::CodeSearchService;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({
"language": "python",
@ -211,7 +209,7 @@ pub struct CompletionState {
impl CompletionState {
pub fn new(
engine: Arc<Box<dyn TextGeneration>>,
code: Arc<CodeSearchService>,
code: Arc<BoxCodeSearch>,
prompt_template: Option<String>,
) -> Self {
Self {

View File

@ -4,7 +4,7 @@ use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use tabby_common::{
api::code::{CodeSearch, CodeSearchError},
api::code::{BoxCodeSearch, CodeSearchError},
index::CodeSearchSchema,
languages::get_language,
};
@ -13,7 +13,6 @@ use textdistance::Algorithm;
use tracing::warn;
use super::{Segments, Snippet};
use crate::search::CodeSearchService;
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
@ -22,11 +21,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder {
schema: CodeSearchSchema,
prompt_template: Option<String>,
code: Option<Arc<CodeSearchService>>,
code: Option<Arc<BoxCodeSearch>>,
}
impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self {
pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
PromptBuilder {
schema: CodeSearchSchema::new(),
prompt_template,
@ -44,7 +43,13 @@ impl PromptBuilder {
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(code) = &self.code {
collect_snippets(&self.schema, code, language, &segments.prefix).await
collect_snippets(
&self.schema,
code.as_ref(),
language,
&segments.prefix,
)
.await
} else {
vec![]
}
@ -113,7 +118,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
async fn collect_snippets(
schema: &CodeSearchSchema,
code: &CodeSearchService,
code: &BoxCodeSearch,
language: &str,
text: &str,
) -> Vec<Snippet> {

View File

@ -32,7 +32,7 @@ use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
};
use crate::{chat::ChatService, fatal, search::CodeSearchService};
use crate::{chat::ChatService, fatal, search::create_code_search};
#[derive(OpenApi)]
#[openapi(
@ -173,7 +173,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let code = Arc::new(CodeSearchService::new());
let code = Arc::new(create_code_search());
let completion_state = {
let (
engine,

View File

@ -7,12 +7,10 @@ use axum::{
};
use hyper::StatusCode;
use serde::Deserialize;
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
use tabby_common::api::code::{BoxCodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn};
use utoipa::IntoParams;
use crate::search::CodeSearchService;
#[derive(Deserialize, IntoParams)]
pub struct SearchQuery {
#[param(default = "get")]
@ -38,7 +36,7 @@ pub struct SearchQuery {
)]
#[instrument(skip(state, query))]
pub async fn search(
State(state): State<Arc<CodeSearchService>>,
State(state): State<Arc<BoxCodeSearch>>,
query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> {
match state