feat: implement /v1beta/search interface (#516)

* feat: implement /v1beta/search interface

* update

* update

* improve debugger
wsxiaoys-patch-1
Meng Zhang 2023-10-06 11:54:12 -07:00 committed by GitHub
parent fd2a1ab865
commit 8497fb1372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 232 additions and 31 deletions

1
Cargo.lock generated
View File

@ -3142,6 +3142,7 @@ dependencies = [
"serde", "serde",
"serde-jsonlines", "serde-jsonlines",
"serdeconv", "serdeconv",
"tantivy",
"tokio", "tokio",
"uuid 1.4.1", "uuid 1.4.1",
] ]

View File

@ -13,6 +13,7 @@ serde-jsonlines = { workspace = true }
reqwest = { workspace = true, features = [ "json" ] } reqwest = { workspace = true, features = [ "json" ] }
tokio = { workspace = true, features = ["rt", "macros"] } tokio = { workspace = true, features = ["rt", "macros"] }
uuid = { version = "1.4.1", features = ["v4"] } uuid = { version = "1.4.1", features = ["v4"] }
tantivy.workspace = true
[features] [features]
testutils = [] testutils = []

View File

@ -0,0 +1,20 @@
use tantivy::{
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index,
};
pub trait IndexExt {
fn register_tokenizer(&self);
}
pub static CODE_TOKENIZER: &str = "code";
impl IndexExt for Index {
fn register_tokenizer(&self) {
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
.filter(RemoveLongFilter::limit(128))
.build();
self.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
}
}

View File

@ -1,5 +1,6 @@
pub mod config; pub mod config;
pub mod events; pub mod events;
pub mod index;
pub mod path; pub mod path;
pub mod usage; pub mod usage;

View File

@ -1,12 +1,16 @@
use std::fs; use std::fs;
use anyhow::Result; use anyhow::Result;
use tabby_common::{config::Config, path::index_dir, SourceFile}; use tabby_common::{
config::Config,
index::{IndexExt, CODE_TOKENIZER},
path::index_dir,
SourceFile,
};
use tantivy::{ use tantivy::{
directory::MmapDirectory, directory::MmapDirectory,
doc, doc,
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING}, schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index, Index,
}; };
@ -18,7 +22,7 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder(); let mut builder = Schema::builder();
let code_indexing_options = TextFieldIndexing::default() let code_indexing_options = TextFieldIndexing::default()
.set_tokenizer("code") .set_tokenizer(CODE_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions); .set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let code_options = TextOptions::default() let code_options = TextOptions::default()
.set_indexing_options(code_indexing_options) .set_indexing_options(code_indexing_options)
@ -36,11 +40,8 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
fs::create_dir_all(index_dir())?; fs::create_dir_all(index_dir())?;
let directory = MmapDirectory::open(index_dir())?; let directory = MmapDirectory::open(index_dir())?;
let index = Index::open_or_create(directory, schema)?; let index = Index::open_or_create(directory, schema)?;
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w*)").unwrap()) index.register_tokenizer();
.filter(RemoveLongFilter::limit(40))
.build();
index.tokenizers().register("code", code_tokenizer);
let mut writer = index.writer(10_000_000)?; let mut writer = index.writer(10_000_000)?;
writer.delete_all_documents()?; writer.delete_all_documents()?;

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
ctranslate2-bindings = { path = "../ctranslate2-bindings" } ctranslate2-bindings = { path = "../ctranslate2-bindings" }
tabby-common = { path = "../tabby-common" } tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler", optional = true } tabby-scheduler = { path = "../tabby-scheduler" }
tabby-download = { path = "../tabby-download" } tabby-download = { path = "../tabby-download" }
tabby-inference = { path = "../tabby-inference" } tabby-inference = { path = "../tabby-inference" }
axum = "0.6" axum = "0.6"
@ -53,9 +53,7 @@ features = [
] ]
[features] [features]
default = ["scheduler"]
link_shared = ["ctranslate2-bindings/link_shared"] link_shared = ["ctranslate2-bindings/link_shared"]
scheduler = ["tabby-scheduler"]
[build-dependencies] [build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

View File

@ -32,7 +32,6 @@ pub enum Commands {
Download(download::DownloadArgs), Download(download::DownloadArgs),
/// Run scheduler progress for cron jobs integrating external code repositories. /// Run scheduler progress for cron jobs integrating external code repositories.
#[cfg(feature = "scheduler")]
Scheduler(SchedulerArgs), Scheduler(SchedulerArgs),
} }
@ -53,7 +52,6 @@ async fn main() {
match &cli.command { match &cli.command {
Commands::Serve(args) => serve::main(&config, args).await, Commands::Serve(args) => serve::main(&config, args).await,
Commands::Download(args) => download::main(args).await, Commands::Download(args) => download::main(args).await,
#[cfg(feature = "scheduler")]
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now) Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await .await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)), .unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),

View File

