feat: use remote code search in workers (#833)

* refactor: lift tabby::api::code to tabby_common::api::code

* feat: use remote code search in workers

* update

* handle errors
release-fix-intellij-update-support-version-range
Meng Zhang 2023-11-18 15:45:00 -08:00 committed by GitHub
parent b862d9d100
commit a12f741565
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 152 additions and 33 deletions

3
Cargo.lock generated
View File

@ -4469,6 +4469,7 @@ name = "tabby-common"
version = "0.6.0-dev"
dependencies = [
"anyhow",
"async-trait",
"filenamify",
"glob",
"lazy_static",
@ -4478,6 +4479,7 @@ dependencies = [
"serde_json",
"serdeconv",
"tantivy",
"thiserror",
"utoipa",
"uuid 1.4.1",
]
@ -4541,6 +4543,7 @@ name = "tabby-webserver"
version = "0.6.0-dev"
dependencies = [
"anyhow",
"async-trait",
"axum",
"bincode",
"chrono",

View File

@ -16,6 +16,8 @@ anyhow.workspace = true
glob = "0.3.1"
utoipa.workspace = true
serde_json.workspace = true
async-trait.workspace = true
thiserror.workspace = true
[features]
testutils = []

View File

@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, Debug, ToSchema)]
#[derive(Default, Serialize, Deserialize, Debug, ToSchema)]
pub struct SearchResponse {
pub num_hits: usize,
pub hits: Vec<Hit>,
@ -31,10 +31,10 @@ pub enum CodeSearchError {
#[error("index not ready")]
NotReady,
#[error("{0}")]
#[error(transparent)]
QueryParserError(#[from] tantivy::query::QueryParserError),
#[error("{0}")]
#[error(transparent)]
TantivyError(#[from] tantivy::TantivyError),
}

View File

@ -1 +1,2 @@
pub mod code;
pub mod event;

View File

@ -1 +0,0 @@
pub mod code;

View File

@ -1,4 +1,3 @@
mod api;
mod routes;
mod services;

View File

@ -7,11 +7,10 @@ use axum::{
};
use hyper::StatusCode;
use serde::Deserialize;
use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse};
use tracing::{instrument, warn};
use utoipa::IntoParams;
use crate::api::code::{CodeSearch, CodeSearchError, SearchResponse};
#[derive(Deserialize, IntoParams)]
pub struct SearchQuery {
#[param(default = "get")]

View File

@ -7,7 +7,12 @@ use std::{
use axum::{routing, Router, Server};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use clap::Args;
use tabby_common::{api::event::EventLogger, config::Config, usage};
use tabby_common::{
api,
api::{code::CodeSearch, event::EventLogger},
config::Config,
usage,
};
use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::info;
@ -15,10 +20,10 @@ use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
use crate::{
api::{self},
fatal, routes,
services::{
chat::{self, create_chat_service},
code::create_code_search,
completion::{self, create_completion_service},
event::create_logger,
health,
@ -46,7 +51,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi
),
paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search),
components(schemas(
tabby_common::api::event::LogEventRequest,
api::event::LogEventRequest,
completion::CompletionRequest,
completion::CompletionResponse,
completion::Segments,
@ -102,13 +107,14 @@ pub async fn main(config: &Config, args: &ServeArgs) {
info!("Starting server, this might takes a few minutes...");
let logger = Arc::new(create_logger());
let code = Arc::new(create_code_search());
let app = Router::new()
.merge(api_router(args, config, logger.clone()).await)
.merge(api_router(args, config, logger.clone(), code.clone()).await)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()));
#[cfg(feature = "ee")]
let app = tabby_webserver::attach_webserver(app, logger).await;
let app = tabby_webserver::attach_webserver(app, logger, code).await;
#[cfg(not(feature = "ee"))]
let app = app.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") });
@ -133,9 +139,12 @@ async fn load_model(args: &ServeArgs) {
}
}
async fn api_router(args: &ServeArgs, config: &Config, logger: Arc<dyn EventLogger>) -> Router {
let code = Arc::new(crate::services::code::create_code_search());
async fn api_router(
args: &ServeArgs,
config: &Config,
logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>,
) -> Router {
let completion_state = if let Some(model) = &args.model {
Some(Arc::new(
create_completion_service(

View File

@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration};
use anyhow::Result;
use async_trait::async_trait;
use tabby_common::{
api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse},
index::{self, register_tokenizers, CodeSearchSchema},
path,
};
@ -16,8 +17,6 @@ use tantivy::{
use tokio::{sync::Mutex, time::sleep};
use tracing::{debug, log::info};
use crate::api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse};
struct CodeSearchImpl {
reader: IndexReader,
query_parser: QueryParser,

View File

@ -5,7 +5,10 @@ use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tabby_common::{
api,
api::event::{Event, EventLogger},
api::{
code::CodeSearch,
event::{Event, EventLogger},
},
languages::get_language,
};
use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder};
@ -14,7 +17,7 @@ use tracing::debug;
use utoipa::ToSchema;
use super::model;
use crate::{api::code::CodeSearch, Device};
use crate::Device;
#[derive(Error, Debug)]
pub enum CompletionError {

View File

@ -3,12 +3,14 @@ use std::sync::Arc;
use lazy_static::lazy_static;
use regex::Regex;
use strfmt::strfmt;
use tabby_common::languages::get_language;
use tabby_common::{
api::code::{CodeSearch, CodeSearchError},
languages::get_language,
};
use textdistance::Algorithm;
use tracing::warn;
use super::{Segments, Snippet};
use crate::api::code::{CodeSearch, CodeSearchError};
static MAX_SNIPPETS_TO_FETCH: usize = 20;
static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768;

View File

@ -8,7 +8,6 @@ use anyhow::Result;
use axum::{routing, Router};
use clap::Args;
use hyper::Server;
use tabby_common::api::event::EventLogger;
use tabby_webserver::api::{tracing_context, HubClient, WorkerKind};
use tracing::{info, warn};
@ -16,7 +15,6 @@ use crate::{
fatal, routes,
services::{
chat::create_chat_service,
code,
completion::create_completion_service,
health::{read_cpu_info, read_cuda_devices},
model::download_model_if_needed,
@ -66,8 +64,8 @@ async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router {
async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router {
context.register(WorkerKind::Completion, args).await;
let code = Arc::new(code::create_code_search());
let logger: Arc<dyn EventLogger> = Arc::new(context.client);
let code = Arc::new(context.client.clone());
let logger = Arc::new(context.client);
let completion_state = Arc::new(
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
);

View File

@ -7,6 +7,7 @@ homepage.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
axum = { workspace = true, features = ["ws"] }
bincode = "1.3.3"
chrono = "0.4"

View File

@ -1,6 +1,10 @@
use async_trait::async_trait;
use juniper::{GraphQLEnum, GraphQLObject};
use serde::{Deserialize, Serialize};
use tabby_common::api::event::RawEventLogger;
use tabby_common::api::{
code::{CodeSearch, CodeSearchError, SearchResponse},
event::RawEventLogger,
};
use thiserror::Error;
use tokio_tungstenite::connect_async;
@ -48,6 +52,15 @@ pub trait Hub {
) -> Result<Worker, HubError>;
async fn log_event(content: String);
async fn search(q: String, limit: usize, offset: usize) -> SearchResponse;
async fn search_in_language(
language: String,
tokens: Vec<String>,
limit: usize,
offset: usize,
) -> SearchResponse;
}
pub fn tracing_context() -> tarpc::context::Context {
@ -68,3 +81,43 @@ impl RawEventLogger for HubClient {
tokio::spawn(async move { client.log_event(context, content).await });
}
}
#[async_trait]
impl CodeSearch for HubClient {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
match self
.search(tracing_context(), q.to_owned(), limit, offset)
.await
{
Ok(serp) => Ok(serp),
Err(_) => Err(CodeSearchError::NotReady),
}
}
async fn search_in_language(
&self,
language: &str,
tokens: &[String],
limit: usize,
offset: usize,
) -> Result<SearchResponse, CodeSearchError> {
match self
.search_in_language(
tracing_context(),
language.to_owned(),
tokens.to_owned(),
limit,
offset,
)
.await
{
Ok(serp) => Ok(serp),
Err(_) => Err(CodeSearchError::NotReady),
}
}
}

View File

@ -2,8 +2,11 @@ pub mod api;
mod schema;
pub use schema::create_schema;
use tabby_common::api::event::RawEventLogger;
use tracing::error;
use tabby_common::api::{
code::{CodeSearch, SearchResponse},
event::RawEventLogger,
};
use tracing::{error, warn};
use websocket::WebSocketTransport;
mod db;
@ -27,9 +30,13 @@ use schema::Schema;
use server::ServerContext;
use tarpc::server::{BaseChannel, Channel};
pub async fn attach_webserver(router: Router, logger: Arc<dyn RawEventLogger>) -> Router {
pub async fn attach_webserver(
router: Router,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Router {
let conn = db::DbConn::new().await.unwrap();
let ctx = Arc::new(ServerContext::new(conn, logger));
let ctx = Arc::new(ServerContext::new(conn, logger, code));
let schema = Arc::new(create_schema());
let app = Router::new()
@ -129,4 +136,42 @@ impl Hub for Arc<HubImpl> {
async fn log_event(self, _context: tarpc::context::Context, content: String) {
self.ctx.logger.log(content)
}
async fn search(
self,
_context: tarpc::context::Context,
q: String,
limit: usize,
offset: usize,
) -> SearchResponse {
match self.ctx.code.search(&q, limit, offset).await {
Ok(serp) => serp,
Err(err) => {
warn!("Failed to search: {}", err);
SearchResponse::default()
}
}
}
async fn search_in_language(
self,
_context: tarpc::context::Context,
language: String,
tokens: Vec<String>,
limit: usize,
offset: usize,
) -> SearchResponse {
match self
.ctx
.code
.search_in_language(&language, &tokens, limit, offset)
.await
{
Ok(serp) => serp,
Err(err) => {
warn!("Failed to search: {}", err);
SearchResponse::default()
}
}
}
}

View File

@ -6,7 +6,7 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Result;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tabby_common::api::event::RawEventLogger;
use tabby_common::api::{code::CodeSearch, event::RawEventLogger};
use tracing::{info, warn};
use crate::{
@ -21,16 +21,22 @@ pub struct ServerContext {
db_conn: DbConn,
pub logger: Arc<dyn RawEventLogger>,
pub code: Arc<dyn CodeSearch>,
}
impl ServerContext {
pub fn new(db_conn: DbConn, logger: Arc<dyn RawEventLogger>) -> Self {
pub fn new(
db_conn: DbConn,
logger: Arc<dyn RawEventLogger>,
code: Arc<dyn CodeSearch>,
) -> Self {
Self {
client: Client::default(),
completion: worker::WorkerGroup::default(),
chat: worker::WorkerGroup::default(),
db_conn,
logger,
code,
}
}