feat: add graphql interface to tabby-webserver (#770)

feat: add graphql interface to tabby-webserver
extract-routes
Meng Zhang 2023-11-12 14:52:28 -08:00 committed by GitHub
parent 4d6dc626c0
commit 3a9b4d9ef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 861 additions and 76 deletions

280
Cargo.lock generated
View File

@ -23,7 +23,7 @@ version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd"
dependencies = [
"getrandom",
"getrandom 0.2.9",
"once_cell",
"version_check",
]
@ -186,6 +186,12 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
[[package]]
name = "ascii"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@ -602,6 +608,23 @@ dependencies = [
"tracing",
]
[[package]]
name = "bson"
version = "1.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de0aa578035b938855a710ba58d43cfb4d435f3619f99236fb35922a574d6cb1"
dependencies = [
"base64 0.13.1",
"chrono",
"hex",
"lazy_static",
"linked-hash-map",
"rand 0.7.3",
"serde",
"serde_json",
"uuid 0.8.2",
]
[[package]]
name = "bstr"
version = "1.7.0"
@ -708,7 +731,7 @@ dependencies = [
"bitflags 1.3.2",
"clap_derive 3.2.25",
"clap_lex 0.2.4",
"indexmap",
"indexmap 1.9.3",
"once_cell",
"strsim 0.10.0",
"termcolor",
@ -802,6 +825,19 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "combine"
version = "3.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680"
dependencies = [
"ascii",
"byteorder",
"either",
"memchr",
"unreachable",
]
[[package]]
name = "concurrent-queue"
version = "2.3.0"
@ -1197,6 +1233,17 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_utils"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "532b4c15dccee12c7044f1fcad956e98410860b22231e44a3b827464797ca7bf"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -1314,6 +1361,12 @@ dependencies = [
"termcolor",
]
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.1"
@ -1489,6 +1542,17 @@ version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
[[package]]
name = "futures-enum"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3422d14de7903a52e9dbc10ae05a7e14445ec61890100e098754e120b2bd7b1e"
dependencies = [
"derive_utils",
"quote",
"syn 1.0.109",
]
[[package]]
name = "futures-executor"
version = "0.3.28"
@ -1594,6 +1658,17 @@ dependencies = [
"version_check",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
]
[[package]]
name = "getrandom"
version = "0.2.9"
@ -1602,7 +1677,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if",
"libc",
"wasi",
"wasi 0.11.0+wasi-snapshot-preview1",
]
[[package]]
@ -1636,6 +1711,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "graphql-parser"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1abd4ce5247dfc04a03ccde70f87a048458c9356c7e41d21ad8c407b3dde6f2"
dependencies = [
"combine",
"thiserror",
]
[[package]]
name = "h2"
version = "0.3.19"
@ -1648,7 +1733,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http",
"indexmap",
"indexmap 1.9.3",
"slab",
"tokio",
"tokio-util",
@ -1934,6 +2019,16 @@ dependencies = [
"serde",
]
[[package]]
name = "indexmap"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e"
dependencies = [
"equivalent",
"hashbrown 0.14.0",
]
[[package]]
name = "indicatif"
version = "0.17.7"
@ -2041,6 +2136,73 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "juniper"
version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52adf17d43d0b526eed31fac15d9312941c5c2558ffbfb105811690b96d6e2f1"
dependencies = [
"async-trait",
"bson",
"chrono",
"fnv",
"futures",
"futures-enum",
"graphql-parser",
"indexmap 1.9.3",
"juniper_codegen",
"serde",
"smartstring",
"static_assertions",
"url",
"uuid 0.8.2",
]
[[package]]
name = "juniper-axum"
version = "0.6.0-dev"
dependencies = [
"axum",
"juniper",
"juniper_graphql_ws",
"serde",
"serde_json",
]
[[package]]
name = "juniper_codegen"
version = "0.15.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aee97671061ad50301ba077d054d295e01d31a1868fbd07902db651f987e71db"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "juniper_graphql_ws"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed5526c2f2a9c40f08841dc559971641fdd71c008a265745d18bb0c8b7e105b3"
dependencies = [
"juniper",
"juniper_subscriptions",
"serde",
"tokio",
]
[[package]]
name = "juniper_subscriptions"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2983b26a1e12b691c17432aee3881d8bec4a94d6c64bc933c0eaf6d9e3429f13"
dependencies = [
"futures",
"juniper",
]
[[package]]
name = "kdam"
version = "0.5.0"
@ -2123,6 +2285,12 @@ dependencies = [
"cc",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.3.8"
@ -2374,7 +2542,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
dependencies = [
"libc",
"log",
"wasi",
"wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.45.0",
]
@ -2681,7 +2849,7 @@ dependencies = [
"fnv",
"futures-channel",
"futures-util",
"indexmap",
"indexmap 1.9.3",
"js-sys",
"once_cell",
"pin-project-lite",
@ -2704,7 +2872,7 @@ dependencies = [
"once_cell",
"opentelemetry_api",
"percent-encoding",
"rand",
"rand 0.8.5",
"thiserror",
"tokio",
"tokio-stream",
@ -2824,7 +2992,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4"
dependencies = [
"fixedbitset",
"indexmap",
"indexmap 1.9.3",
]
[[package]]
@ -3054,6 +3222,19 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -3061,8 +3242,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
@ -3072,7 +3263,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.6.4",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
]
[[package]]
@ -3081,7 +3281,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
"getrandom 0.2.9",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
@ -3142,7 +3351,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.9",
"redox_syscall 0.2.16",
"thiserror",
]
@ -3253,7 +3462,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e"
dependencies = [
"cc",
"getrandom",
"getrandom 0.2.9",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
@ -3654,6 +3863,7 @@ version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
dependencies = [
"indexmap 2.0.1",
"itoa",
"ryu",
"serde",
@ -4187,9 +4397,12 @@ dependencies = [
"anyhow",
"axum",
"hyper",
"juniper",
"juniper-axum",
"lazy_static",
"mime_guess",
"rust-embed 8.0.0",
"thiserror",
"tokio",
"tracing",
"unicase",
@ -4557,7 +4770,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
dependencies = [
"pin-project",
"rand",
"rand 0.8.5",
"tokio",
]
@ -4639,7 +4852,7 @@ version = "0.19.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2380d56e8670370eee6566b0bfd4265f65b3f432e8c6d85623f728d4fa31f739"
dependencies = [
"indexmap",
"indexmap 1.9.3",
"serde",
"serde_spanned",
"toml_datetime",
@ -4699,10 +4912,10 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"indexmap",
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
"rand",
"rand 0.8.5",
"slab",
"tokio",
"tokio-util",
@ -5036,7 +5249,7 @@ dependencies = [
"http",
"httparse",
"log",
"rand",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
@ -5091,6 +5304,15 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]]
name = "unreachable"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56"
dependencies = [
"void",
]
[[package]]
name = "untildify"
version = "0.1.1"
@ -5162,7 +5384,7 @@ version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
dependencies = [
"indexmap",
"indexmap 1.9.3",
"serde",
"serde_json",
"utoipa-gen",
@ -5202,7 +5424,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
dependencies = [
"getrandom",
"getrandom 0.2.9",
]
[[package]]
@ -5211,8 +5433,8 @@ version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d"
dependencies = [
"getrandom",
"rand",
"getrandom 0.2.9",
"rand 0.8.5",
"serde",
"uuid-macro-internal",
]
@ -5269,6 +5491,12 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "void"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
[[package]]
name = "vte"
version = "0.11.1"
@ -5366,6 +5594,12 @@ dependencies = [
"warp",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"

View File

@ -8,6 +8,7 @@ members = [
"crates/tabby-inference",
"crates/llama-cpp-bindings",
"crates/http-api-bindings",
"crates/juniper-axum",
"ee/tabby-webserver",
]
@ -38,3 +39,4 @@ thiserror = "1.0.49"
utoipa = "3.3"
axum = "0.6"
hyper = "0.14"
juniper = "0.15"

View File

@ -0,0 +1,13 @@
[package]
name = "juniper-axum"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
axum.workspace = true
juniper.workspace = true
juniper_graphql_ws = "0.3.0"
serde.workspace = true
serde_json.workspace = true

View File

@ -0,0 +1,28 @@
BSD 2-Clause License
Adapted from https://github.com/graphql-rust/juniper/blob/master/juniper_axum
Copyright (c) 2023, The TabbyML team
Copyright (c) 2022-2023, Benno Tielen, Kai Ren
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,3 @@
# juniper_axum
Adopted from https://github.com/graphql-rust/juniper/tree/master/juniper_axum for juniper 15

View File

@ -0,0 +1,179 @@
//! Types and traits for extracting data from [`Request`]s.
use std::fmt;
use axum::{
async_trait,
body::Body,
extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode},
response::{IntoResponse as _, Response},
Json, RequestExt as _,
};
use juniper::{
http::{GraphQLBatchRequest, GraphQLRequest},
DefaultScalarValue, ScalarValue,
};
use serde::Deserialize;
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
/// fn add(a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
/// Query,
/// EmptyMutation::<Context>::new(),
/// EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
/// .route("/graphql", post(graphql))
/// .layer(Extension(Arc::new(schema)))
/// .layer(Extension(Context));
///
/// # #[axum::debug_handler]
/// async fn graphql(
/// Extension(schema): Extension<Arc<Schema>>,
/// Extension(context): Extension<Context>,
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
/// ) -> JuniperResponse {
/// JuniperResponse(req.execute(&*schema, &context).await)
/// }
#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
S: ScalarValue;
#[async_trait]
impl<S, State> FromRequest<State, Body> for JuniperRequest<S>
where
S: ScalarValue,
State: Sync,
Query<GetRequest>: FromRequestParts<State>,
Json<GraphQLBatchRequest<S>>: FromRequest<State, Body>,
<Json<GraphQLBatchRequest<S>> as FromRequest<State, Body>>::Rejection: fmt::Display,
String: FromRequest<State, Body>,
{
type Rejection = Response;
async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
let content_type = req
.headers()
.get("content-type")
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"`Content-Type` header is not a valid HTTP header string",
)
.into_response()
})?;
match (req.method(), content_type) {
(&Method::GET, _) => req
.extract_parts::<Query<GetRequest>>()
.await
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query string: {e}"),
)
.into_response()
})
.and_then(|query| {
query
.0
.try_into()
.map(|q| Self(GraphQLBatchRequest::Single(q)))
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query `variables`: {e}"),
)
.into_response()
})
}),
(&Method::POST, Some("application/json")) => {
Json::<GraphQLBatchRequest<S>>::from_request(req, state)
.await
.map(|req| Self(req.0))
.map_err(|e| {
(StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
})
}
(&Method::POST, Some("application/graphql")) => String::from_request(req, state)
.await
.map(|body| {
Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
body, None, None,
)))
})
.map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()),
(&Method::POST, _) => Err((
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"`Content-Type` header is expected to be either `application/json` or \
`application/graphql`",
)
.into_response()),
_ => Err((
StatusCode::METHOD_NOT_ALLOWED,
"HTTP method is expected to be either GET or POST",
)
.into_response()),
}
}
}
/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string,
/// containing `variables` in JSON format.
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct GetRequest {
query: String,
#[serde(rename = "operationName")]
operation_name: Option<String>,
variables: Option<String>,
}
impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
type Error = serde_json::Error;
fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
let GetRequest {
query,
operation_name,
variables,
} = req;
Ok(Self::new(
query,
operation_name,
variables.map(|v| serde_json::from_str(&v)).transpose()?,
))
}
}