@ -4,6 +4,7 @@ mod engine;
mod events; mod events;
mod health; mod health;
mod playground; mod playground;
mod search;
use std::{ use std::{
net::{Ipv4Addr, SocketAddr}, net::{Ipv4Addr, SocketAddr},
@ -28,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
use self::{ use self::{
engine::{create_engine, EngineInfo}, engine::{create_engine, EngineInfo},
health::HealthState, health::HealthState,
search::IndexServer,
}; };
use crate::fatal; use crate::fatal;
@ -48,7 +50,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers( servers(
(url = "/", description = "Server"), (url = "/", description = "Server"),
), ),
paths(events::log_event, completions::completions, chat::completions, health::health), paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
components(schemas( components(schemas(
events::LogEventRequest, events::LogEventRequest,
completions::CompletionRequest, completions::CompletionRequest,
@ -90,6 +92,10 @@ pub struct ServeArgs {
#[clap(long)] #[clap(long)]
chat_model: Option<String>, chat_model: Option<String>,
/// When set to `true`, the search API route will be enabled.
#[clap(long, default_value_t = false)]
enable_search: bool,
#[clap(long, default_value_t = 8080)] #[clap(long, default_value_t = 8080)]
port: u16, port: u16,
@ -195,6 +201,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
routing::post(completions::completions).with_state(completion_state), routing::post(completions::completions).with_state(completion_state),
); );
let router = if args.enable_search {
router.route(
"/v1beta/search",
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
)
} else {
router
};
let router = if let Some(chat_state) = chat_state { let router = if let Some(chat_state) = chat_state {
router.route( router.route(
"/v1beta/chat/completions", "/v1beta/chat/completions",

View File

@ -0,0 +1,144 @@
use std::sync::Arc;
use anyhow::Result;
use axum::{
extract::{Query, State},
Json,
};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{index::IndexExt, path};
use tantivy::{
collector::{Count, TopDocs},
query::QueryParser,
schema::{Field, FieldType, NamedFieldDocument, Schema},
DocAddress, Document, Index, IndexReader, Score,
};
use tracing::instrument;
use utoipa::IntoParams;
#[derive(Deserialize, IntoParams)]
pub struct SearchQuery {
#[param(default = "get")]
q: String,
#[param(default = 20)]
limit: Option<usize>,
#[param(default = 0)]
offset: Option<usize>,
}
#[derive(Serialize)]
pub struct SearchResponse {
q: String,
num_hits: usize,
hits: Vec<Hit>,
}
#[derive(Serialize)]
pub struct Hit {
score: Score,
doc: NamedFieldDocument,
id: u32,
}
#[utoipa::path(
get,
params(SearchQuery),
path = "/v1beta/search",
operation_id = "search",
tag = "v1beta",
responses(
(status = 200, description = "Success" , content_type = "application/json"),
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
)
)]
#[instrument(skip(state, query))]
pub async fn search(
State(state): State<Arc<IndexServer>>,
query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> {
let Ok(serp) = state.search(
&query.q,
query.limit.unwrap_or(20),
query.offset.unwrap_or(0),
) else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
Ok(Json(serp))
}
pub struct IndexServer {
reader: IndexReader,
query_parser: QueryParser,
schema: Schema,
}
impl IndexServer {
pub fn new() -> Self {
Self::load().expect("Failed to load code state")
}
fn load() -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?;
index.register_tokenizer();
let schema = index.schema();
let default_fields: Vec<Field> = schema
.fields()
.filter(|&(_, field_entry)| match field_entry.field_type() {
FieldType::Str(ref text_field_options) => {
text_field_options.get_indexing_options().is_some()
}
_ => false,
})
.map(|(field, _)| field)
.collect();
let query_parser =
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
let reader = index.reader()?;
Ok(Self {
reader,
query_parser,
schema,
})
}
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
let query = self
.query_parser
.parse_query(q)
.expect("Parsing the query failed");
let searcher = self.reader.searcher();
let (top_docs, num_hits) = {
searcher.search(
&query,
&(TopDocs::with_limit(limit).and_offset(offset), Count),
)?
};
let hits: Vec<Hit> = {
top_docs
.iter()
.map(|(score, doc_address)| {
let doc = searcher.doc(*doc_address).unwrap();
self.create_hit(*score, doc, *doc_address)
})
.collect()
};
Ok(SearchResponse {
q: q.to_owned(),
num_hits,
hits,
})
}
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
Hit {
score,
doc: self.schema.to_named_doc(&doc),
id: doc_address.doc_id,
}
}
}

View File

@ -0,0 +1,38 @@
import re
import requests
import streamlit as st
from typing import NamedTuple
class Doc(NamedTuple):
name: str
body: str
score: float
filepath: str
@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"],
filepath=doc["filepath"][0],
)
# force wide mode
st.set_page_config(layout="wide")
language = st.text_input("Language", "rust")
query = st.text_area("Query", "get")
tokens = re.findall(r"\w+", query)
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
if query:
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"]
for x in hits:
doc = Doc.from_json(x)
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
st.code(doc.body)

View File

@ -2,29 +2,13 @@ import requests
import streamlit as st import streamlit as st
from typing import NamedTuple from typing import NamedTuple
class Doc(NamedTuple):
name: str
body: str
score: float
@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"]
)
# force wide mode # force wide mode
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
query = st.text_input("Query") query = st.text_input("Query")
if query: if query:
r = requests.get("http://localhost:3000/api", params=dict(q=query)) r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"] hits = r.json()["hits"]
for x in hits: for x in hits:
doc = Doc.from_json(x) st.write(x)
st.write(doc.name + " : " + str(doc.score))
st.code(doc.body)