refactor: extract tabby-common/src/code

refactor-extract-code
Meng Zhang 2023-11-10 13:23:01 -08:00
parent aa61f0549f
commit f5508554e2
12 changed files with 39 additions and 27 deletions

1
Cargo.lock generated
View File

@ -4125,6 +4125,7 @@ dependencies = [
"tantivy", "tantivy",
"thiserror", "thiserror",
"tokio", "tokio",
"tracing",
"utoipa", "utoipa",
"uuid 1.4.1", "uuid 1.4.1",
] ]

View File

@ -18,6 +18,7 @@ anyhow.workspace = true
async-trait.workspace = true async-trait.workspace = true
thiserror.workspace = true thiserror.workspace = true
utoipa = { workspace = true, features = ["axum_extras", "preserve_order"] } utoipa = { workspace = true, features = ["axum_extras", "preserve_order"] }
tracing.workspace = true
[features] [features]
testutils = [] testutils = []

View File

@ -1 +0,0 @@
pub mod code;

View File

@ -54,3 +54,5 @@ pub trait CodeSearch: Send + Sync {
offset: usize, offset: usize,
) -> Result<SearchResponse, CodeSearchError>; ) -> Result<SearchResponse, CodeSearchError>;
} }
pub type BoxCodeSearch = Box<dyn CodeSearch>;

View File

