refactor: extract tabby_common::api::code / tabby_common::index::CodeSearchSchema (#743)
* refactor: extract tabby_common::api::code mark CodeSearch being Send + Sync * extract CodeSearchSchemarefactor-extract-code
parent
ff03e2a34e
commit
b510f61aca
|
|
@ -334,9 +334,9 @@ checksum = "b4eb2cdb97421e01129ccb49169d8279ed21e829929144f4a22a6e54ac549ca1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.72"
|
version = "0.1.74"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09"
|
checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
|
@ -4058,6 +4058,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"assert-json-diff",
|
"assert-json-diff",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-streams",
|
"axum-streams",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
|
|
@ -4087,7 +4088,6 @@ dependencies = [
|
||||||
"tabby-scheduler",
|
"tabby-scheduler",
|
||||||
"tantivy",
|
"tantivy",
|
||||||
"textdistance",
|
"textdistance",
|
||||||
"thiserror",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower-http 0.4.0",
|
"tower-http 0.4.0",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
@ -4104,6 +4104,7 @@ name = "tabby-common"
|
||||||
version = "0.6.0-dev"
|
version = "0.6.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"chrono",
|
"chrono",
|
||||||
"filenamify",
|
"filenamify",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
|
|
@ -4112,7 +4113,9 @@ dependencies = [
|
||||||
"serde-jsonlines",
|
"serde-jsonlines",
|
||||||
"serdeconv",
|
"serdeconv",
|
||||||
"tantivy",
|
"tantivy",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"utoipa",
|
||||||
"uuid 1.4.1",
|
"uuid 1.4.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,3 +34,4 @@ futures = "0.3.28"
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3.5"
|
||||||
regex = "1.10.0"
|
regex = "1.10.0"
|
||||||
thiserror = "1.0.49"
|
thiserror = "1.0.49"
|
||||||
|
utoipa = "3.3"
|
||||||
|
|
@ -15,6 +15,9 @@ tokio = { workspace = true, features = ["rt", "macros"] }
|
||||||
uuid = { version = "1.4.1", features = ["v4"] }
|
uuid = { version = "1.4.1", features = ["v4"] }
|
||||||
tantivy.workspace = true
|
tantivy.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
utoipa = { workspace = true, features = ["axum_extras", "preserve_order"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
testutils = []
|
testutils = []
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct SearchResponse {
|
||||||
|
pub num_hits: usize,
|
||||||
|
pub hits: Vec<Hit>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct Hit {
|
||||||
|
pub score: f32,
|
||||||
|
pub doc: HitDocument,
|
||||||
|
pub id: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct HitDocument {
|
||||||
|
pub body: String,
|
||||||
|
pub filepath: String,
|
||||||
|
pub git_url: String,
|
||||||
|
pub kind: String,
|
||||||
|
pub language: String,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum CodeSearchError {
|
||||||
|
#[error("index not ready")]
|
||||||
|
NotReady,
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
QueryParserError(#[from] tantivy::query::QueryParserError),
|
||||||
|
|
||||||
|
#[error("{0}")]
|
||||||
|
TantivyError(#[from] tantivy::TantivyError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait CodeSearch: Send + Sync {
|
||||||
|
async fn search(
|
||||||
|
&self,
|
||||||
|
q: &str,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError>;
|
||||||
|
|
||||||
|
async fn search_with_query(
|
||||||
|
&self,
|
||||||
|
q: &dyn tantivy::query::Query,
|
||||||
|
limit: usize,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<SearchResponse, CodeSearchError>;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
pub mod code;
|
||||||
|
|
@ -1,27 +1,89 @@
|
||||||
use tantivy::{
|
use tantivy::{
|
||||||
|
query::{TermQuery, TermSetQuery},
|
||||||
|
schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING},
|
||||||
tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer},
|
tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer},
|
||||||
Index,
|
Index, Term,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait IndexExt {
|
|
||||||
fn register_tokenizer(&self);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub static CODE_TOKENIZER: &str = "code";
|
pub static CODE_TOKENIZER: &str = "code";
|
||||||
pub static IDENTIFIER_TOKENIZER: &str = "identifier";
|
pub static IDENTIFIER_TOKENIZER: &str = "identifier";
|
||||||
|
|
||||||
impl IndexExt for Index {
|
pub fn register_tokenizers(index: &Index) {
|
||||||
fn register_tokenizer(&self) {
|
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
|
||||||
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
|
.filter(RemoveLongFilter::limit(128))
|
||||||
.filter(RemoveLongFilter::limit(128))
|
.build();
|
||||||
.build();
|
|
||||||
|
|
||||||
self.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
|
index.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
|
||||||
|
|
||||||
let identifier_tokenzier =
|
let identifier_tokenzier =
|
||||||
TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build();
|
TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build();
|
||||||
|
|
||||||
self.tokenizers()
|
index
|
||||||
.register(IDENTIFIER_TOKENIZER, identifier_tokenzier);
|
.tokenizers()
|
||||||
|
.register(IDENTIFIER_TOKENIZER, identifier_tokenzier);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CodeSearchSchema {
|
||||||
|
pub schema: Schema,
|
||||||
|
pub field_git_url: Field,
|
||||||
|
pub field_filepath: Field,
|
||||||
|
pub field_language: Field,
|
||||||
|
pub field_name: Field,
|
||||||
|
pub field_kind: Field,
|
||||||
|
pub field_body: Field,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CodeSearchSchema {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let mut builder = Schema::builder();
|
||||||
|
|
||||||
|
let code_indexing_options = TextFieldIndexing::default()
|
||||||
|
.set_tokenizer(CODE_TOKENIZER)
|
||||||
|
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
|
||||||
|
let code_options = TextOptions::default()
|
||||||
|
.set_indexing_options(code_indexing_options)
|
||||||
|
.set_stored();
|
||||||
|
|
||||||
|
let name_indexing_options = TextFieldIndexing::default()
|
||||||
|
.set_tokenizer(IDENTIFIER_TOKENIZER)
|
||||||
|
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
|
||||||
|
let name_options = TextOptions::default()
|
||||||
|
.set_indexing_options(name_indexing_options)
|
||||||
|
.set_stored();
|
||||||
|
|
||||||
|
let field_git_url = builder.add_text_field("git_url", STRING | STORED);
|
||||||
|
let field_filepath = builder.add_text_field("filepath", STRING | STORED);
|
||||||
|
let field_language = builder.add_text_field("language", STRING | STORED);
|
||||||
|
let field_name = builder.add_text_field("name", name_options);
|
||||||
|
let field_kind = builder.add_text_field("kind", STRING | STORED);
|
||||||
|
let field_body = builder.add_text_field("body", code_options);
|
||||||
|
let schema = builder.build();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
schema,
|
||||||
|
field_git_url,
|
||||||
|
field_filepath,
|
||||||
|
field_language,
|
||||||
|
field_name,
|
||||||
|
field_kind,
|
||||||
|
field_body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CodeSearchSchema {
|
||||||
|
pub fn language_query(&self, language: &str) -> Box<TermQuery> {
|
||||||
|
Box::new(TermQuery::new(
|
||||||
|
Term::from_field_text(self.field_language, language),
|
||||||
|
IndexRecordOption::WithFreqsAndPositions,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn body_query(&self, tokens: &[String]) -> Box<TermSetQuery> {
|
||||||
|
Box::new(TermSetQuery::new(
|
||||||
|
tokens
|
||||||
|
.iter()
|
||||||
|
.map(|x| Term::from_field_text(self.field_body, x)),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod api;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod events;
|
pub mod events;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,11 @@ use std::fs;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
config::Config,
|
config::Config,
|
||||||
index::{IndexExt, CODE_TOKENIZER, IDENTIFIER_TOKENIZER},
|
index::{register_tokenizers, CodeSearchSchema},
|
||||||
path::index_dir,
|
path::index_dir,
|
||||||
SourceFile,
|
SourceFile,
|
||||||
};
|
};
|
||||||
use tantivy::{
|
use tantivy::{directory::MmapDirectory, doc, Index};
|
||||||
directory::MmapDirectory,
|
|
||||||
doc,
|
|
||||||
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
|
|
||||||
Index,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Magic numbers
|
// Magic numbers
|
||||||
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
|
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
|
||||||
|
|
@ -20,35 +15,12 @@ static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
|
||||||
static MAX_BODY_LINES_THRESHOLD: usize = 15;
|
static MAX_BODY_LINES_THRESHOLD: usize = 15;
|
||||||
|
|
||||||
pub fn index_repositories(_config: &Config) -> Result<()> {
|
pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||||
let mut builder = Schema::builder();
|
let code = CodeSearchSchema::new();
|
||||||
|
|
||||||
let code_indexing_options = TextFieldIndexing::default()
|
|
||||||
.set_tokenizer(CODE_TOKENIZER)
|
|
||||||
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
|
|
||||||
let code_options = TextOptions::default()
|
|
||||||
.set_indexing_options(code_indexing_options)
|
|
||||||
.set_stored();
|
|
||||||
|
|
||||||
let name_indexing_options = TextFieldIndexing::default()
|
|
||||||
.set_tokenizer(IDENTIFIER_TOKENIZER)
|
|
||||||
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
|
|
||||||
let name_options = TextOptions::default()
|
|
||||||
.set_indexing_options(name_indexing_options)
|
|
||||||
.set_stored();
|
|
||||||
|
|
||||||
let field_git_url = builder.add_text_field("git_url", STRING | STORED);
|
|
||||||
let field_filepath = builder.add_text_field("filepath", STRING | STORED);
|
|
||||||
let field_language = builder.add_text_field("language", STRING | STORED);
|
|
||||||
let field_name = builder.add_text_field("name", name_options);
|
|
||||||
let field_kind = builder.add_text_field("kind", STRING | STORED);
|
|
||||||
let field_body = builder.add_text_field("body", code_options);
|
|
||||||
|
|
||||||
let schema = builder.build();
|
|
||||||
|
|
||||||
fs::create_dir_all(index_dir())?;
|
fs::create_dir_all(index_dir())?;
|
||||||
let directory = MmapDirectory::open(index_dir())?;
|
let directory = MmapDirectory::open(index_dir())?;
|
||||||
let index = Index::open_or_create(directory, schema)?;
|
let index = Index::open_or_create(directory, code.schema)?;
|
||||||
index.register_tokenizer();
|
register_tokenizers(&index);
|
||||||
|
|
||||||
let mut writer = index.writer(10_000_000)?;
|
let mut writer = index.writer(10_000_000)?;
|
||||||
writer.delete_all_documents()?;
|
writer.delete_all_documents()?;
|
||||||
|
|
@ -64,12 +36,12 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||||
|
|
||||||
for doc in from_source_file(file) {
|
for doc in from_source_file(file) {
|
||||||
writer.add_document(doc!(
|
writer.add_document(doc!(
|
||||||
field_git_url => doc.git_url,
|
code.field_git_url => doc.git_url,
|
||||||
field_filepath => doc.filepath,
|
code.field_filepath => doc.filepath,
|
||||||
field_language => doc.language,
|
code.field_language => doc.language,
|
||||||
field_name => doc.name,
|
code.field_name => doc.name,
|
||||||
field_body => doc.body,
|
code.field_body => doc.body,
|
||||||
field_kind => doc.kind,
|
code.field_kind => doc.kind,
|
||||||
))?;
|
))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ tabby-inference = { path = "../tabby-inference" }
|
||||||
axum = "0.6"
|
axum = "0.6"
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
tokio = { workspace = true, features = ["full"] }
|
tokio = { workspace = true, features = ["full"] }
|
||||||
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
utoipa = { workspace= true, features = ["axum_extras", "preserve_order"] }
|
||||||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serdeconv = { workspace = true }
|
serdeconv = { workspace = true }
|
||||||
|
|
@ -42,9 +42,9 @@ axum-streams = { version = "0.9.1", features = ["json"] }
|
||||||
minijinja = { version = "1.0.8", features = ["loader"] }
|
minijinja = { version = "1.0.8", features = ["loader"] }
|
||||||
textdistance = "1.0.2"
|
textdistance = "1.0.2"
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
thiserror.workspace = true
|
|
||||||
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
|
|
||||||
|
|
@ -1,93 +1,39 @@
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::Serialize;
|
use tabby_common::{
|
||||||
use tabby_common::{index::IndexExt, path};
|
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
|
||||||
|
index::{self, register_tokenizers, CodeSearchSchema},
|
||||||
|
path,
|
||||||
|
};
|
||||||
use tantivy::{
|
use tantivy::{
|
||||||
collector::{Count, TopDocs},
|
collector::{Count, TopDocs},
|
||||||
query::{QueryParser, TermQuery, TermSetQuery},
|
query::QueryParser,
|
||||||
schema::{Field, IndexRecordOption},
|
schema::Field,
|
||||||
DocAddress, Document, Index, IndexReader, Term,
|
DocAddress, Document, Index, IndexReader,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
|
||||||
use tokio::{sync::Mutex, time::sleep};
|
use tokio::{sync::Mutex, time::sleep};
|
||||||
use tracing::{debug, log::info};
|
use tracing::{debug, log::info};
|
||||||
use utoipa::ToSchema;
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct SearchResponse {
|
|
||||||
pub num_hits: usize,
|
|
||||||
pub hits: Vec<Hit>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct Hit {
|
|
||||||
pub score: f32,
|
|
||||||
pub doc: HitDocument,
|
|
||||||
pub id: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
|
||||||
pub struct HitDocument {
|
|
||||||
pub body: String,
|
|
||||||
pub filepath: String,
|
|
||||||
pub git_url: String,
|
|
||||||
pub kind: String,
|
|
||||||
pub language: String,
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
|
||||||
pub enum CodeSearchError {
|
|
||||||
#[error("index not ready")]
|
|
||||||
NotReady,
|
|
||||||
|
|
||||||
#[error("{0}")]
|
|
||||||
QueryParserError(#[from] tantivy::query::QueryParserError),
|
|
||||||
|
|
||||||
#[error("{0}")]
|
|
||||||
TantivyError(#[from] tantivy::TantivyError),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait CodeSearch {
|
|
||||||
async fn search(
|
|
||||||
&self,
|
|
||||||
q: &str,
|
|
||||||
limit: usize,
|
|
||||||
offset: usize,
|
|
||||||
) -> Result<SearchResponse, CodeSearchError>;
|
|
||||||
|
|
||||||
async fn search_with_query(
|
|
||||||
&self,
|
|
||||||
q: &dyn tantivy::query::Query,
|
|
||||||
limit: usize,
|
|
||||||
offset: usize,
|
|
||||||
) -> Result<SearchResponse, CodeSearchError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CodeSearchImpl {
|
struct CodeSearchImpl {
|
||||||
reader: IndexReader,
|
reader: IndexReader,
|
||||||
query_parser: QueryParser,
|
query_parser: QueryParser,
|
||||||
|
|
||||||
field_body: Field,
|
schema: CodeSearchSchema,
|
||||||
field_filepath: Field,
|
|
||||||
field_git_url: Field,
|
|
||||||
field_kind: Field,
|
|
||||||
field_language: Field,
|
|
||||||
field_name: Field,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CodeSearchImpl {
|
impl CodeSearchImpl {
|
||||||
fn load() -> Result<Self> {
|
fn load() -> Result<Self> {
|
||||||
|
let code_schema = index::CodeSearchSchema::new();
|
||||||
let index = Index::open_in_dir(path::index_dir())?;
|
let index = Index::open_in_dir(path::index_dir())?;
|
||||||
index.register_tokenizer();
|
register_tokenizers(&index);
|
||||||
|
|
||||||
let schema = index.schema();
|
let query_parser = QueryParser::new(
|
||||||
let field_body = schema.get_field("body").unwrap();
|
code_schema.schema.clone(),
|
||||||
let query_parser =
|
vec![code_schema.field_body],
|
||||||
QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone());
|
index.tokenizers().clone(),
|
||||||
|
);
|
||||||
let reader = index
|
let reader = index
|
||||||
.reader_builder()
|
.reader_builder()
|
||||||
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
.reload_policy(tantivy::ReloadPolicy::OnCommit)
|
||||||
|
|
@ -95,12 +41,7 @@ impl CodeSearchImpl {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
reader,
|
reader,
|
||||||
query_parser,
|
query_parser,
|
||||||
field_body,
|
schema: code_schema,
|
||||||
field_filepath: schema.get_field("filepath").unwrap(),
|
|
||||||
field_git_url: schema.get_field("git_url").unwrap(),
|
|
||||||
field_kind: schema.get_field("kind").unwrap(),
|
|
||||||
field_language: schema.get_field("language").unwrap(),
|
|
||||||
field_name: schema.get_field("name").unwrap(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,12 +65,12 @@ impl CodeSearchImpl {
|
||||||
Hit {
|
Hit {
|
||||||
score,
|
score,
|
||||||
doc: HitDocument {
|
doc: HitDocument {
|
||||||
body: get_field(&doc, self.field_body),
|
body: get_field(&doc, self.schema.field_body),
|
||||||
filepath: get_field(&doc, self.field_filepath),
|
filepath: get_field(&doc, self.schema.field_filepath),
|
||||||
git_url: get_field(&doc, self.field_git_url),
|
git_url: get_field(&doc, self.schema.field_git_url),
|
||||||
kind: get_field(&doc, self.field_kind),
|
kind: get_field(&doc, self.schema.field_kind),
|
||||||
name: get_field(&doc, self.field_name),
|
name: get_field(&doc, self.schema.field_name),
|
||||||
language: get_field(&doc, self.field_language),
|
language: get_field(&doc, self.schema.field_language),
|
||||||
},
|
},
|
||||||
id: doc_address.doc_id,
|
id: doc_address.doc_id,
|
||||||
}
|
}
|
||||||
|
|
@ -196,41 +137,6 @@ impl CodeSearchService {
|
||||||
|
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn with_impl<T, F>(&self, op: F) -> Result<T, CodeSearchError>
|
|
||||||
where
|
|
||||||
F: FnOnce(&CodeSearchImpl) -> Result<T, CodeSearchError>,
|
|
||||||
{
|
|
||||||
if let Some(imp) = self.search.lock().await.as_ref() {
|
|
||||||
op(imp)
|
|
||||||
} else {
|
|
||||||
Err(CodeSearchError::NotReady)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn language_query(&self, language: &str) -> Result<Box<TermQuery>, CodeSearchError> {
|
|
||||||
self.with_impl(|imp| {
|
|
||||||
Ok(Box::new(TermQuery::new(
|
|
||||||
Term::from_field_text(imp.field_language, language),
|
|
||||||
IndexRecordOption::WithFreqsAndPositions,
|
|
||||||
)))
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn body_query(
|
|
||||||
&self,
|
|
||||||
tokens: &[String],
|
|
||||||
) -> Result<Box<TermSetQuery>, CodeSearchError> {
|
|
||||||
self.with_impl(|imp| {
|
|
||||||
Ok(Box::new(TermSetQuery::new(
|
|
||||||
tokens
|
|
||||||
.iter()
|
|
||||||
.map(|x| Term::from_field_text(imp.field_body, x)),
|
|
||||||
)))
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
|
||||||
|
|
@ -3,19 +3,24 @@ use std::sync::Arc;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use strfmt::strfmt;
|
use strfmt::strfmt;
|
||||||
use tabby_common::languages::get_language;
|
use tabby_common::{
|
||||||
|
api::code::{CodeSearch, CodeSearchError},
|
||||||
|
index::CodeSearchSchema,
|
||||||
|
languages::get_language,
|
||||||
|
};
|
||||||
use tantivy::{query::BooleanQuery, query_grammar::Occur};
|
use tantivy::{query::BooleanQuery, query_grammar::Occur};
|
||||||
use textdistance::Algorithm;
|
use textdistance::Algorithm;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
use super::{Segments, Snippet};
|
use super::{Segments, Snippet};
|
||||||
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService};
|
use crate::search::CodeSearchService;
|
||||||
|
|
||||||
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
static MAX_SNIPPETS_TO_FETCH: usize = 20;
|
||||||
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;
|
||||||
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
|
static MAX_SIMILARITY_THRESHOLD: f32 = 0.9;
|
||||||
|
|
||||||
pub struct PromptBuilder {
|
pub struct PromptBuilder {
|
||||||
|
schema: CodeSearchSchema,
|
||||||
prompt_template: Option<String>,
|
prompt_template: Option<String>,
|
||||||
code: Option<Arc<CodeSearchService>>,
|
code: Option<Arc<CodeSearchService>>,
|
||||||
}
|
}
|
||||||
|
|
@ -23,6 +28,7 @@ pub struct PromptBuilder {
|
||||||
impl PromptBuilder {
|
impl PromptBuilder {
|
||||||
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self {
|
pub fn new(prompt_template: Option<String>, code: Option<Arc<CodeSearchService>>) -> Self {
|
||||||
PromptBuilder {
|
PromptBuilder {
|
||||||
|
schema: CodeSearchSchema::new(),
|
||||||
prompt_template,
|
prompt_template,
|
||||||
code,
|
code,
|
||||||
}
|
}
|
||||||
|
|
@ -38,7 +44,7 @@ impl PromptBuilder {
|
||||||
|
|
||||||
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
|
pub async fn collect(&self, language: &str, segments: &Segments) -> Vec<Snippet> {
|
||||||
if let Some(code) = &self.code {
|
if let Some(code) = &self.code {
|
||||||
collect_snippets(code, language, &segments.prefix).await
|
collect_snippets(&self.schema, code, language, &segments.prefix).await
|
||||||
} else {
|
} else {
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
|
|
@ -105,16 +111,17 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String {
|
||||||
format!("{}\n{}", comments, prefix)
|
format!("{}\n{}", comments, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn collect_snippets(code: &CodeSearchService, language: &str, text: &str) -> Vec<Snippet> {
|
async fn collect_snippets(
|
||||||
|
schema: &CodeSearchSchema,
|
||||||
|
code: &CodeSearchService,
|
||||||
|
language: &str,
|
||||||
|
text: &str,
|
||||||
|
) -> Vec<Snippet> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
let mut tokens = tokenize_text(text);
|
let mut tokens = tokenize_text(text);
|
||||||
|
|
||||||
let Ok(language_query) = code.language_query(language).await else {
|
let language_query = schema.language_query(language);
|
||||||
return vec![];
|
let body_query = schema.body_query(&tokens);
|
||||||
};
|
|
||||||
let Ok(body_query) = code.body_query(&tokens).await else {
|
|
||||||
return vec![];
|
|
||||||
};
|
|
||||||
let query = BooleanQuery::new(vec)
|
))
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,11 @@ use axum::{
|
||||||
};
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
|
||||||
use tracing::{instrument, warn};
|
use tracing::{instrument, warn};
|
||||||
use utoipa::IntoParams;
|
use utoipa::IntoParams;
|
||||||
|
|
||||||
use crate::search::{CodeSearch, CodeSearchError, CodeSearchService, SearchResponse};
|
use crate::search::CodeSearchService;
|
||||||
|
|
||||||
#[derive(Deserialize, IntoParams)]
|
#[derive(Deserialize, IntoParams)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue