feat: implement /v1beta/search interface (#516)
* feat: implement /v1beta/search interface * update * update * improve debuggerwsxiaoys-patch-1
parent
fd2a1ab865
commit
8497fb1372
|
|
@ -3142,6 +3142,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde-jsonlines",
|
||||
"serdeconv",
|
||||
"tantivy",
|
||||
"tokio",
|
||||
"uuid 1.4.1",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ serde-jsonlines = { workspace = true }
|
|||
reqwest = { workspace = true, features = [ "json" ] }
|
||||
tokio = { workspace = true, features = ["rt", "macros"] }
|
||||
uuid = { version = "1.4.1", features = ["v4"] }
|
||||
tantivy.workspace = true
|
||||
|
||||
[features]
|
||||
testutils = []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
use tantivy::{
|
||||
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
|
||||
Index,
|
||||
};
|
||||
|
||||
pub trait IndexExt {
|
||||
fn register_tokenizer(&self);
|
||||
}
|
||||
|
||||
pub static CODE_TOKENIZER: &str = "code";
|
||||
|
||||
impl IndexExt for Index {
|
||||
fn register_tokenizer(&self) {
|
||||
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w+)").unwrap())
|
||||
.filter(RemoveLongFilter::limit(128))
|
||||
.build();
|
||||
|
||||
self.tokenizers().register(CODE_TOKENIZER, code_tokenizer);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod config;
|
||||
pub mod events;
|
||||
pub mod index;
|
||||
pub mod path;
|
||||
pub mod usage;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
use std::fs;
|
||||
|
||||
use anyhow::Result;
|
||||
use tabby_common::{config::Config, path::index_dir, SourceFile};
|
||||
use tabby_common::{
|
||||
config::Config,
|
||||
index::{IndexExt, CODE_TOKENIZER},
|
||||
path::index_dir,
|
||||
SourceFile,
|
||||
};
|
||||
use tantivy::{
|
||||
directory::MmapDirectory,
|
||||
doc,
|
||||
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
|
||||
tokenizer::{RegexTokenizer, RemoveLongFilter, TextAnalyzer},
|
||||
Index,
|
||||
};
|
||||
|
||||
|
|
@ -18,7 +22,7 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
|
|||
let mut builder = Schema::builder();
|
||||
|
||||
let code_indexing_options = TextFieldIndexing::default()
|
||||
.set_tokenizer("code")
|
||||
.set_tokenizer(CODE_TOKENIZER)
|
||||
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
|
||||
let code_options = TextOptions::default()
|
||||
.set_indexing_options(code_indexing_options)
|
||||
|
|
@ -36,11 +40,8 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
|
|||
fs::create_dir_all(index_dir())?;
|
||||
let directory = MmapDirectory::open(index_dir())?;
|
||||
let index = Index::open_or_create(directory, schema)?;
|
||||
let code_tokenizer = TextAnalyzer::builder(RegexTokenizer::new(r"(?:\w*)").unwrap())
|
||||
.filter(RemoveLongFilter::limit(40))
|
||||
.build();
|
||||
index.register_tokenizer();
|
||||
|
||||
index.tokenizers().register("code", code_tokenizer);
|
||||
let mut writer = index.writer(10_000_000)?;
|
||||
writer.delete_all_documents()?;
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
tabby-scheduler = { path = "../tabby-scheduler", optional = true }
|
||||
tabby-scheduler = { path = "../tabby-scheduler" }
|
||||
tabby-download = { path = "../tabby-download" }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
axum = "0.6"
|
||||
|
|
@ -53,9 +53,7 @@ features = [
|
|||
]
|
||||
|
||||
[features]
|
||||
default = ["scheduler"]
|
||||
link_shared = ["ctranslate2-bindings/link_shared"]
|
||||
scheduler = ["tabby-scheduler"]
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ pub enum Commands {
|
|||
Download(download::DownloadArgs),
|
||||
|
||||
/// Run scheduler progress for cron jobs integrating external code repositories.
|
||||
#[cfg(feature = "scheduler")]
|
||||
Scheduler(SchedulerArgs),
|
||||
}
|
||||
|
||||
|
|
@ -53,7 +52,6 @@ async fn main() {
|
|||
match &cli.command {
|
||||
Commands::Serve(args) => serve::main(&config, args).await,
|
||||
Commands::Download(args) => download::main(args).await,
|
||||
#[cfg(feature = "scheduler")]
|
||||
Commands::Scheduler(args) => tabby_scheduler::scheduler(args.now)
|
||||
.await
|
||||
.unwrap_or_else(|err| fatal!("Scheduler failed due to '{}'", err)),
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ mod engine;
|
|||
mod events;
|
||||
mod health;
|
||||
mod playground;
|
||||
mod search;
|
||||
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
|
|
@ -28,6 +29,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||
use self::{
|
||||
engine::{create_engine, EngineInfo},
|
||||
health::HealthState,
|
||||
search::IndexServer,
|
||||
};
|
||||
use crate::fatal;
|
||||
|
||||
|
|
@ -48,7 +50,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
|
|||
servers(
|
||||
(url = "/", description = "Server"),
|
||||
),
|
||||
paths(events::log_event, completions::completions, chat::completions, health::health),
|
||||
paths(events::log_event, completions::completions, chat::completions, health::health, search::search),
|
||||
components(schemas(
|
||||
events::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
|
|
@ -90,6 +92,10 @@ pub struct ServeArgs {
|
|||
#[clap(long)]
|
||||
chat_model: Option<String>,
|
||||
|
||||
/// When set to `true`, the search API route will be enabled.
|
||||
#[clap(long, default_value_t = false)]
|
||||
enable_search: bool,
|
||||
|
||||
#[clap(long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
|
|
@ -195,6 +201,15 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
|
|||
routing::post(completions::completions).with_state(completion_state),
|
||||
);
|
||||
|
||||
let router = if args.enable_search {
|
||||
router.route(
|
||||
"/v1beta/search",
|
||||
routing::get(search::search).with_state(Arc::new(IndexServer::new())),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
let router = if let Some(chat_state) = chat_state {
|
||||
router.route(
|
||||
"/v1beta/chat/completions",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,144 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
Json,
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{index::IndexExt, path};
|
||||
use tantivy::{
|
||||
collector::{Count, TopDocs},
|
||||
query::QueryParser,
|
||||
schema::{Field, FieldType, NamedFieldDocument, Schema},
|
||||
DocAddress, Document, Index, IndexReader, Score,
|
||||
};
|
||||
use tracing::instrument;
|
||||
use utoipa::IntoParams;
|
||||
|
||||
#[derive(Deserialize, IntoParams)]
|
||||
pub struct SearchQuery {
|
||||
#[param(default = "get")]
|
||||
q: String,
|
||||
|
||||
#[param(default = 20)]
|
||||
limit: Option<usize>,
|
||||
|
||||
#[param(default = 0)]
|
||||
offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SearchResponse {
|
||||
q: String,
|
||||
num_hits: usize,
|
||||
hits: Vec<Hit>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct Hit {
|
||||
score: Score,
|
||||
doc: NamedFieldDocument,
|
||||
id: u32,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
params(SearchQuery),
|
||||
path = "/v1beta/search",
|
||||
operation_id = "search",
|
||||
tag = "v1beta",
|
||||
responses(
|
||||
(status = 200, description = "Success" , content_type = "application/json"),
|
||||
(status = 405, description = "When code search is not enabled, the endpoint will returns 405 Method Not Allowed"),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(state, query))]
|
||||
pub async fn search(
|
||||
State(state): State<Arc<IndexServer>>,
|
||||
query: Query<SearchQuery>,
|
||||
) -> Result<Json<SearchResponse>, StatusCode> {
|
||||
let Ok(serp) = state.search(
|
||||
&query.q,
|
||||
query.limit.unwrap_or(20),
|
||||
query.offset.unwrap_or(0),
|
||||
) else {
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
};
|
||||
|
||||
Ok(Json(serp))
|
||||
}
|
||||
|
||||
pub struct IndexServer {
|
||||
reader: IndexReader,
|
||||
query_parser: QueryParser,
|
||||
schema: Schema,
|
||||
}
|
||||
|
||||
impl IndexServer {
|
||||
pub fn new() -> Self {
|
||||
Self::load().expect("Failed to load code state")
|
||||
}
|
||||
|
||||
fn load() -> Result<Self> {
|
||||
let index = Index::open_in_dir(path::index_dir())?;
|
||||
index.register_tokenizer();
|
||||
|
||||
let schema = index.schema();
|
||||
let default_fields: Vec<Field> = schema
|
||||
.fields()
|
||||
.filter(|&(_, field_entry)| match field_entry.field_type() {
|
||||
FieldType::Str(ref text_field_options) => {
|
||||
text_field_options.get_indexing_options().is_some()
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
.map(|(field, _)| field)
|
||||
.collect();
|
||||
let query_parser =
|
||||
QueryParser::new(schema.clone(), default_fields, index.tokenizers().clone());
|
||||
let reader = index.reader()?;
|
||||
Ok(Self {
|
||||
reader,
|
||||
query_parser,
|
||||
schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn search(&self, q: &str, limit: usize, offset: usize) -> tantivy::Result<SearchResponse> {
|
||||
let query = self
|
||||
.query_parser
|
||||
.parse_query(q)
|
||||
.expect("Parsing the query failed");
|
||||
let searcher = self.reader.searcher();
|
||||
let (top_docs, num_hits) = {
|
||||
searcher.search(
|
||||
&query,
|
||||
&(TopDocs::with_limit(limit).and_offset(offset), Count),
|
||||
)?
|
||||
};
|
||||
let hits: Vec<Hit> = {
|
||||
top_docs
|
||||
.iter()
|
||||
.map(|(score, doc_address)| {
|
||||
let doc = searcher.doc(*doc_address).unwrap();
|
||||
self.create_hit(*score, doc, *doc_address)
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
Ok(SearchResponse {
|
||||
q: q.to_owned(),
|
||||
num_hits,
|
||||
hits,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_hit(&self, score: Score, doc: Document, doc_address: DocAddress) -> Hit {
|
||||
Hit {
|
||||
score,
|
||||
doc: self.schema.to_named_doc(&doc),
|
||||
id: doc_address.doc_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import re
|
||||
import requests
|
||||
import streamlit as st
|
||||
from typing import NamedTuple
|
||||
|
||||
class Doc(NamedTuple):
|
||||
name: str
|
||||
body: str
|
||||
score: float
|
||||
filepath: str
|
||||
|
||||
@staticmethod
|
||||
def from_json(json: dict):
|
||||
doc = json["doc"]
|
||||
return Doc(
|
||||
name=doc["name"][0],
|
||||
body=doc["body"][0],
|
||||
score=json["score"],
|
||||
filepath=doc["filepath"][0],
|
||||
)
|
||||
|
||||
# force wide mode
|
||||
st.set_page_config(layout="wide")
|
||||
|
||||
language = st.text_input("Language", "rust")
|
||||
query = st.text_area("Query", "get")
|
||||
tokens = re.findall(r"\w+", query)
|
||||
tokens = [x for x in tokens if x != "AND" and x != "OR" and x != "NOT"]
|
||||
|
||||
query = "(" + " ".join(tokens) + ")" + " " + "AND language:" + language
|
||||
|
||||
if query:
|
||||
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
|
||||
hits = r.json()["hits"]
|
||||
for x in hits:
|
||||
doc = Doc.from_json(x)
|
||||
st.write(doc.name + "@" + doc.filepath + " : " + str(doc.score))
|
||||
st.code(doc.body)
|
||||
|
|
@ -2,29 +2,13 @@ import requests
|
|||
import streamlit as st
|
||||
from typing import NamedTuple
|
||||
|
||||
class Doc(NamedTuple):
|
||||
name: str
|
||||
body: str
|
||||
score: float
|
||||
|
||||
@staticmethod
|
||||
def from_json(json: dict):
|
||||
doc = json["doc"]
|
||||
return Doc(
|
||||
name=doc["name"][0],
|
||||
body=doc["body"][0],
|
||||
score=json["score"]
|
||||
)
|
||||
|
||||
# force wide mode
|
||||
st.set_page_config(layout="wide")
|
||||
|
||||
query = st.text_input("Query")
|
||||
|
||||
if query:
|
||||
r = requests.get("http://localhost:3000/api", params=dict(q=query))
|
||||
r = requests.get("http://localhost:8080/v1beta/search", params=dict(q=query))
|
||||
hits = r.json()["hits"]
|
||||
for x in hits:
|
||||
doc = Doc.from_json(x)
|
||||
st.write(doc.name + " : " + str(doc.score))
|
||||
st.code(doc.body)
|
||||
st.write(x)
|
||||
|
|
|
|||
Loading…
Reference in New Issue