View File

@ -0,0 +1,137 @@
pub mod extract;
pub mod response;
use std::{future, net::SocketAddr};
use axum::{
extract::{ConnectInfo, Extension, State},
response::{Html, IntoResponse},
};
use juniper_graphql_ws::Schema;
use self::{extract::JuniperRequest, response::JuniperResponse};
pub trait FromStateAndClientAddr<C, S> {
fn build(state: S, client_addr: SocketAddr) -> C;
}
/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing
/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one.
///
/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need
/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s
/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a
/// > [`JuniperRequest`] (see its documentation for examples).
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::graphql;
///
/// #[derive(Clone, Copy, Debug, Default)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
/// fn add(a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
/// Query,
/// EmptyMutation::<Context>::new(),
/// EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
/// .route("/graphql", post(graphql::<Arc<Schema>>))
/// .layer(Extension(Arc::new(schema)));
/// ```
///
/// [`extract`]: axum::extract
/// [`Handler`]: axum::handler::Handler
#[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S, C>(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(state): State<C>,
Extension(schema): Extension<S>,
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse
where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromStateAndClientAddr<S::Context, C>,
C: Clone,
{
let context = S::Context::build(state.clone(), addr);
JuniperResponse(req.execute(schema.root_node(), &context).await).into_response()
}
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].
///
/// This does not handle routing, so you can mount it on any endpoint.
///
/// # Example
///
/// ```rust
/// use axum::{routing::get, Router};
/// use juniper_axum::graphiql;
///
/// let app: Router = Router::new()
/// .route("/", get(graphiql("/graphql", "/subscriptions")));
/// ```
///
/// [`Handler`]: axum::handler::Handler
/// [GraphiQL]: https://github.com/graphql/graphiql
pub fn graphiql<'a>(
graphql_endpoint_url: &str,
subscriptions_endpoint_url: impl Into<Option<&'a str>>,
) -> impl FnOnce() -> future::Ready<Html<String>> + Clone + Send {
let html = Html(juniper::http::graphiql::graphiql_source(
graphql_endpoint_url,
subscriptions_endpoint_url.into(),
));
|| future::ready(html)
}
/// Creates a [`Handler`] that replies with an HTML page containing [GraphQL Playground].
///
/// This does not handle routing, so you can mount it on any endpoint.
///
/// # Example
///
/// ```rust
/// use axum::{routing::get, Router};
/// use juniper_axum::playground;
///
/// let app: Router = Router::new()
/// .route("/", get(playground("/graphql", "/subscriptions")));
/// ```
///
/// [`Handler`]: axum::handler::Handler
/// [GraphQL Playground]: https://github.com/prisma/graphql-playground
pub fn playground<'a>(
graphql_endpoint_url: &str,
subscriptions_endpoint_url: impl Into<Option<&'a str>>,
) -> impl FnOnce() -> future::Ready<Html<String>> + Clone + Send {
let html = Html(juniper::http::playground::playground_source(
graphql_endpoint_url,
subscriptions_endpoint_url.into(),
));
|| future::ready(html)
}