@ -2,11 +2,6 @@ use std::{sync::Arc, time::Duration};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use tabby_common::{
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
use tantivy::{ use tantivy::{
collector::{Count, TopDocs}, collector::{Count, TopDocs},
query::QueryParser, query::QueryParser,
@ -16,6 +11,12 @@ use tantivy::{
use tokio::{sync::Mutex, time::sleep}; use tokio::{sync::Mutex, time::sleep};
use tracing::{debug, log::info}; use tracing::{debug, log::info};
use super::api::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse};
use crate::{
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
struct CodeSearchImpl { struct CodeSearchImpl {
reader: IndexReader, reader: IndexReader,
query_parser: QueryParser, query_parser: QueryParser,
@ -118,12 +119,12 @@ fn get_field(doc: &Document, field: Field) -> String {
.to_owned() .to_owned()
} }
pub struct CodeSearchService { pub(crate) struct CodeSearchService {
search: Arc<Mutex<Option<CodeSearchImpl>>>, search: Arc<Mutex<Option<CodeSearchImpl>>>,
} }
impl CodeSearchService { impl CodeSearchService {
pub fn new() -> Self { pub(crate) fn new() -> Self {
let search = Arc::new(Mutex::new(None)); let search = Arc::new(Mutex::new(None));
let ret = Self { let ret = Self {

View File

@ -0,0 +1,8 @@
mod api;
mod imp;
pub use api::*;
pub fn create_local() -> BoxCodeSearch {
Box::new(imp::CodeSearchService::new())
}

View File

@ -1,4 +1,4 @@
pub mod api; pub mod code;
pub mod config; pub mod config;
pub mod events; pub mod events;
pub mod index; pub mod index;

View File

@ -1,6 +1,5 @@
mod chat; mod chat;
mod download; mod download;
mod search;
mod serve; mod serve;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};

View File

@ -5,13 +5,11 @@ use std::sync::Arc;
use axum::{extract::State, Json}; use axum::{extract::State, Json};
use hyper::StatusCode; use hyper::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_common::{events, languages::get_language}; use tabby_common::{code::BoxCodeSearch, events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::search::CodeSearchService;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
#[schema(example=json!({ #[schema(example=json!({
"language": "python", "language": "python",
@ -211,7 +209,7 @@ pub struct CompletionState {
impl CompletionState { impl CompletionState {
pub fn new( pub fn new(
engine: Arc<Box<dyn TextGeneration>>, engine: Arc<Box<dyn TextGeneration>>,
code: Arc<CodeSearchService>, code: Arc<BoxCodeSearch>,
prompt_template: Option<String>, prompt_template: Option<String>,
) -> Self { ) -> Self {
Self { Self {

View File

@ -4,7 +4,7 @@ use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::{ use tabby_common::{
api::code::{CodeSearch, CodeSearchError}, code::{BoxCodeSearch, CodeSearch, CodeSearchError},
index::CodeSearchSchema, index::CodeSearchSchema,
languages::get_language, languages::get_language,
}; };
@ -13,7 +13,6 @@ use textdistance::Algorithm;
use tracing::warn; use tracing::warn;
use super::{Segments, Snippet}; use super::{Segments, Snippet};
use crate::search::CodeSearchService;
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
@ -22,11 +21,11 @@ static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder { pub struct PromptBuilder {
schema: CodeSearchSchema, schema: CodeSearchSchema,
prompt_template: Option<String>, prompt_template: Option<String>,
code: Option<Arc<CodeSearchService>>, code: Option<Arc<BoxCodeSearch>>,
} }
impl PromptBuilder { impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self { pub fn new(prompt_template: Option<String>, code: Option<Arc<BoxCodeSearch>>) -> Self {
PromptBuilder { PromptBuilder {
schema: CodeSearchSchema::new(), schema: CodeSearchSchema::new(),
prompt_template, prompt_template,
@ -44,7 +43,13 @@ impl PromptBuilder {
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(code) = &self.code { if let Some(code) = &self.code {
collect_snippets(&self.schema, code, language, &segments.prefix).await collect_snippets(
&self.schema,
code.as_ref().as_ref(),
language,
&segments.prefix,
)
.await
} else { } else {
vec![] vec![]
} }
@ -113,7 +118,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
async fn collect_snippets( async fn collect_snippets(
schema: &CodeSearchSchema, schema: &CodeSearchSchema,
code: &CodeSearchService, code: &dyn CodeSearch,
language: &str, language: &str,
text: &str, text: &str,
) -> Vec<Snippet> { ) -> Vec<Snippet> {

View File

@ -17,7 +17,7 @@ use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args; use clap::Args;
use tabby_common::{ use tabby_common::{
api::code::{Hit, HitDocument, SearchResponse}, code::{create_local, Hit, HitDocument, SearchResponse},
config::Config, config::Config,
usage, usage,
}; };
@ -32,7 +32,7 @@ use self::{
engine::{create_engine, EngineInfo}, engine::{create_engine, EngineInfo},
health::HealthState, health::HealthState,
}; };
use crate::{chat::ChatService, fatal, search::CodeSearchService}; use crate::{chat::ChatService, fatal};
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -173,7 +173,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
} }
async fn api_router(args: &ServeArgs, config: &Config) -> Router { async fn api_router(args: &ServeArgs, config: &Config) -> Router {
let code = Arc::new(CodeSearchService::new()); let code = Arc::new(create_local());
let completion_state = { let completion_state = {
let ( let (
engine, engine,

View File

@ -7,12 +7,10 @@ use axum::{
}; };
use hyper::StatusCode; use hyper::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse}; use tabby_common::code::{BoxCodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn}; use tracing::{instrument, warn};
use utoipa::IntoParams; use utoipa::IntoParams;
use crate::search::CodeSearchService;
#[derive(Deserialize, IntoParams)] #[derive(Deserialize, IntoParams)]
pub struct SearchQuery { pub struct SearchQuery {
#[param(default = "get")] #[param(default = "get")]
@ -38,7 +36,7 @@ pub struct SearchQuery {
)] )]
#[instrument(skip(state, query))] #[instrument(skip(state, query))]
pub async fn search( pub async fn search(
State(state): State<Arc<CodeSearchService>>, State(state): State<Arc<BoxCodeSearch>>,
query: Query<SearchQuery>, query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> { ) -> Result<Json<SearchResponse>, StatusCode> {
match state match state