feat: add tabby-webserver as distribution layer of tabby #769

extract-routes
Meng Zhang 2023-11-09 10:51:07 -08:00
parent 56ec7c05f6
commit 15f768a971
58 changed files with 301 additions and 59 deletions

30
Cargo.lock generated
View File

@ -430,9 +430,9 @@ dependencies = [
[[package]]
name = "axum"
version = "0.6.18"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core",
@ -1820,9 +1820,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "0.14.26"
version = "0.14.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468"
dependencies = [
"bytes",
"futures-channel",
@ -4078,14 +4078,12 @@ dependencies = [
"hyper",
"lazy_static",
"llama-cpp-bindings",
"mime_guess",
"minijinja",
"nvml-wrapper",
"opentelemetry",
"opentelemetry-otlp",
"regex",
"reqwest",
"rust-embed 8.0.0",
"serde",
"serde_json",
"serdeconv",
@ -4096,6 +4094,7 @@ dependencies = [
"tabby-download",
"tabby-inference",
"tabby-scheduler",
"tabby-webserver",
"tantivy",
"textdistance",
"tokio",
@ -4181,6 +4180,21 @@ dependencies = [
"tree-sitter-typescript",
]
[[package]]
name = "tabby-webserver"
version = "0.6.0-dev"
dependencies = [
"anyhow",
"axum",
"hyper",
"lazy_static",
"mime_guess",
"rust-embed 8.0.0",
"tokio",
"tracing",
"unicase",
]
[[package]]
name = "tantivy"
version = "0.21.0"
@ -5037,9 +5051,9 @@ checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicase"
version = "2.6.0"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89"
dependencies = [
"version_check",
]

View File

@ -8,6 +8,7 @@ members = [
"crates/tabby-inference",
"crates/llama-cpp-bindings",
"crates/http-api-bindings",
"ee/tabby-webserver",
]
[workspace.package]
@ -35,3 +36,5 @@ async-stream = "0.3.5"
regex = "1.10.0"
thiserror = "1.0.49"
utoipa = "3.3"
axum = "0.6"
hyper = "0.14"

View File

@ -13,8 +13,7 @@ fix-ui:
update-ui:
cd ee/tabby-ui && yarn build
rm -rf crates/tabby/ui && cp -R ee/tabby-ui/out crates/tabby/ui
cp ee/LICENSE crates/tabby/ui/
rm -rf ee/tabby-webserver/ui && cp -R ee/tabby-ui/out ee/tabby-webserver/ui
bump-version:
cargo ws version --no-git-tag --force "*"

View File

@ -12,8 +12,8 @@ tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler" }
tabby-download = { path = "../tabby-download" }
tabby-inference = { path = "../tabby-inference" }
axum = "0.6"
hyper = { version = "0.14", features = ["full"] }
axum.workspace = true
hyper = { workspace = true }
tokio = { workspace = true, features = ["full"] }
utoipa = { workspace= true, features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
@ -23,8 +23,6 @@ serde_json = { workspace = true }
tower-http = { version = "0.4.0", features = ["cors", "timeout"] }
clap = { version = "4.3.0", features = ["derive"] }
lazy_static = { workspace = true }
rust-embed = "8.0.0"
mime_guess = "2.0.4"
strum = { version = "0.24", features = ["derive"] }
strfmt = "0.2.4"
tracing = { workspace = true }
@ -46,6 +44,7 @@ regex.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
futures.workspace = true
async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver" }
[dependencies.uuid]
version = "1.3.3"

View File

@ -4,7 +4,6 @@ mod engine;
mod events;
mod health;
mod search;
mod ui;
use std::{
fs,
@ -22,6 +21,7 @@ use tabby_common::{
usage,
};
use tabby_download::download_model;
use tabby_webserver::attach_webserver;
use tokio::time::sleep;
use tower_http::{cors::CorsLayer, timeout::TimeoutLayer};
use tracing::info;
@ -147,17 +147,17 @@ pub async fn main(config: &Config, args: &ServeArgs) {
doc.override_doc(args);
let app = Router::new()
.route("/", routing::get(ui::handler))
.merge(api_router(args, config).await)
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc))
.fallback(ui::handler);
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
let app = attach_webserver(app);
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, args.port));
info!("Listening at {}", address);
start_heartbeat(args);
Server::bind(&address)
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}

View File

@ -1,36 +0,0 @@
The Tabby Enterprise license (the “Enterprise License”)
Copyright (c) 2023 TabbyML, Inc.
With regard to the Tabby Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the Tabby Subscription Terms of Service, available
at https://tabby.tabbyml.com/terms (the “Enterprise Terms”), or other
agreement governing the use of the Software, as agreed by you and TabbyML,
and otherwise have a valid Tabby Enterprise license for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that TabbyML
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Tabby Enterprise license for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that Tabby and/or its licensors (as applicable) retain
all right, title and interest in and to all such modifications. You are not
granted any other rights beyond what is expressly stated herein. Subject to the
foregoing, it is forbidden to copy, merge, publish, distribute, sublicense,
and/or sell the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
For all third party components incorporated into the Tabby Software, those
components are licensed under the original license provided by the owner of the
applicable component.

View File

@ -0,0 +1,20 @@
[package]
name = "tabby-webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
anyhow.workspace = true
axum.workspace = true
hyper = { workspace = true, features=["client"]}
lazy_static = "1.4.0"
mime_guess = "2.0.4"
rust-embed = "8.0.0"
tokio.workspace = true
tracing.workspace = true
unicase = "2.7.0"
[dev-dependencies]
tokio = { workspace = true, features = ["macros"] }

View File

@ -0,0 +1,78 @@
mod proxy;
mod ui;
mod worker;
use std::{net::SocketAddr, sync::Arc};
use axum::{
extract::State,
http::Request,
middleware::{from_fn_with_state, Next},
response::IntoResponse,
Router,
};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tracing::warn;
#[derive(Default)]
pub struct Webserver {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
}
impl Webserver {
async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
let path = request.uri().path();
let remote_addr = request
.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|ci| ci.0)
.expect("Unable to extract remote addr");
let worker = if path.starts_with("/v1/completions") {
self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") {
self.chat.select().await
} else {
None
};
if let Some(worker) = worker {
match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await {
Ok(res) => res.into_response(),
Err(err) => {
warn!("Failed to proxy request {}", err);
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
.into_response()
}
}
} else {
next.run(request).await
}
}
}
pub fn attach_webserver(router: Router) -> Router {
let ws = Arc::new(Webserver::default());
router
.fallback(ui::handler)
.layer(from_fn_with_state(ws, distributed_tabby_layer))
}
async fn distributed_tabby_layer(
State(ws): State<Arc<Webserver>>,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
ws.dispatch_request(request, next).await
}