View File

@ -0,0 +1,24 @@
//! [`JuniperResponse`] definition.
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue};
/// Wrapper around a [`GraphQLBatchResponse`], implementing [`IntoResponse`], so it can be returned
/// from [`axum`] handlers.
pub struct JuniperResponse<'a, S = DefaultScalarValue>(pub GraphQLBatchResponse<'a, S>)
where
S: ScalarValue;
impl<S: ScalarValue> IntoResponse for JuniperResponse<'_, S> {
fn into_response(self) -> Response {
if self.0.is_ok() {
Json(self.0).into_response()
} else {
(StatusCode::BAD_REQUEST, Json(self.0)).into_response()
}
}
}

View File

@ -9,9 +9,12 @@ homepage.workspace = true
anyhow.workspace = true
axum.workspace = true
hyper = { workspace = true, features=["client"]}
juniper.workspace = true
juniper-axum = { path = "../../crates/juniper-axum" }
lazy_static = "1.4.0"
mime_guess = "2.0.4"
rust-embed = "8.0.0"
thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
unicase = "2.7.0"

View File

@ -1,70 +1,37 @@
mod proxy;
mod schema;
mod ui;
mod webserver;
mod worker;
use std::{net::SocketAddr, sync::Arc};
use std::sync::Arc;
use axum::{
extract::State,
http::Request,
middleware::{from_fn_with_state, Next},
response::IntoResponse,
Router,
routing, Extension, Router,
};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use tracing::warn;
#[derive(Default)]
pub struct Webserver {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
}
impl Webserver {
async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
let path = request.uri().path();
let remote_addr = request
.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|ci| ci.0)
.expect("Unable to extract remote addr");
let worker = if path.starts_with("/v1/completions") {
self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") {
self.chat.select().await
} else {
None
};
if let Some(worker) = worker {
match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await {
Ok(res) => res.into_response(),
Err(err) => {
warn!("Failed to proxy request {}", err);
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
.into_response()
}
}
} else {
next.run(request).await
}
}
}
use hyper::Body;
use juniper::EmptySubscription;
use juniper_axum::{graphiql, graphql, playground};
use schema::{Mutation, Query, Schema};
use webserver::Webserver;
pub fn attach_webserver(router: Router) -> Router {
let ws = Arc::new(Webserver::default());
let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new()));
let app = Router::new()
.route("/graphql", routing::get(playground("/graphql", None)))
.route("/graphiql", routing::get(graphiql("/graphql", None)))
.route(
"/graphql",
routing::post(graphql::<Arc<Schema>, Arc<Webserver>>).with_state(ws.clone()),
)
.layer(Extension(schema));
router
.merge(app)
.fallback(ui::handler)
.layer(from_fn_with_state(ws, distributed_tabby_layer))
}

