diff --git a/Cargo.lock b/Cargo.lock index 48808e7..1bf5e4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4469,6 +4469,7 @@ name = "tabby-common" version = "0.6.0-dev" dependencies = [ "anyhow", + "async-trait", "filenamify", "glob", "lazy_static", @@ -4478,6 +4479,7 @@ dependencies = [ "serde_json", "serdeconv", "tantivy", + "thiserror", "utoipa", "uuid 1.4.1", ] @@ -4541,6 +4543,7 @@ name = "tabby-webserver" version = "0.6.0-dev" dependencies = [ "anyhow", + "async-trait", "axum", "bincode", "chrono", diff --git a/crates/tabby-common/Cargo.toml b/crates/tabby-common/Cargo.toml index c39af3a..755559d 100644 --- a/crates/tabby-common/Cargo.toml +++ b/crates/tabby-common/Cargo.toml @@ -16,6 +16,8 @@ anyhow.workspace = true glob = "0.3.1" utoipa.workspace = true serde_json.workspace = true +async-trait.workspace = true +thiserror.workspace = true [features] testutils = [] @@ -24,4 +26,4 @@ testutils = [] ignored = [ # required in utoipa ToSchema. "serde_json" -] \ No newline at end of file +] diff --git a/crates/tabby/src/api/code.rs b/crates/tabby-common/src/api/code.rs similarity index 91% rename from crates/tabby/src/api/code.rs rename to crates/tabby-common/src/api/code.rs index 2f07329..c44c650 100644 --- a/crates/tabby/src/api/code.rs +++ b/crates/tabby-common/src/api/code.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use utoipa::ToSchema; -#[derive(Serialize, Deserialize, Debug, ToSchema)] +#[derive(Default, Serialize, Deserialize, Debug, ToSchema)] pub struct SearchResponse { pub num_hits: usize, pub hits: Vec, @@ -31,10 +31,10 @@ pub enum CodeSearchError { #[error("index not ready")] NotReady, - #[error("{0}")] + #[error(transparent)] QueryParserError(#[from] tantivy::query::QueryParserError), - #[error("{0}")] + #[error(transparent)] TantivyError(#[from] tantivy::TantivyError), } diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs index 53f1126..cebf170 100644 --- a/crates/tabby-common/src/api/mod.rs +++ b/crates/tabby-common/src/api/mod.rs @@ -1 +1,2 @@ +pub mod code; pub mod event; diff --git a/crates/tabby/src/api/mod.rs b/crates/tabby/src/api/mod.rs deleted file mode 100644 index 9de50d4..0000000 --- a/crates/tabby/src/api/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod code; diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 9ef0e88..fe7c4a4 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -1,4 +1,3 @@ -mod api; mod routes; mod services; diff --git a/crates/tabby/src/routes/search.rs b/crates/tabby/src/routes/search.rs index 89932b1..99edade 100644 --- a/crates/tabby/src/routes/search.rs +++ b/crates/tabby/src/routes/search.rs @@ -7,11 +7,10 @@ use axum::{ }; use hyper::StatusCode; use serde::Deserialize; +use tabby_common::api::code::{CodeSearch, CodeSearchError, SearchResponse}; use tracing::{instrument, warn}; use utoipa::IntoParams; -use crate::api::code::{CodeSearch, CodeSearchError, SearchResponse}; - #[derive(Deserialize, IntoParams)] pub struct SearchQuery { #[param(default = "get")] diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index ab15b68..a439cf8 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -7,7 +7,12 @@ use std::{ use axum::{routing, Router, Server}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use clap::Args; -use tabby_common::{api::event::EventLogger, config::Config, usage}; +use tabby_common::{ + api, + api::{code::CodeSearch, event::EventLogger}, + config::Config, + usage, +}; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; use tracing::info; @@ -15,10 +20,10 @@ use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; use crate::{ - api::{self}, fatal, routes, services::{ chat::{self, create_chat_service}, + code::create_code_search, completion::{self, create_completion_service}, event::create_logger, health, @@ -46,7 +51,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi ), paths(routes::log_event, routes::completions, routes::completions, routes::health, routes::search), components(schemas( - tabby_common::api::event::LogEventRequest, + api::event::LogEventRequest, completion::CompletionRequest, completion::CompletionResponse, completion::Segments, @@ -102,13 +107,14 @@ pub async fn main(config: &Config, args: &ServeArgs) { info!("Starting server, this might takes a few minutes..."); let logger = Arc::new(create_logger()); + let code = Arc::new(create_code_search()); let app = Router::new() - .merge(api_router(args, config, logger.clone()).await) + .merge(api_router(args, config, logger.clone(), code.clone()).await) .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())); #[cfg(feature = "ee")] - let app = tabby_webserver::attach_webserver(app, logger).await; + let app = tabby_webserver::attach_webserver(app, logger, code).await; #[cfg(not(feature = "ee"))] let app = app.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") }); @@ -133,9 +139,12 @@ async fn load_model(args: &ServeArgs) { } } -async fn api_router(args: &ServeArgs, config: &Config, logger: Arc) -> Router { - let code = Arc::new(crate::services::code::create_code_search()); - +async fn api_router( + args: &ServeArgs, + config: &Config, + logger: Arc, + code: Arc, +) -> Router { let completion_state = if let Some(model) = &args.model { Some(Arc::new( create_completion_service( diff --git a/crates/tabby/src/services/code.rs b/crates/tabby/src/services/code.rs index 1abc586..7ca3188 100644 --- a/crates/tabby/src/services/code.rs +++ b/crates/tabby/src/services/code.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use anyhow::Result; use async_trait::async_trait; use tabby_common::{ + api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}, index::{self, register_tokenizers, CodeSearchSchema}, path, }; @@ -16,8 +17,6 @@ use tantivy::{ use tokio::{sync::Mutex, time::sleep}; use tracing::{debug, log::info}; -use crate::api::code::{CodeSearch, CodeSearchError, Hit, HitDocument, SearchResponse}; - struct CodeSearchImpl { reader: IndexReader, query_parser: QueryParser, diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 11aca1c..366b421 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -5,7 +5,10 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use tabby_common::{ api, - api::event::{Event, EventLogger}, + api::{ + code::CodeSearch, + event::{Event, EventLogger}, + }, languages::get_language, }; use tabby_inference::{TextGeneration, TextGenerationOptions, TextGenerationOptionsBuilder}; @@ -14,7 +17,7 @@ use tracing::debug; use utoipa::ToSchema; use super::model; -use crate::{api::code::CodeSearch, Device}; +use crate::Device; #[derive(Error, Debug)] pub enum CompletionError { diff --git a/crates/tabby/src/services/completion/completion_prompt.rs b/crates/tabby/src/services/completion/completion_prompt.rs index 627916f..742e5bf 100644 --- a/crates/tabby/src/services/completion/completion_prompt.rs +++ b/crates/tabby/src/services/completion/completion_prompt.rs @@ -3,12 +3,14 @@ use std::sync::Arc; use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; -use tabby_common::languages::get_language; +use tabby_common::{ + api::code::{CodeSearch, CodeSearchError}, + languages::get_language, +}; use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::api::code::{CodeSearch, CodeSearchError}; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index 4991f78..fc58e3f 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -8,7 +8,6 @@ use anyhow::Result; use axum::{routing, Router}; use clap::Args; use hyper::Server; -use tabby_common::api::event::EventLogger; use tabby_webserver::api::{tracing_context, HubClient, WorkerKind}; use tracing::{info, warn}; @@ -16,7 +15,6 @@ use crate::{ fatal, routes, services::{ chat::create_chat_service, - code, completion::create_completion_service, health::{read_cpu_info, read_cuda_devices}, model::download_model_if_needed, @@ -66,8 +64,8 @@ async fn make_chat_route(context: WorkerContext, args: &WorkerArgs) -> Router { async fn make_completion_route(context: WorkerContext, args: &WorkerArgs) -> Router { context.register(WorkerKind::Completion, args).await; - let code = Arc::new(code::create_code_search()); - let logger: Arc = Arc::new(context.client); + let code = Arc::new(context.client.clone()); + let logger = Arc::new(context.client); let completion_state = Arc::new( create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await, ); diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 99be2be..fdb4ecf 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -7,6 +7,7 @@ homepage.workspace = true [dependencies] anyhow.workspace = true +async-trait.workspace = true axum = { workspace = true, features = ["ws"] } bincode = "1.3.3" chrono = "0.4" diff --git a/ee/tabby-webserver/src/api.rs b/ee/tabby-webserver/src/api.rs index e05bce2..512238a 100644 --- a/ee/tabby-webserver/src/api.rs +++ b/ee/tabby-webserver/src/api.rs @@ -1,6 +1,10 @@ +use async_trait::async_trait; use juniper::{GraphQLEnum, GraphQLObject}; use serde::{Deserialize, Serialize}; -use tabby_common::api::event::RawEventLogger; +use tabby_common::api::{ + code::{CodeSearch, CodeSearchError, SearchResponse}, + event::RawEventLogger, +}; use thiserror::Error; use tokio_tungstenite::connect_async; @@ -48,6 +52,15 @@ pub trait Hub { ) -> Result; async fn log_event(content: String); + + async fn search(q: String, limit: usize, offset: usize) -> SearchResponse; + + async fn search_in_language( + language: String, + tokens: Vec, + limit: usize, + offset: usize, + ) -> SearchResponse; } pub fn tracing_context() -> tarpc::context::Context { @@ -68,3 +81,43 @@ impl RawEventLogger for HubClient { tokio::spawn(async move { client.log_event(context, content).await }); } } + +#[async_trait] +impl CodeSearch for HubClient { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result { + match self + .search(tracing_context(), q.to_owned(), limit, offset) + .await + { + Ok(serp) => Ok(serp), + Err(_) => Err(CodeSearchError::NotReady), + } + } + + async fn search_in_language( + &self, + language: &str, + tokens: &[String], + limit: usize, + offset: usize, + ) -> Result { + match self + .search_in_language( + tracing_context(), + language.to_owned(), + tokens.to_owned(), + limit, + offset, + ) + .await + { + Ok(serp) => Ok(serp), + Err(_) => Err(CodeSearchError::NotReady), + } + } +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 54553b6..b69a1ec 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -2,8 +2,11 @@ pub mod api; mod schema; pub use schema::create_schema; -use tabby_common::api::event::RawEventLogger; -use tracing::error; +use tabby_common::api::{ + code::{CodeSearch, SearchResponse}, + event::RawEventLogger, +}; +use tracing::{error, warn}; use websocket::WebSocketTransport; mod db; @@ -27,9 +30,13 @@ use schema::Schema; use server::ServerContext; use tarpc::server::{BaseChannel, Channel}; -pub async fn attach_webserver(router: Router, logger: Arc) -> Router { +pub async fn attach_webserver( + router: Router, + logger: Arc, + code: Arc, +) -> Router { let conn = db::DbConn::new().await.unwrap(); - let ctx = Arc::new(ServerContext::new(conn, logger)); + let ctx = Arc::new(ServerContext::new(conn, logger, code)); let schema = Arc::new(create_schema()); let app = Router::new() @@ -129,4 +136,42 @@ impl Hub for Arc { async fn log_event(self, _context: tarpc::context::Context, content: String) { self.ctx.logger.log(content) } + + async fn search( + self, + _context: tarpc::context::Context, + q: String, + limit: usize, + offset: usize, + ) -> SearchResponse { + match self.ctx.code.search(&q, limit, offset).await { + Ok(serp) => serp, + Err(err) => { + warn!("Failed to search: {}", err); + SearchResponse::default() + } + } + } + + async fn search_in_language( + self, + _context: tarpc::context::Context, + language: String, + tokens: Vec, + limit: usize, + offset: usize, + ) -> SearchResponse { + match self + .ctx + .code + .search_in_language(&language, &tokens, limit, offset) + .await + { + Ok(serp) => serp, + Err(err) => { + warn!("Failed to search: {}", err); + SearchResponse::default() + } + } + } } diff --git a/ee/tabby-webserver/src/server.rs b/ee/tabby-webserver/src/server.rs index b108fa0..435539f 100644 --- a/ee/tabby-webserver/src/server.rs +++ b/ee/tabby-webserver/src/server.rs @@ -6,7 +6,7 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Result; use axum::{http::Request, middleware::Next, response::IntoResponse}; use hyper::{client::HttpConnector, Body, Client, StatusCode}; -use tabby_common::api::event::RawEventLogger; +use tabby_common::api::{code::CodeSearch, event::RawEventLogger}; use tracing::{info, warn}; use crate::{ @@ -21,16 +21,22 @@ pub struct ServerContext { db_conn: DbConn, pub logger: Arc, + pub code: Arc, } impl ServerContext { - pub fn new(db_conn: DbConn, logger: Arc) -> Self { + pub fn new( + db_conn: DbConn, + logger: Arc, + code: Arc, + ) -> Self { Self { client: Client::default(), completion: worker::WorkerGroup::default(), chat: worker::WorkerGroup::default(), db_conn, logger, + code, } }