View File

@ -0,0 +1,95 @@
use std::{net::IpAddr, str::FromStr};
use anyhow::Result;
use hyper::{
client::HttpConnector,
header::{HeaderMap, HeaderValue},
Body, Client, Request, Response, Uri,
};
use lazy_static::lazy_static;
fn is_hop_header(name: &str) -> bool {
use unicase::Ascii;
// A list of the headers, using `unicase` to help us compare without
// worrying about the case, and `lazy_static!` to prevent reallocation
// of the vector.
lazy_static! {
static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
Ascii::new("Connection"),
Ascii::new("Keep-Alive"),
Ascii::new("Proxy-Authenticate"),
Ascii::new("Proxy-Authorization"),
Ascii::new("Te"),
Ascii::new("Trailers"),
Ascii::new("Transfer-Encoding"),
Ascii::new("Upgrade"),
];
}
HOP_HEADERS.iter().any(|h| h == &name)
}
/// Returns a clone of the headers without the [hop-by-hop headers].
///
/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
let mut result = HeaderMap::new();
for (k, v) in headers.iter() {
if !is_hop_header(k.as_str()) {
result.insert(k.clone(), v.clone());
}
}
result
}
fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
*response.headers_mut() = remove_hop_headers(response.headers());
response
}
fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri> {
let forward_uri = match req.uri().query() {
Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
None => format!("{}{}", forward_url, req.uri().path()),
};
Ok(Uri::from_str(forward_uri.as_str())?)
}
fn create_proxied_request<B>(
client_ip: IpAddr,
forward_url: &str,
mut request: Request<B>,
) -> Result<Request<B>> {
*request.headers_mut() = remove_hop_headers(request.headers());
*request.uri_mut() = forward_uri(forward_url, &request)?;
let x_forwarded_for_header_name = "x-forwarded-for";
// Add forwarding information in the headers
match request.headers_mut().entry(x_forwarded_for_header_name) {
hyper::header::Entry::Vacant(entry) => {
entry.insert(client_ip.to_string().parse()?);
}
hyper::header::Entry::Occupied(mut entry) => {
let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
entry.insert(addr.parse()?);
}
}
Ok(request)
}
pub async fn call(
client: Client<HttpConnector>,
client_ip: IpAddr,
forward_uri: &str,
request: Request<Body>,
) -> Result<Response<Body>> {
let proxied_request = create_proxied_request(client_ip, forward_uri, request)?;
let response = client.request(proxied_request).await?;
let proxied_response = create_proxied_response(response);
Ok(proxied_response)
}