View File

@ -0,0 +1,84 @@
use std::{net::SocketAddr, sync::Arc};
use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject,
IntoFieldError, RootNode, ScalarValue, Value,
};
use juniper_axum::FromStateAndClientAddr;
use crate::webserver::{Webserver, WebserverError};
pub struct Request {
ws: Arc<Webserver>,
client_addr: SocketAddr,
}
impl FromStateAndClientAddr<Request, Arc<Webserver>> for Request {
fn build(ws: Arc<Webserver>, client_addr: SocketAddr) -> Request {
Request { ws, client_addr }
}
}
// To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for Request {}
#[derive(GraphQLEnum, Clone, Debug)]
pub enum WorkerKind {
Completion,
Chat,
}
#[derive(GraphQLObject, Clone, Debug)]
pub struct Worker {
kind: WorkerKind,
addr: String,
}
impl Worker {
pub fn new(kind: WorkerKind, addr: String) -> Self {
Self { kind, addr }
}
}
#[derive(Default)]
pub struct Query;
#[graphql_object(context = Request)]
impl Query {
async fn workers(request: &Request) -> Vec<Worker> {
request.ws.list_workers().await
}
}
pub struct Mutation;
#[graphql_object(context = Request)]
impl Mutation {
async fn register_worker(
request: &Request,
token: String,
kind: WorkerKind,
port: i32,
) -> Result<Worker, WebserverError> {
let ws = &request.ws;
ws.register_worker(token, request.client_addr, kind, port)
.await
}
}
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Request>>;
impl<S: ScalarValue> IntoFieldError<S> for WebserverError {
fn into_field_error(self) -> FieldError<S> {
let msg = format!("{}", &self);
match self {
WebserverError::InvalidToken(token) => FieldError::new(
msg,
graphql_value!({
"token": token
}),
),
_ => FieldError::new(msg, Value::Null),
}
}
}

