feat: implement /v1beta/search interface (#516)

* feat: implement /v1beta/search interface

* update

* update

* improve debugger
wsxiaoys-patch-1
Meng Zhang 2023-10-06 11:54:12 -07:00 committed by GitHub
parent fd2a1ab865
commit 8497fb1372
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 232 additions and 31 deletions

1
Cargo.lock generated
View File

@ -3142,6 +3142,7 @@ dependencies = [
"serde",
"serde-jsonlines",
"serdeconv",
"tantivy",
"tokio",
"uuid 1.4.1",
]

View File

@ -13,6 +13,7 @@ serde-jsonlines = { workspace = true }
reqwest = { workspace = true, features = [ "json" ] }
tokio = { workspace = true, features = ["rt", "macros"] }
uuid = { version = "1.4.1", features = ["v4"] }
tantivy.workspace = true
[features]
testutils = []

View File

@ -0,0 +1,20 @@
use tantivy::{
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index,
};
pub trait IndexExt {
fn register_tokenizer(&self);
}
pub static CODE_TOKENIZER: &str = "code";
impl IndexExt for Index {
fn register_tokenizer(&self) {
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
.filter(RemoveLongFilter::limit(128))
.build();
self.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
}
}

View File

@ -1,5 +1,6 @@
pub mod config;
pub mod events;
pub mod index;
pub mod path;
pub mod usage;

View File

@ -1,12 +1,16 @@
use std::fs;
use anyhow::Result;
use tabby_common::{config::Config, path::index_dir, SourceFile};
use tabby_common::{
config::Config,
index::{IndexExt, CODE_TOKENIZER},
path::index_dir,
SourceFile,
};
use tantivy::{
directory::MmapDirectory,
doc,
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
Index,
};
@ -18,7 +22,7 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder();
let code_indexing_options = TextFieldIndexing::default()
.set_tokenizer("code")
.set_tokenizer(CODE_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let code_options = TextOptions::default()
.set_indexing_options(code_indexing_options)
@ -36,11 +40,8 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
fs::create_dir_all(index_dir())?;
let directory = MmapDirectory::open(index_dir())?;
let index = Index::open_or_create(directory, schema)?;
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w*)").unwrap())
.filter(RemoveLongFilter::limit(40))
.build();
index.register_tokenizer();
index.tokenizers().register("code", code_tokenizer);
let mut writer = index.writer(10_000_000)?;
writer.delete_all_documents()?;

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler", optional = true }
tabby-scheduler = { path = "../tabby-scheduler" }
tabby-download = { path = "../tabby-download" }
tabby-inference = { path = "../tabby-inference" }
axum = "0.6"
@ -53,9 +53,7 @@ features = [
]
[features]
default = ["scheduler"]
link_shared = ["ctranslate2-bindings/link_shared"]
scheduler = ["tabby-scheduler"]
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

View File

@ -32,7 +32,6 @@ pub enum Commands {
Download(download::DownloadArgs),
/// Run scheduler progress for cron jobs integrating external code repositories.
#[cfg(feature = "scheduler")]
Scheduler(SchedulerArgs),
}
@ -53,7 +52,6 @@ async fn main() {
match &cli.command {
Commands::Serve(args) => serve::main(&config, args).await,
Commands::Download(args) => download::main(args).await,
#[cfg(feature = "scheduler")]
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
.await
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),

View File

@ -4,6 +4,7 @@ mod engine;
mod events;
mod health;
mod playground;
mod search;
use std::{
net::{Ipv4Addr, SocketAddr},
@ -28,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
use self::{
engine::{create_engine, EngineInfo},
health::HealthState,
search::IndexServer,
};
use crate::fatal;
@ -48,7 +50,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
servers(
(url = "/", description = "Server"),
),
paths(events::log_event, completions::completions, chat::completions, health::health),
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
components(schemas(
events::LogEventRequest,
completions::CompletionRequest,
@ -90,6 +92,10 @@ pub struct ServeArgs {
#[clap(long)]
chat_model: Option<String>,
/// When set to `true`, the search API route will be enabled.
#[clap(long, default_value_t = false)]
enable_search: bool,
#[clap(long, default_value_t = 8080)]
port: u16,
@ -195,6 +201,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
routing::post(completions::completions).with_state(completion_state),
);
let router = if args.enable_search {
router.route(
"/v1beta/search",
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
)
} else {
router
};
let router = if let Some(chat_state) = chat_state {
router.route(
"/v1beta/chat/completions",

View File

@ -0,0 +1,144 @@
use std::sync::Arc;
use anyhow::Result;
use axum::{
extract::{Query, State},
Json,
};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::{index::IndexExt, path};
use tantivy::{
collector::{Count, TopDocs},
query::QueryParser,
schema::{Field, FieldType, NamedFieldDocument, Schema},
DocAddress, Document, Index, IndexReader, Score,
};
use tracing::instrument;
use utoipa::IntoParams;
#[derive(Deserialize, IntoParams)]
pub struct SearchQuery {
#[param(default = "get")]
q: String,
#[param(default = 20)]
limit: Option<usize>,
#[param(default = 0)]
offset: Option<usize>,
}
#[derive(Serialize)]
pub struct SearchResponse {
q: String,
num_hits: usize,
hits: Vec<Hit>,
}
#[derive(Serialize)]
pub struct Hit {
score: Score,
doc: NamedFieldDocument,
id: u32,
}
#[utoipa::path(
get,
params(SearchQuery),
path = "/v1beta/search",
operation_id = "search",
tag = "v1beta",
responses(
(status = 200, description = "Success" , content_type = "application/json"),
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
)
)]
#[instrument(skip(state, query))]
pub async fn search(
State(state): State<Arc<IndexServer>>,
query: Query<SearchQuery>,
) -> Result<Json<SearchResponse>, StatusCode> {
let Ok(serp) = state.search(
&query.q,
query.limit.unwrap_or(20),
query.offset.unwrap_or(0),
) else {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
};
Ok(Json(serp))
}
pub struct IndexServer {
reader: IndexReader,
query_parser: QueryParser,
schema: Schema,
}
impl IndexServer {
pub fn new() -> Self {
Self::load().expect("Failed to load code state")
}
fn load() -> Result<Self> {
let index = Index::open_in_dir(path::index_dir())?;
index.register_tokenizer();
let schema = index.schema();
let default_fields: Vec<Field> = schema
.fields()
.filter(|&(_, field_entry)| match field_entry.field_type() {
FieldType::Str(ref text_field_options) => {
text_field_options.get_indexing_options().is_some()
}
_ => false,
})
.map(|(field, _)| field)
.collect();
let query_parser =
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
let reader = index.reader()?;
Ok(Self {
reader,
query_parser,
schema,
})
}
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
let query = self
.query_parser
.parse_query(q)
.expect("Parsing the query failed");
let searcher = self.reader.searcher();
let (top_docs, num_hits) = {
searcher.search(
&query,
&(TopDocs::with_limit(limit).and_offset(offset), Count),
)?
};
let hits: Vec<Hit> = {
top_docs
.iter()
.map(|(score, doc_address)| {
let doc = searcher.doc(*doc_address).unwrap();
self.create_hit(*score, doc, *doc_address)
})
.collect()
};
Ok(SearchResponse {
q: q.to_owned(),
num_hits,
hits,
})
}
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
Hit {
score,
doc: self.schema.to_named_doc(&doc),
id: doc_address.doc_id,
}
}
}

View File

@ -0,0 +1,38 @@
import re
import requests
import streamlit as st
from typing import NamedTuple
class Doc(NamedTuple):
name: str
body: str
score: float
filepath: str
@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"],
filepath=doc["filepath"][0],
)
# force wide mode
st.set_page_config(layout="wide")
language = st.text_input("Language", "rust")
query = st.text_area("Query", "get")
tokens = re.findall(r"\w+", query)
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
if query:
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"]
for x in hits:
doc = Doc.from_json(x)
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
st.code(doc.body)

View File

@ -2,29 +2,13 @@ import requests
import streamlit as st
from typing import NamedTuple
class Doc(NamedTuple):
name: str
body: str
score: float
@staticmethod
def from_json(json: dict):
doc = json["doc"]
return Doc(
name=doc["name"][0],
body=doc["body"][0],
score=json["score"]
)
# force wide mode
st.set_page_config(layout="wide")
query = st.text_input("Query")
if query:
r = requests.get("http://localhost:3000/api", params=dict(q=query))
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
hits = r.json()["hits"]
for x in hits:
doc = Doc.from_json(x)
st.write(doc.name + " : " + str(doc.score))
st.code(doc.body)
st.write(x)