From b510f61acab86c135a9938d82428c90184ec92c1 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 10 Nov 2023 10:11:13 -0800 Subject: [PATCH] refactor: extract tabby_common::api::code / tabby_common::index::CodeSearchSchema (#743) * refactor: extract tabby_common::api::code mark CodeSearch being Send + Sync * extract CodeSearchSchema --- Cargo.lock | 9 +- Cargo.toml | 1 + crates/tabby-common/Cargo.toml | 3 + crates/tabby-common/src/api/code.rs | 56 ++++++++ crates/tabby-common/src/api/mod.rs | 1 + crates/tabby-common/src/index.rs | 92 ++++++++++-- crates/tabby-common/src/lib.rs | 1 + crates/tabby-scheduler/src/index.rs | 50 ++----- crates/tabby/Cargo.toml | 4 +- crates/tabby/src/search.rs | 142 ++++--------------- crates/tabby/src/serve/completions/prompt.rs | 27 ++-- crates/tabby/src/serve/mod.rs | 12 +- crates/tabby/src/serve/search.rs | 3 +- 13 files changed, 209 insertions(+), 192 deletions(-) create mode 100644 crates/tabby-common/src/api/code.rs create mode 100644 crates/tabby-common/src/api/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c2f31e1..609ed6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -334,9 +334,9 @@ checksum = "b4eb2cdb97421e01129ccb49169d8279ed21e829929144f4a22a6e54ac549ca1" [[package]] name = "async-trait" -version = "0.1.72" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6dde6e4ed435a4c1ee4e73592f5ba9da2151af10076cc04858746af9352d09" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", @@ -4058,6 +4058,7 @@ dependencies = [ "anyhow", "assert-json-diff", "async-stream", + "async-trait", "axum", "axum-streams", "axum-tracing-opentelemetry", @@ -4087,7 +4088,6 @@ dependencies = [ "tabby-scheduler", "tantivy", "textdistance", - "thiserror", "tokio", "tower-http 0.4.0", "tracing", @@ -4104,6 +4104,7 @@ name = "tabby-common" version = "0.6.0-dev" dependencies = [ "anyhow", + "async-trait", "chrono", "filenamify", "lazy_static", @@ -4112,7 +4113,9 @@ dependencies = [ "serde-jsonlines", "serdeconv", "tantivy", + "thiserror", "tokio", + "utoipa", "uuid 1.4.1", ] diff --git a/Cargo.toml b/Cargo.toml index b701d96..dae1755 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ futures = "0.3.28" async-stream = "0.3.5" regex = "1.10.0" thiserror = "1.0.49" +utoipa = "3.3" \ No newline at end of file diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index ccaac95..238ef9d 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -15,6 +15,9 @@ tokio = { workspace = true, features = ["rt", "macros"] } uuid = { version = "1.4.1", features = ["v4"] } tantivy.workspace = true anyhow.workspace = true +async-trait.workspace = true +thiserror.workspace = true +utoipa = { workspace = true, features = ["axum_extras", "preserve_order"] } [features] testutils = [] diff --git a/crates/tabby-common/src/api/code.rs b/crates/tabby-common/src/api/code.rs new file mode 100644 index 0000000..2491a99 --- /dev/null +++ b/crates/tabby-common/src/api/code.rs @@ -0,0 +1,56 @@ +use async_trait::async_trait; +use serde::Serialize; +use thiserror::Error; +use utoipa::ToSchema; + +#[derive(Serialize, ToSchema)] +pub struct SearchResponse { + pub num_hits: usize, + pub hits: Vec, +} + +#[derive(Serialize, ToSchema)] +pub struct Hit { + pub score: f32, + pub doc: HitDocument, + pub id: u32, +} + +#[derive(Serialize, ToSchema)] +pub struct HitDocument { + pub body: String, + pub filepath: String, + pub git_url: String, + pub kind: String, + pub language: String, + pub name: String, +} + +#[derive(Error, Debug)] +pub enum CodeSearchError { + #[error("index not ready")] + NotReady, + + #[error("{0}")] + QueryParserError(#[from] tantivy::query::QueryParserError), + + #[error("{0}")] + TantivyError(#[from] tantivy::TantivyError), +} + +#[async_trait] +pub trait CodeSearch: Send + Sync { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result; + + async fn search_with_query( + &self, + q: &dyn tantivy::query::Query, + limit: usize, + offset: usize, + ) -> Result; +} diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs new file mode 100644 index 0000000..9de50d4 --- /dev/null +++ b/crates/tabby-common/src/api/mod.rs @@ -0,0 +1 @@ +pub mod code; diff --git a/crates/tabby-common/src/index.rs b/crates/tabby-common/src/index.rs index 6bf7eec..eec2549 100644 --- a/crates/tabby-common/src/index.rs +++ b/crates/tabby-common/src/index.rs @@ -1,27 +1,89 @@ use tantivy::{ + query::{TermQuery, TermSetQuery}, + schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, STORED, STRING}, tokenizer::{NgramTokenizer, RegexTokenizer, RemoveLongFilter, TextAnalyzer}, - Index, + Index, Term, }; -pub trait IndexExt { - fn register_tokenizer(&self); -} - pub static CODE_TOKENIZER: &str = "code"; pub static IDENTIFIER_TOKENIZER: &str = "identifier"; -impl IndexExt for Index { - fn register_tokenizer(&self) { - let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap()) - .filter(RemoveLongFilter::limit(128)) - .build(); +pub fn register_tokenizers(index: &Index) { + let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap()) + .filter(RemoveLongFilter::limit(128)) + .build(); - self.tokenizers().register(CODE_TOKENIZER, code_tokenizer); + index.tokenizers().register(CODE_TOKENIZER, code_tokenizer); - let identifier_tokenzier = - TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build(); + let identifier_tokenzier = + TextAnalyzer::builder(NgramTokenizer::prefix_only(2, 5).unwrap()).build(); - self.tokenizers() - .register(IDENTIFIER_TOKENIZER, identifier_tokenzier); + index + .tokenizers() + .register(IDENTIFIER_TOKENIZER, identifier_tokenzier); +} + +pub struct CodeSearchSchema { + pub schema: Schema, + pub field_git_url: Field, + pub field_filepath: Field, + pub field_language: Field, + pub field_name: Field, + pub field_kind: Field, + pub field_body: Field, +} + +impl CodeSearchSchema { + pub fn new() -> Self { + let mut builder = Schema::builder(); + + let code_indexing_options = TextFieldIndexing::default() + .set_tokenizer(CODE_TOKENIZER) + .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); + let code_options = TextOptions::default() + .set_indexing_options(code_indexing_options) + .set_stored(); + + let name_indexing_options = TextFieldIndexing::default() + .set_tokenizer(IDENTIFIER_TOKENIZER) + .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); + let name_options = TextOptions::default() + .set_indexing_options(name_indexing_options) + .set_stored(); + + let field_git_url = builder.add_text_field("git_url", STRING | STORED); + let field_filepath = builder.add_text_field("filepath", STRING | STORED); + let field_language = builder.add_text_field("language", STRING | STORED); + let field_name = builder.add_text_field("name", name_options); + let field_kind = builder.add_text_field("kind", STRING | STORED); + let field_body = builder.add_text_field("body", code_options); + let schema = builder.build(); + + Self { + schema, + field_git_url, + field_filepath, + field_language, + field_name, + field_kind, + field_body, + } + } +} + +impl CodeSearchSchema { + pub fn language_query(&self, language: &str) -> Box { + Box::new(TermQuery::new( + Term::from_field_text(self.field_language, language), + IndexRecordOption::WithFreqsAndPositions, + )) + } + + pub fn body_query(&self, tokens: &[String]) -> Box { + Box::new(TermSetQuery::new( + tokens + .iter() + .map(|x| Term::from_field_text(self.field_body, x)), + )) } } diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 2458fef..cc548cc 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,3 +1,4 @@ +pub mod api; pub mod config; pub mod events; pub mod index; diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index ba052f0..cf2403d 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -3,16 +3,11 @@ use std::fs; use anyhow::Result; use tabby_common::{ config::Config, - index::{IndexExt, CODE_TOKENIZER, IDENTIFIER_TOKENIZER}, + index::{register_tokenizers, CodeSearchSchema}, path::index_dir, SourceFile, }; -use tantivy::{ - directory::MmapDirectory, - doc, - schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING}, - Index, -}; +use tantivy::{directory::MmapDirectory, doc, Index}; // Magic numbers static MAX_LINE_LENGTH_THRESHOLD: usize = 300; @@ -20,35 +15,12 @@ static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32; static MAX_BODY_LINES_THRESHOLD: usize = 15; pub fn index_repositories(_config: &Config) -> Result<()> { - let mut builder = Schema::builder(); - - let code_indexing_options = TextFieldIndexing::default() - .set_tokenizer(CODE_TOKENIZER) - .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); - let code_options = TextOptions::default() - .set_indexing_options(code_indexing_options) - .set_stored(); - - let name_indexing_options = TextFieldIndexing::default() - .set_tokenizer(IDENTIFIER_TOKENIZER) - .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); - let name_options = TextOptions::default() - .set_indexing_options(name_indexing_options) - .set_stored(); - - let field_git_url = builder.add_text_field("git_url", STRING | STORED); - let field_filepath = builder.add_text_field("filepath", STRING | STORED); - let field_language = builder.add_text_field("language", STRING | STORED); - let field_name = builder.add_text_field("name", name_options); - let field_kind = builder.add_text_field("kind", STRING | STORED); - let field_body = builder.add_text_field("body", code_options); - - let schema = builder.build(); + let code = CodeSearchSchema::new(); fs::create_dir_all(index_dir())?; let directory = MmapDirectory::open(index_dir())?; - let index = Index::open_or_create(directory, schema)?; - index.register_tokenizer(); + let index = Index::open_or_create(directory, code.schema)?; + register_tokenizers(&index); let mut writer = index.writer(10_000_000)?; writer.delete_all_documents()?; @@ -64,12 +36,12 @@ pub fn index_repositories(_config: &Config) -> Result<()> { for doc in from_source_file(file) { writer.add_document(doc!( - field_git_url => doc.git_url, - field_filepath => doc.filepath, - field_language => doc.language, - field_name => doc.name, - field_body => doc.body, - field_kind => doc.kind, + code.field_git_url => doc.git_url, + code.field_filepath => doc.filepath, + code.field_language => doc.language, + code.field_name => doc.name, + code.field_body => doc.body, + code.field_kind => doc.kind, ))?; } } diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index df54be6..0f670dc 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -14,7 +14,7 @@ tabby-inference = { path = "../tabby-inference" } axum = "0.6" hyper = { version = "0.14", features = ["full"] } tokio = { workspace = true, features = ["full"] } -utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] } +utoipa = { workspace= true, features = ["axum_extras", "preserve_order"] } utoipa-swagger-ui = { version = "3.1", features = ["axum"] } serde = { workspace = true } serdeconv = { workspace = true } @@ -42,9 +42,9 @@ axum-streams = { version = "0.9.1", features = ["json"] } minijinja = { version = "1.0.8", features = ["loader"] } textdistance = "1.0.2" regex.workspace = true -thiserror.workspace = true llama-cpp-bindings = { path = "../llama-cpp-bindings" } futures.workspace = true +async-trait.workspace = true [dependencies.uuid] version = "1.3.3" diff --git a/crates/tabby/src/search.rs b/crates/tabby/src/search.rs index 3f6cfa5..ed123e5 100644 --- a/crates/tabby/src/search.rs +++ b/crates/tabby/src/search.rs @@ -1,93 +1,39 @@ use std::{sync::Arc, time::Duration}; use anyhow::Result; -use axum::async_trait; -use serde::Serialize; -use tabby_common::{index::IndexExt, path}; +use async_trait::async_trait; +use tabby_common::{ + api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, + index::{self, register_tokenizers, CodeSearchSchema}, + path, +}; use tantivy::{ collector::{Count, TopDocs}, - query::{QueryParser, TermQuery, TermSetQuery}, - schema::{Field, IndexRecordOption}, - DocAddress, Document, Index, IndexReader, Term, + query::QueryParser, + schema::Field, + DocAddress, Document, Index, IndexReader, }; -use thiserror::Error; use tokio::{sync::Mutex, time::sleep}; use tracing::{debug, log::info}; -use utoipa::ToSchema; - -#[derive(Serialize, ToSchema)] -pub struct SearchResponse { - pub num_hits: usize, - pub hits: Vec, -} - -#[derive(Serialize, ToSchema)] -pub struct Hit { - pub score: f32, - pub doc: HitDocument, - pub id: u32, -} - -#[derive(Serialize, ToSchema)] -pub struct HitDocument { - pub body: String, - pub filepath: String, - pub git_url: String, - pub kind: String, - pub language: String, - pub name: String, -} - -#[derive(Error, Debug)] -pub enum CodeSearchError { - #[error("index not ready")] - NotReady, - - #[error("{0}")] - QueryParserError(#[from] tantivy::query::QueryParserError), - - #[error("{0}")] - TantivyError(#[from] tantivy::TantivyError), -} - -#[async_trait] -pub trait CodeSearch { - async fn search( - &self, - q: &str, - limit: usize, - offset: usize, - ) -> Result; - - async fn search_with_query( - &self, - q: &dyn tantivy::query::Query, - limit: usize, - offset: usize, - ) -> Result; -} struct CodeSearchImpl { reader: IndexReader, query_parser: QueryParser, - field_body: Field, - field_filepath: Field, - field_git_url: Field, - field_kind: Field, - field_language: Field, - field_name: Field, + schema: CodeSearchSchema, } impl CodeSearchImpl { fn load() -> Result { + let code_schema = index::CodeSearchSchema::new(); let index = Index::open_in_dir(path::index_dir())?; - index.register_tokenizer(); + register_tokenizers(&index); - let schema = index.schema(); - let field_body = schema.get_field("body").unwrap(); - let query_parser = - QueryParser::new(schema.clone(), vec![field_body], index.tokenizers().clone()); + let query_parser = QueryParser::new( + code_schema.schema.clone(), + vec![code_schema.field_body], + index.tokenizers().clone(), + ); let reader = index .reader_builder() .reload_policy(tantivy::ReloadPolicy::OnCommit) @@ -95,12 +41,7 @@ impl CodeSearchImpl { Ok(Self { reader, query_parser, - field_body, - field_filepath: schema.get_field("filepath").unwrap(), - field_git_url: schema.get_field("git_url").unwrap(), - field_kind: schema.get_field("kind").unwrap(), - field_language: schema.get_field("language").unwrap(), - field_name: schema.get_field("name").unwrap(), + schema: code_schema, }) } @@ -124,12 +65,12 @@ impl CodeSearchImpl { Hit { score, doc: HitDocument { - body: get_field(&doc, self.field_body), - filepath: get_field(&doc, self.field_filepath), - git_url: get_field(&doc, self.field_git_url), - kind: get_field(&doc, self.field_kind), - name: get_field(&doc, self.field_name), - language: get_field(&doc, self.field_language), + body: get_field(&doc, self.schema.field_body), + filepath: get_field(&doc, self.schema.field_filepath), + git_url: get_field(&doc, self.schema.field_git_url), + kind: get_field(&doc, self.schema.field_kind), + name: get_field(&doc, self.schema.field_name), + language: get_field(&doc, self.schema.field_language), }, id: doc_address.doc_id, } @@ -196,41 +137,6 @@ impl CodeSearchService { ret } - - async fn with_impl(&self, op: F) -> Result - where - F: FnOnce(&CodeSearchImpl) -> Result, - { - if let Some(imp) = self.search.lock().await.as_ref() { - op(imp) - } else { - Err(CodeSearchError::NotReady) - } - } - - pub async fn language_query(&self, language: &str) -> Result, CodeSearchError> { - self.with_impl(|imp| { - Ok(Box::new(TermQuery::new( - Term::from_field_text(imp.field_language, language), - IndexRecordOption::WithFreqsAndPositions, - ))) - }) - .await - } - - pub async fn body_query( - &self, - tokens: &[String], - ) -> Result, CodeSearchError> { - self.with_impl(|imp| { - Ok(Box::new(TermSetQuery::new( - tokens - .iter() - .map(|x| Term::from_field_text(imp.field_body, x)), - ))) - }) - .await - } } #[async_trait] diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index 8c4c2d3..a9a2f5a 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -3,19 +3,24 @@ use std::sync::Arc; use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; -use tabby_common::languages::get_language; +use tabby_common::{ + api::code::{CodeSearch, CodeSearchError}, + index::CodeSearchSchema, + languages::get_language, +}; use tantivy::{query::BooleanQuery, query_grammar::Occur}; use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::search::{CodeSearch, CodeSearchError, CodeSearchService}; +use crate::search::CodeSearchService; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; static MAX_SIMILARITY_THRESHOLD: f32 = 0.9; pub struct PromptBuilder { + schema: CodeSearchSchema, prompt_template: Option, code: Option>, } @@ -23,6 +28,7 @@ pub struct PromptBuilder { impl PromptBuilder { pub fn new(prompt_template: Option, code: Option>) -> Self { PromptBuilder { + schema: CodeSearchSchema::new(), prompt_template, code, } @@ -38,7 +44,7 @@ impl PromptBuilder { pub async fn collect(&self, language: &str, segments: &Segments) -> Vec { if let Some(code) = &self.code { - collect_snippets(code, language, &segments.prefix).await + collect_snippets(&self.schema, code, language, &segments.prefix).await } else { vec![] } @@ -105,16 +111,17 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { format!("{}\n{}", comments, prefix) } -async fn collect_snippets(code: &CodeSearchService, language: &str, text: &str) -> Vec { +async fn collect_snippets( + schema: &CodeSearchSchema, + code: &CodeSearchService, + language: &str, + text: &str, +) -> Vec { let mut ret = Vec::new(); let mut tokens = tokenize_text(text); - let Ok(language_query) = code.language_query(language).await else { - return vec![]; - }; - let Ok(body_query) = code.body_query(&tokens).await else { - return vec![]; - }; + let language_query = schema.language_query(language); + let body_query = schema.body_query(&tokens); let query = BooleanQuery::new(vec![ (Occur::Must, language_query), (Occur::Must, body_query), diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index bfafd0c..62c5ea3 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -16,7 +16,11 @@ use std::{ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; -use tabby_common::{config::Config, usage}; +use tabby_common::{ + api::code::{Hit, HitDocument, SearchResponse}, + config::Config, + usage, +}; use tabby_download::download_model; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; @@ -62,9 +66,9 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi crate::chat::ChatCompletionChunk, health::HealthState, health::Version, - crate::search::SearchResponse, - crate::search::Hit, - crate::search::HitDocument + SearchResponse, + Hit, + HitDocument )) )] struct ApiDoc; diff --git a/crates/tabby/src/serve/search.rs b/crates/tabby/src/serve/search.rs index 0644443..bdc742f 100644 --- a/crates/tabby/src/serve/search.rs +++ b/crates/tabby/src/serve/search.rs @@ -7,10 +7,11 @@ use axum::{ }; use hyper::StatusCode; use serde::Deserialize; +use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse}; use tracing::{instrument, warn}; use utoipa::IntoParams; -use crate::search::{CodeSearch, CodeSearchError, CodeSearchService, SearchResponse}; +use crate::search::CodeSearchService; #[derive(Deserialize, IntoParams)] pub struct SearchQuery {