View File

@ -0,0 +1,111 @@
mod proxy;
use std::net::SocketAddr;
use axum::{http::Request, middleware::Next, response::IntoResponse};
use hyper::{client::HttpConnector, Body, Client, StatusCode};
use thiserror::Error;
use tracing::{info, warn};
use crate::{
schema::{Worker, WorkerKind},
worker,
};
#[derive(Error, Debug)]
pub enum WebserverError {
#[error("Invalid worker token")]
InvalidToken(String),
#[error("Feature requires enterprise license")]
RequiresEnterpriseLicense,
}
#[derive(Default)]
pub struct Webserver {
client: Client<HttpConnector>,
completion: worker::WorkerGroup,
chat: worker::WorkerGroup,
}
// FIXME: generate token and support refreshing in database.
static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382";
impl Webserver {
pub async fn register_worker(
&self,
token: String,
client_addr: SocketAddr,
kind: WorkerKind,
port: i32,
) -> Result<Worker, WebserverError> {
if token != WORKER_TOKEN {
return Err(WebserverError::InvalidToken(token));
}
let addr = SocketAddr::new(client_addr.ip(), port as u16);
let addr = match kind {
WorkerKind::Completion => self.completion.register(addr).await,
WorkerKind::Chat => self.chat.register(addr).await,
};
if let Some(addr) = addr {
info!("registering <{:?}> worker running at {}", kind, addr);
Ok(Worker::new(kind, addr))
} else {
Err(WebserverError::RequiresEnterpriseLicense)
}
}
pub async fn list_workers(&self) -> Vec<Worker> {
let make_workers = |x: WorkerKind, lst: Vec<String>| -> Vec<Worker> {
lst.into_iter()
.map(|addr| Worker::new(x.clone(), addr))
.collect()
};
[
make_workers(WorkerKind::Completion, self.completion.list().await),
make_workers(WorkerKind::Chat, self.chat.list().await),
]
.concat()
}
pub async fn dispatch_request(
&self,
request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
let path = request.uri().path();
let remote_addr = request
.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|ci| ci.0)
.expect("Unable to extract remote addr");
let worker = if path.starts_with("/v1/completions") {
self.completion.select().await
} else if path.starts_with("/v1beta/chat/completions") {
self.chat.select().await
} else {
None
};
if let Some(worker) = worker {
match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await {
Ok(res) => res.into_response(),
Err(err) => {
warn!("Failed to proxy request {}", err);
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
.into_response()
}
}
} else {
next.run(request).await
}
}
}