refactor: extract tabby_common::api::code / tabby_common::index::CodeSearchSchema (#743)

* refactor: extract tabby_common::api::code

mark CodeSearch being Send + Sync

* extract CodeSearchSchema
refactor-extract-code
Meng Zhang 2023-11-10 10:11:13 -08:00 committed by GitHub
parent ff03e2a34e
commit b510f61aca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 209 additions and 192 deletions

9
Cargo.lock generated
View File

@ -334,9 +334,9 @@ checksum = "b4eb2cdb97421e01129ccb49169d8279ed21e829929144f4a22a6e54ac549ca1"
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.72" version = "0.1.74"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09" checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -4058,6 +4058,7 @@ dependencies = [
"anyhow", "anyhow",
"assert-json-diff", "assert-json-diff",
"async-stream", "async-stream",
"async-trait",
"axum", "axum",
"axum-streams", "axum-streams",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
@ -4087,7 +4088,6 @@ dependencies = [
"tabby-scheduler", "tabby-scheduler",
"tantivy", "tantivy",
"textdistance", "textdistance",
"thiserror",
"tokio", "tokio",
"tower-http 0.4.0", "tower-http 0.4.0",
"tracing", "tracing",
@ -4104,6 +4104,7 @@ name = "tabby-common"
version = "0.6.0-dev" version = "0.6.0-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"chrono", "chrono",
"filenamify", "filenamify",
"lazy_static", "lazy_static",
@ -4112,7 +4113,9 @@ dependencies = [
"serde-jsonlines", "serde-jsonlines",
"serdeconv", "serdeconv",
"tantivy", "tantivy",
"thiserror",
"tokio", "tokio",
"utoipa",
"uuid 1.4.1", "uuid 1.4.1",
] ]

View File

@ -34,3 +34,4 @@ futures = "0.3.28"
async-stream = "0.3.5" async-stream = "0.3.5"
regex = "1.10.0" regex = "1.10.0"
thiserror = "1.0.49" thiserror = "1.0.49"
utoipa = "3.3"

View File

@ -15,6 +15,9 @@ tokio = { workspace = true, features = ["rt", "macros"] }
uuid = { version = "1.4.1", features = ["v4"] } uuid = { version = "1.4.1", features = ["v4"] }
tantivy.workspace = true tantivy.workspace = true
anyhow.workspace = true anyhow.workspace = true
async-trait.workspace = true
thiserror.workspace = true
utoipa = { workspace = true, features = ["axum_extras", "preserve_order"] }
[features] [features]
testutils = [] testutils = []

View File

@ -0,0 +1,56 @@
use async_trait::async_trait;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Serialize, ToSchema)]
pub struct SearchResponse {
pub num_hits: usize,
pub hits: Vec<Hit>,
}
#[derive(Serialize, ToSchema)]
pub struct Hit {
pub score: f32,
pub doc: HitDocument,
pub id: u32,
}
#[derive(Serialize, ToSchema)]
pub struct HitDocument {
pub body: String,
pub filepath: String,
pub git_url: String,
pub kind: String,
pub language: String,
pub name: String,
}
#[derive(Error, Debug)]
pub enum CodeSearchError {
#[error("index not ready")]
NotReady,
#[error("{0}")]
QueryParserError(#[from] tantivy::query::QueryParserError),
#[error("{0}")]
TantivyError(#[from] tantivy::TantivyError),
}
#[async_trait]
pub trait CodeSearch: Send + Sync {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
async fn search_with_query(
&self,
q: &dyn tantivy::query::Query,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
}

View File

@ -0,0 +1 @@
pub mod code;

View File

@ -1,27 +1,89 @@
use tantivy::{ use tantivy::{
query::{TermQuery, TermSetQuery},
schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING},
tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer}, tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index, Index, Term,
}; };
pub trait IndexExt {
fn register_tokenizer(&self);
}
pub static CODE_TOKENIZER: &str = "code"; pub static CODE_TOKENIZER: &str = "code";
pub static IDENTIFIER_TOKENIZER: &str = "identifier"; pub static IDENTIFIER_TOKENIZER: &str = "identifier";
impl IndexExt for Index { pub fn register_tokenizers(index: &Index) {
fn register_tokenizer(&self) { let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap()) .filter(RemoveLongFilter::limit(128))
.filter(RemoveLongFilter::limit(128)) .build();
.build();
self.tokenizers().register(CODE_TOKENIZER, code_tokenizer); index.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
let identifier_tokenzier = let identifier_tokenzier =
TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build(); TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build();
self.tokenizers() index
.register(IDENTIFIER_TOKENIZER, identifier_tokenzier); .tokenizers()
.register(IDENTIFIER_TOKENIZER, identifier_tokenzier);
}
pub struct CodeSearchSchema {
pub schema: Schema,
pub field_git_url: Field,
pub field_filepath: Field,
pub field_language: Field,
pub field_name: Field,
pub field_kind: Field,
pub field_body: Field,
}
impl CodeSearchSchema {
pub fn new() -> Self {
let mut builder = Schema::builder();
let code_indexing_options = TextFieldIndexing::default()
.set_tokenizer(CODE_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let code_options = TextOptions::default()
.set_indexing_options(code_indexing_options)
.set_stored();
let name_indexing_options = TextFieldIndexing::default()
.set_tokenizer(IDENTIFIER_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let name_options = TextOptions::default()
.set_indexing_options(name_indexing_options)
.set_stored();
let field_git_url = builder.add_text_field("git_url", STRING | STORED);
let field_filepath = builder.add_text_field("filepath", STRING | STORED);
let field_language = builder.add_text_field("language", STRING | STORED);
let field_name = builder.add_text_field("name", name_options);
let field_kind = builder.add_text_field("kind", STRING | STORED);
let field_body = builder.add_text_field("body", code_options);
let schema = builder.build();
Self {
schema,
field_git_url,
field_filepath,
field_language,
field_name,
field_kind,
field_body,
}
}
}
impl CodeSearchSchema {
pub fn language_query(&self, language: &str) -> Box<TermQuery> {
Box::new(TermQuery::new(
Term::from_field_text(self.field_language, language),
IndexRecordOption::WithFreqsAndPositions,
))
}
pub fn body_query(&self, tokens: &[String]) -> Box<TermSetQuery> {
Box::new(TermSetQuery::new(
tokens
.iter()
.map(|x| Term::from_field_text(self.field_body, x)),
))
} }
} }

View File

@ -1,3 +1,4 @@
pub mod api;
pub mod config; pub mod config;
pub mod events; pub mod events;
pub mod index; pub mod index;

View File

@ -3,16 +3,11 @@ use std::fs;
use anyhow::Result; use anyhow::Result;
use tabby_common::{ use tabby_common::{
config::Config, config::Config,
index::{IndexExt, CODE_TOKENIZER, IDENTIFIER_TOKENIZER}, index::{register_tokenizers, CodeSearchSchema},
path::index_dir, path::index_dir,
SourceFile, SourceFile,
}; };
use tantivy::{ use tantivy::{directory::MmapDirectory, doc, Index};
directory::MmapDirectory,
doc,
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
Index,
};
// Magic numbers // Magic numbers
static MAX_LINE_LENGTH_THRESHOLD: usize = 300; static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
@ -20,35 +15,12 @@ static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
static MAX_BODY_LINES_THRESHOLD: usize = 15; static MAX_BODY_LINES_THRESHOLD: usize = 15;
pub fn index_repositories(_config: &Config) -> Result<()> { pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder(); let code = CodeSearchSchema::new();
let code_indexing_options = TextFieldIndexing::default()
.set_tokenizer(CODE_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let code_options = TextOptions::default()
.set_indexing_options(code_indexing_options)
.set_stored();
let name_indexing_options = TextFieldIndexing::default()
.set_tokenizer(IDENTIFIER_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let name_options = TextOptions::default()
.set_indexing_options(name_indexing_options)
.set_stored();
let field_git_url = builder.add_text_field("git_url", STRING | STORED);
let field_filepath = builder.add_text_field("filepath", STRING | STORED);
let field_language = builder.add_text_field("language", STRING | STORED);
let field_name = builder.add_text_field("name", name_options);
let field_kind = builder.add_text_field("kind", STRING | STORED);
let field_body = builder.add_text_field("body", code_options);
let schema = builder.build();
fs::create_dir_all(index_dir())?; fs::create_dir_all(index_dir())?;
let directory = MmapDirectory::open(index_dir())?; let directory = MmapDirectory::open(index_dir())?;
let index = Index::open_or_create(directory, schema)?; let index = Index::open_or_create(directory, code.schema)?;
index.register_tokenizer(); register_tokenizers(&index);
let mut writer = index.writer(10_000_000)?; let mut writer = index.writer(10_000_000)?;
writer.delete_all_documents()?; writer.delete_all_documents()?;
@ -64,12 +36,12 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
for doc in from_source_file(file) { for doc in from_source_file(file) {
writer.add_document(doc!( writer.add_document(doc!(
field_git_url => doc.git_url, code.field_git_url => doc.git_url,
field_filepath => doc.filepath, code.field_filepath => doc.filepath,
field_language => doc.language, code.field_language => doc.language,
field_name => doc.name, code.field_name => doc.name,
field_body => doc.body, code.field_body => doc.body,
field_kind => doc.kind, code.field_kind => doc.kind,
))?; ))?;
} }
} }

View File

@ -14,7 +14,7 @@ tabby-inference = { path = "../tabby-inference" }
axum = "0.6" axum = "0.6"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] } utoipa = { workspace= true, features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] } utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
serde = { workspace = true } serde = { workspace = true }
serdeconv = { workspace = true } serdeconv = { workspace = true }
@ -42,9 +42,9 @@ axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] } minijinja = { version = "1.0.8", features = ["loader"] }
textdistance = "1.0.2" textdistance = "1.0.2"
regex.workspace = true regex.workspace = true
thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" } llama-cpp-bindings = { path = "../llama-cpp-bindings" }
futures.workspace = true futures.workspace = true
async-trait.workspace = true
[dependencies.uuid] [dependencies.uuid]
version = "1.3.3" version = "1.3.3"

View File

@ -1,93 +1,39 @@
use std::{sync::Arc, time::Duration}; use std::{sync::Arc, time::Duration};
use anyhow::Result; use anyhow::Result;
use axum::async_trait; use async_trait::async_trait;
use serde::Serialize; use tabby_common::{
use tabby_common::{index::IndexExt, path}; api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
use tantivy::{ use tantivy::{
collector::{Count, TopDocs}, collector::{Count, TopDocs},
query::{QueryParser, TermQuery, TermSetQuery}, query::QueryParser,
schema::{Field, IndexRecordOption}, schema::Field,
DocAddress, Document, Index, IndexReader, Term, DocAddress, Document, Index, IndexReader,
}; };
use thiserror::Error;
use tokio::{sync::Mutex, time::sleep}; use tokio::{sync::Mutex, time::sleep};
use tracing::{debug, log::info}; use tracing::{debug, log::info};
use utoipa::ToSchema;
#[derive(Serialize, ToSchema)]
pub struct SearchResponse {
pub num_hits: usize,
pub hits: Vec<Hit>,
}
#[derive(Serialize, ToSchema)]
pub struct Hit {
pub score: f32,
pub doc: HitDocument,
pub id: u32,
}
#[derive(Serialize, ToSchema)]
pub struct HitDocument {
pub body: String,
pub filepath: String,
pub git_url: String,
pub kind: String,
pub language: String,
pub name: String,
}
#[derive(Error, Debug)]
pub enum CodeSearchError {
#[error("index not ready")]
NotReady,
#[error("{0}")]
QueryParserError(#[from] tantivy::query::QueryParserError),
#[error("{0}")]
TantivyError(#[from] tantivy::TantivyError),
}
#[async_trait]
pub trait CodeSearch {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
async fn search_with_query(
&self,
q: &dyn tantivy::query::Query,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError>;
}
struct CodeSearchImpl { struct CodeSearchImpl {
reader: IndexReader, reader: IndexReader,
query_parser: QueryParser, query_parser: QueryParser,
field_body: Field, schema: CodeSearchSchema,
field_filepath: Field,
field_git_url: Field,
field_kind: Field,
field_language: Field,
field_name: Field,
} }
impl CodeSearchImpl { impl CodeSearchImpl {
fn load() -> Result<Self> { fn load() -> Result<Self> {
let code_schema = index::CodeSearchSchema::new();
let index = Index::open_in_dir(path::index_dir())?; let index = Index::open_in_dir(path::index_dir())?;
index.register_tokenizer(); register_tokenizers(&index);
let schema = index.schema(); let query_parser = QueryParser::new(
let field_body = schema.get_field("body").unwrap(); code_schema.schema.clone(),
let query_parser = vec![code_schema.field_body],
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone()); index.tokenizers().clone(),
);
let reader = index let reader = index
.reader_builder() .reader_builder()
.reload_policy(tantivy::ReloadPolicy::OnCommit) .reload_policy(tantivy::ReloadPolicy::OnCommit)
@ -95,12 +41,7 @@ impl CodeSearchImpl {
Ok(Self { Ok(Self {
reader, reader,
query_parser, query_parser,
field_body, schema: code_schema,
field_filepath: schema.get_field("filepath").unwrap(),
field_git_url: schema.get_field("git_url").unwrap(),
field_kind: schema.get_field("kind").unwrap(),
field_language: schema.get_field("language").unwrap(),
field_name: schema.get_field("name").unwrap(),
}) })
} }
@ -124,12 +65,12 @@ impl CodeSearchImpl {
Hit { Hit {
score, score,
doc: HitDocument { doc: HitDocument {
body: get_field(&doc, self.field_body), body: get_field(&doc, self.schema.field_body),
filepath: get_field(&doc, self.field_filepath), filepath: get_field(&doc, self.schema.field_filepath),
git_url: get_field(&doc, self.field_git_url), git_url: get_field(&doc, self.schema.field_git_url),
kind: get_field(&doc, self.field_kind), kind: get_field(&doc, self.schema.field_kind),
name: get_field(&doc, self.field_name), name: get_field(&doc, self.schema.field_name),
language: get_field(&doc, self.field_language), language: get_field(&doc, self.schema.field_language),
}, },
id: doc_address.doc_id, id: doc_address.doc_id,
} }
@ -196,41 +137,6 @@ impl CodeSearchService {
ret ret
} }
async fn with_impl<T, F>(&self, op: F) -> Result<T, CodeSearchError>
where
F: FnOnce(&CodeSearchImpl) -> Result<T, CodeSearchError>,
{
if let Some(imp) = self.search.lock().await.as_ref() {
op(imp)
} else {
Err(CodeSearchError::NotReady)
}
}
pub async fn language_query(&self, language: &str) -> Result<Box<TermQuery>, CodeSearchError> {
self.with_impl(|imp| {
Ok(Box::new(TermQuery::new(
Term::from_field_text(imp.field_language, language),
IndexRecordOption::WithFreqsAndPositions,
)))
})
.await
}
pub async fn body_query(
&self,
tokens: &[String],
) -> Result<Box<TermSetQuery>, CodeSearchError> {
self.with_impl(|imp| {
Ok(Box::new(TermSetQuery::new(
tokens
.iter()
.map(|x| Term::from_field_text(imp.field_body, x)),
)))
})
.await
}
} }
#[async_trait] #[async_trait]

View File

@ -3,19 +3,24 @@ use std::sync::Arc;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use strfmt::strfmt; use strfmt::strfmt;
use tabby_common::languages::get_language; use tabby_common::{
api::code::{CodeSearch, CodeSearchError},
index::CodeSearchSchema,
languages::get_language,
};
use tantivy::{query::BooleanQuery, query_grammar::Occur}; use tantivy::{query::BooleanQuery, query_grammar::Occur};
use textdistance::Algorithm; use textdistance::Algorithm;
use tracing::warn; use tracing::warn;
use super::{Segments, Snippet}; use super::{Segments, Snippet};
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService}; use crate::search::CodeSearchService;
static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
pub struct PromptBuilder { pub struct PromptBuilder {
schema: CodeSearchSchema,
prompt_template: Option<String>, prompt_template: Option<String>,
code: Option<Arc<CodeSearchService>>, code: Option<Arc<CodeSearchService>>,
} }
@ -23,6 +28,7 @@ pub struct PromptBuilder {
impl PromptBuilder { impl PromptBuilder {
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self { pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self {
PromptBuilder { PromptBuilder {
schema: CodeSearchSchema::new(),
prompt_template, prompt_template,
code, code,
} }
@ -38,7 +44,7 @@ impl PromptBuilder {
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
if let Some(code) = &self.code { if let Some(code) = &self.code {
collect_snippets(code, language, &segments.prefix).await collect_snippets(&self.schema, code, language, &segments.prefix).await
} else { } else {
vec![] vec![]
} }
@ -105,16 +111,17 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
format!("{}\n{}", comments, prefix) format!("{}\n{}", comments, prefix)
} }
async fn collect_snippets(code: &CodeSearchService, language: &str, text: &str) -> Vec<Snippet> { async fn collect_snippets(
schema: &CodeSearchSchema,
code: &CodeSearchService,
language: &str,
text: &str,
) -> Vec<Snippet> {
let mut ret = Vec::new(); let mut ret = Vec::new();
let mut tokens = tokenize_text(text); let mut tokens = tokenize_text(text);
let Ok(language_query) = code.language_query(language).await else { let language_query = schema.language_query(language);
return vec![]; let body_query = schema.body_query(&tokens);
};
let Ok(body_query) = code.body_query(&tokens).await else {
return vec![];
};
let query = BooleanQuery::new(vec![ let query = BooleanQuery::new(vec![
(Occur::Must, language_query), (Occur::Must, language_query),
(Occur::Must, body_query), (Occur::Must, body_query),

View File

@ -16,7 +16,11 @@ use std::{
use axum::{routing, Router, Server}; use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args; use clap::Args;
use tabby_common::{config::Config, usage}; use tabby_common::{
api::code::{Hit, HitDocument, SearchResponse},
config::Config,
usage,
};
use tabby_download::download_model; use tabby_download::download_model;
use tokio::time::sleep; use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
@ -62,9 +66,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
crate::chat::ChatCompletionChunk, crate::chat::ChatCompletionChunk,
health::HealthState, health::HealthState,
health::Version, health::Version,
crate::search::SearchResponse, SearchResponse,
crate::search::Hit, Hit,
crate::search::HitDocument HitDocument
)) ))
)] )]
struct ApiDoc; struct ApiDoc;

View File

@ -7,10 +7,11 @@ use axum::{
}; };
use hyper::StatusCode; use hyper::StatusCode;
use serde::Deserialize; use serde::Deserialize;
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn}; use tracing::{instrument, warn};
use utoipa::IntoParams; use utoipa::IntoParams;
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService, SearchResponse}; use crate::search::CodeSearchService;
#[derive(Deserialize, IntoParams)] #[derive(Deserialize, IntoParams)]
pub struct SearchQuery { pub struct SearchQuery {