View File

@ -4,8 +4,6 @@ use axum::{
response::{IntoResponse, Response},
};
use crate::fatal;
#[derive(rust_embed::RustEmbed)]
#[folder = "./ui"]
struct WebAssets;
@ -25,12 +23,12 @@ where
Response::builder()
.header(header::CONTENT_TYPE, mime.as_ref())
.body(body)
.unwrap_or_else(|_| fatal!("Invalid response"))
.unwrap_or_else(|_| panic!("Invalid response"))
}
None => Response::builder()
.status(StatusCode::NOT_FOUND)
.body(boxed(Full::from(WebAssets::get("404.html").unwrap().data)))
.unwrap_or_else(|_| fatal!("Invalid response")),
.unwrap_or_else(|_| panic!("Invalid response")),
}
}
}

View File

@ -0,0 +1,72 @@
use std::{
net::SocketAddr,
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::RwLock;
use tracing::error;
#[derive(Default)]
pub struct WorkerGroup {
workers: RwLock<Vec<String>>,
}
impl WorkerGroup {
pub async fn select(&self) -> Option<String> {
let workers = self.workers.read().await;
if workers.len() > 0 {
Some(workers[random_index(workers.len())].clone())
} else {
None
}
}
pub async fn list(&self) -> Vec<String> {
self.workers.read().await.clone()
}
pub async fn register(&self, addr: SocketAddr) -> Option<String> {
let addr = format!("http://{}", addr);
let mut workers = self.workers.write().await;
if workers.len() >= 1 {
error!("You need enterprise license to utilize more than 1 workers, please contact hi@tabbyml.com for information.");
return None;
}
if !workers.contains(&addr) {
workers.push(addr.clone());
}
Some(addr)
}
}
fn random_index(size: usize) -> usize {
let unix_timestamp = (SystemTime::now().duration_since(UNIX_EPOCH))
.unwrap()
.as_nanos();
let index = unix_timestamp % (size as u128);
index as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_worker_group() {
let wg = WorkerGroup::default();
let addr1 = "127.0.0.1:8080".parse().unwrap();
let addr2 = "127.0.0.2:8080".parse().unwrap();
// Register success.
assert!(wg.register(addr1).await.is_some());
// Register failed, as > 1 workers requires enterprise license.
assert!(wg.register(addr2).await.is_none());
let workers = wg.list().await;
assert_eq!(workers.len(), 1);
assert_eq!(workers[0], format!("http://{}", addr1));
}
}