diff --git a/Cargo.lock b/Cargo.lock index fffd01f..5673e24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,7 +23,7 @@ version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd" dependencies = [ - "getrandom", + "getrandom 0.2.9", "once_cell", "version_check", ] @@ -186,6 +186,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "ascii" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e" + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -602,6 +608,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "bson" +version = "1.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de0aa578035b938855a710ba58d43cfb4d435f3619f99236fb35922a574d6cb1" +dependencies = [ + "base64 0.13.1", + "chrono", + "hex", + "lazy_static", + "linked-hash-map", + "rand 0.7.3", + "serde", + "serde_json", + "uuid 0.8.2", +] + [[package]] name = "bstr" version = "1.7.0" @@ -708,7 +731,7 @@ dependencies = [ "bitflags 1.3.2", "clap_derive 3.2.25", "clap_lex 0.2.4", - "indexmap", + "indexmap 1.9.3", "once_cell", "strsim 0.10.0", "termcolor", @@ -802,6 +825,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "combine" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680" +dependencies = [ + "ascii", + "byteorder", + "either", + "memchr", + "unreachable", +] + [[package]] name = "concurrent-queue" version = "2.3.0" @@ -1197,6 +1233,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_utils" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "532b4c15dccee12c7044f1fcad956e98410860b22231e44a3b827464797ca7bf" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -1314,6 +1361,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.1" @@ -1489,6 +1542,17 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-enum" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3422d14de7903a52e9dbc10ae05a7e14445ec61890100e098754e120b2bd7b1e" +dependencies = [ + "derive_utils", + "quote", + "syn 1.0.109", +] + [[package]] name = "futures-executor" version = "0.3.28" @@ -1594,6 +1658,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.9" @@ -1602,7 +1677,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", ] [[package]] @@ -1636,6 +1711,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "graphql-parser" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1abd4ce5247dfc04a03ccde70f87a048458c9356c7e41d21ad8c407b3dde6f2" +dependencies = [ + "combine", + "thiserror", +] + [[package]] name = "h2" version = "0.3.19" @@ -1648,7 +1733,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -1934,6 +2019,16 @@ dependencies = [ "serde", ] +[[package]] +name = "indexmap" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + [[package]] name = "indicatif" version = "0.17.7" @@ -2041,6 +2136,73 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "juniper" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52adf17d43d0b526eed31fac15d9312941c5c2558ffbfb105811690b96d6e2f1" +dependencies = [ + "async-trait", + "bson", + "chrono", + "fnv", + "futures", + "futures-enum", + "graphql-parser", + "indexmap 1.9.3", + "juniper_codegen", + "serde", + "smartstring", + "static_assertions", + "url", + "uuid 0.8.2", +] + +[[package]] +name = "juniper-axum" +version = "0.6.0-dev" +dependencies = [ + "axum", + "juniper", + "juniper_graphql_ws", + "serde", + "serde_json", +] + +[[package]] +name = "juniper_codegen" +version = "0.15.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aee97671061ad50301ba077d054d295e01d31a1868fbd07902db651f987e71db" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "juniper_graphql_ws" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5526c2f2a9c40f08841dc559971641fdd71c008a265745d18bb0c8b7e105b3" +dependencies = [ + "juniper", + "juniper_subscriptions", + "serde", + "tokio", +] + +[[package]] +name = "juniper_subscriptions" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2983b26a1e12b691c17432aee3881d8bec4a94d6c64bc933c0eaf6d9e3429f13" +dependencies = [ + "futures", + "juniper", +] + [[package]] name = "kdam" version = "0.5.0" @@ -2123,6 +2285,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.3.8" @@ -2374,7 +2542,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.45.0", ] @@ -2681,7 +2849,7 @@ dependencies = [ "fnv", "futures-channel", "futures-util", - "indexmap", + "indexmap 1.9.3", "js-sys", "once_cell", "pin-project-lite", @@ -2704,7 +2872,7 @@ dependencies = [ "once_cell", "opentelemetry_api", "percent-encoding", - "rand", + "rand 0.8.5", "thiserror", "tokio", "tokio-stream", @@ -2824,7 +2992,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] @@ -3054,6 +3222,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -3061,8 +3242,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -3072,7 +3263,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -3081,7 +3281,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.9", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", ] [[package]] @@ -3142,7 +3351,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom", + "getrandom 0.2.9", "redox_syscall 0.2.16", "thiserror", ] @@ -3253,7 +3462,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e" dependencies = [ "cc", - "getrandom", + "getrandom 0.2.9", "libc", "spin 0.9.8", "untrusted 0.9.0", @@ -3654,6 +3863,7 @@ version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ + "indexmap 2.0.1", "itoa", "ryu", "serde", @@ -4187,9 +4397,12 @@ dependencies = [ "anyhow", "axum", "hyper", + "juniper", + "juniper-axum", "lazy_static", "mime_guess", "rust-embed 8.0.0", + "thiserror", "tokio", "tracing", "unicase", @@ -4557,7 +4770,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" dependencies = [ "pin-project", - "rand", + "rand 0.8.5", "tokio", ] @@ -4639,7 +4852,7 @@ version = "0.19.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2380d56e8670370eee6566b0bfd4265f65b3f432e8c6d85623f728d4fa31f739" dependencies = [ - "indexmap", + "indexmap 1.9.3", "serde", "serde_spanned", "toml_datetime", @@ -4699,10 +4912,10 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap", + "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand", + "rand 0.8.5", "slab", "tokio", "tokio-util", @@ -5036,7 +5249,7 @@ dependencies = [ "http", "httparse", "log", - "rand", + "rand 0.8.5", "sha1", "thiserror", "url", @@ -5091,6 +5304,15 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + [[package]] name = "untildify" version = "0.1.1" @@ -5162,7 +5384,7 @@ version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98" dependencies = [ - "indexmap", + "indexmap 1.9.3", "serde", "serde_json", "utoipa-gen", @@ -5202,7 +5424,7 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" dependencies = [ - "getrandom", + "getrandom 0.2.9", ] [[package]] @@ -5211,8 +5433,8 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" dependencies = [ - "getrandom", - "rand", + "getrandom 0.2.9", + "rand 0.8.5", "serde", "uuid-macro-internal", ] @@ -5269,6 +5491,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + [[package]] name = "vte" version = "0.11.1" @@ -5366,6 +5594,12 @@ dependencies = [ "warp", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index cb257aa..1a11930 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/tabby-inference", "crates/llama-cpp-bindings", "crates/http-api-bindings", + "crates/juniper-axum", "ee/tabby-webserver", ] @@ -38,3 +39,4 @@ thiserror = "1.0.49" utoipa = "3.3" axum = "0.6" hyper = "0.14" +juniper = "0.15" \ No newline at end of file diff --git a/crates/juniper-axum/Cargo.toml b/crates/juniper-axum/Cargo.toml new file mode 100644 index 0000000..272915a --- /dev/null +++ b/crates/juniper-axum/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "juniper-axum" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +axum.workspace = true +juniper.workspace = true +juniper_graphql_ws = "0.3.0" +serde.workspace = true +serde_json.workspace = true diff --git a/crates/juniper-axum/LICENSE b/crates/juniper-axum/LICENSE new file mode 100644 index 0000000..a3fb4b0 --- /dev/null +++ b/crates/juniper-axum/LICENSE @@ -0,0 +1,28 @@ +BSD 2-Clause License + +Adapted from https://github.com/graphql-rust/juniper/blob/master/juniper_axum + +Copyright (c) 2023, The TabbyML team +Copyright (c) 2022-2023, Benno Tielen, Kai Ren +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/crates/juniper-axum/README.md b/crates/juniper-axum/README.md new file mode 100644 index 0000000..5bfad7e --- /dev/null +++ b/crates/juniper-axum/README.md @@ -0,0 +1,3 @@ +# juniper_axum + +Adopted from https://github.com/graphql-rust/juniper/tree/master/juniper_axum for juniper 15 \ No newline at end of file diff --git a/crates/juniper-axum/src/extract.rs b/crates/juniper-axum/src/extract.rs new file mode 100644 index 0000000..d8fbab6 --- /dev/null +++ b/crates/juniper-axum/src/extract.rs @@ -0,0 +1,179 @@ +//! Types and traits for extracting data from [`Request`]s. + +use std::fmt; + +use axum::{ + async_trait, + body::Body, + extract::{FromRequest, FromRequestParts, Query}, + http::{HeaderValue, Method, Request, StatusCode}, + response::{IntoResponse as _, Response}, + Json, RequestExt as _, +}; +use juniper::{ + http::{GraphQLBatchRequest, GraphQLRequest}, + DefaultScalarValue, ScalarValue, +}; +use serde::Deserialize; + +/// Extractor for [`axum`] to extract a [`JuniperRequest`]. +/// +/// # Example +/// +/// ```rust +/// use std::sync::Arc; +/// +/// use axum::{routing::post, Extension, Json, Router}; +/// use juniper::{ +/// RootNode, EmptySubscription, EmptyMutation, graphql_object, +/// }; +/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse}; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Context; +/// +/// impl juniper::Context for Context {} +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object(context = Context)] +/// impl Query { +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; +/// +/// let schema = Schema::new( +/// Query, +/// EmptyMutation::::new(), +/// EmptySubscription::::new() +/// ); +/// +/// let app: Router = Router::new() +/// .route("/graphql", post(graphql)) +/// .layer(Extension(Arc::new(schema))) +/// .layer(Extension(Context)); +/// +/// # #[axum::debug_handler] +/// async fn graphql( +/// Extension(schema): Extension>, +/// Extension(context): Extension, +/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request` +/// ) -> JuniperResponse { +/// JuniperResponse(req.execute(&*schema, &context).await) +/// } +#[derive(Debug, PartialEq)] +pub struct JuniperRequest(pub GraphQLBatchRequest) +where + S: ScalarValue; + +#[async_trait] +impl FromRequest for JuniperRequest +where + S: ScalarValue, + State: Sync, + Query: FromRequestParts, + Json>: FromRequest, + > as FromRequest>::Rejection: fmt::Display, + String: FromRequest, +{ + type Rejection = Response; + + async fn from_request(mut req: Request, state: &State) -> Result { + let content_type = req + .headers() + .get("content-type") + .map(HeaderValue::to_str) + .transpose() + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "`Content-Type` header is not a valid HTTP header string", + ) + .into_response() + })?; + + match (req.method(), content_type) { + (&Method::GET, _) => req + .extract_parts::>() + .await + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid request query string: {e}"), + ) + .into_response() + }) + .and_then(|query| { + query + .0 + .try_into() + .map(|q| Self(GraphQLBatchRequest::Single(q))) + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid request query `variables`: {e}"), + ) + .into_response() + }) + }), + (&Method::POST, Some("application/json")) => { + Json::>::from_request(req, state) + .await + .map(|req| Self(req.0)) + .map_err(|e| { + (StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response() + }) + } + (&Method::POST, Some("application/graphql")) => String::from_request(req, state) + .await + .map(|body| { + Self(GraphQLBatchRequest::Single(GraphQLRequest::new( + body, None, None, + ))) + }) + .map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()), + (&Method::POST, _) => Err(( + StatusCode::UNSUPPORTED_MEDIA_TYPE, + "`Content-Type` header is expected to be either `application/json` or \ + `application/graphql`", + ) + .into_response()), + _ => Err(( + StatusCode::METHOD_NOT_ALLOWED, + "HTTP method is expected to be either GET or POST", + ) + .into_response()), + } + } +} + +/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string, +/// containing `variables` in JSON format. +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +struct GetRequest { + query: String, + #[serde(rename = "operationName")] + operation_name: Option, + variables: Option, +} + +impl TryFrom for GraphQLRequest { + type Error = serde_json::Error; + fn try_from(req: GetRequest) -> Result { + let GetRequest { + query, + operation_name, + variables, + } = req; + Ok(Self::new( + query, + operation_name, + variables.map(|v| serde_json::from_str(&v)).transpose()?, + )) + } +} diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs new file mode 100644 index 0000000..06d81e5 --- /dev/null +++ b/crates/juniper-axum/src/lib.rs @@ -0,0 +1,137 @@ +pub mod extract; +pub mod response; + +use std::{future, net::SocketAddr}; + +use axum::{ + extract::{ConnectInfo, Extension, State}, + response::{Html, IntoResponse}, +}; +use juniper_graphql_ws::Schema; + +use self::{extract::JuniperRequest, response::JuniperResponse}; + +pub trait FromStateAndClientAddr { + fn build(state: S, client_addr: SocketAddr) -> C; +} + +/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing +/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one. +/// +/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need +/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s +/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a +/// > [`JuniperRequest`] (see its documentation for examples). +/// +/// # Example +/// +/// ```rust +/// use std::sync::Arc; +/// +/// use axum::{routing::post, Extension, Json, Router}; +/// use juniper::{ +/// RootNode, EmptySubscription, EmptyMutation, graphql_object, +/// }; +/// use juniper_axum::graphql; +/// +/// #[derive(Clone, Copy, Debug, Default)] +/// pub struct Context; +/// +/// impl juniper::Context for Context {} +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object(context = Context)] +/// impl Query { +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; +/// +/// let schema = Schema::new( +/// Query, +/// EmptyMutation::::new(), +/// EmptySubscription::::new() +/// ); +/// +/// let app: Router = Router::new() +/// .route("/graphql", post(graphql::>)) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [`extract`]: axum::extract +/// [`Handler`]: axum::handler::Handler +#[cfg_attr(text, axum::debug_handler)] +pub async fn graphql( + ConnectInfo(addr): ConnectInfo, + State(state): State, + Extension(schema): Extension, + JuniperRequest(req): JuniperRequest, +) -> impl IntoResponse +where + S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. + S::Context: FromStateAndClientAddr, + C: Clone, +{ + let context = S::Context::build(state.clone(), addr); + JuniperResponse(req.execute(schema.root_node(), &context).await).into_response() +} + +/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. +/// +/// This does not handle routing, so you can mount it on any endpoint. +/// +/// # Example +/// +/// ```rust +/// use axum::{routing::get, Router}; +/// use juniper_axum::graphiql; +/// +/// let app: Router = Router::new() +/// .route("/", get(graphiql("/graphql", "/subscriptions"))); +/// ``` +/// +/// [`Handler`]: axum::handler::Handler +/// [GraphiQL]: https://github.com/graphql/graphiql +pub fn graphiql<'a>( + graphql_endpoint_url: &str, + subscriptions_endpoint_url: impl Into>, +) -> impl FnOnce() -> future::Ready> + Clone + Send { + let html = Html(juniper::http::graphiql::graphiql_source( + graphql_endpoint_url, + subscriptions_endpoint_url.into(), + )); + + || future::ready(html) +} + +/// Creates a [`Handler`] that replies with an HTML page containing [GraphQL Playground]. +/// +/// This does not handle routing, so you can mount it on any endpoint. +/// +/// # Example +/// +/// ```rust +/// use axum::{routing::get, Router}; +/// use juniper_axum::playground; +/// +/// let app: Router = Router::new() +/// .route("/", get(playground("/graphql", "/subscriptions"))); +/// ``` +/// +/// [`Handler`]: axum::handler::Handler +/// [GraphQL Playground]: https://github.com/prisma/graphql-playground +pub fn playground<'a>( + graphql_endpoint_url: &str, + subscriptions_endpoint_url: impl Into>, +) -> impl FnOnce() -> future::Ready> + Clone + Send { + let html = Html(juniper::http::playground::playground_source( + graphql_endpoint_url, + subscriptions_endpoint_url.into(), + )); + + || future::ready(html) +} diff --git a/crates/juniper-axum/src/response.rs b/crates/juniper-axum/src/response.rs new file mode 100644 index 0000000..047e1c0 --- /dev/null +++ b/crates/juniper-axum/src/response.rs @@ -0,0 +1,24 @@ +//! [`JuniperResponse`] definition. + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue}; + +/// Wrapper around a [`GraphQLBatchResponse`], implementing [`IntoResponse`], so it can be returned +/// from [`axum`] handlers. +pub struct JuniperResponse<'a, S = DefaultScalarValue>(pub GraphQLBatchResponse<'a, S>) +where + S: ScalarValue; + +impl IntoResponse for JuniperResponse<'_, S> { + fn into_response(self) -> Response { + if self.0.is_ok() { + Json(self.0).into_response() + } else { + (StatusCode::BAD_REQUEST, Json(self.0)).into_response() + } + } +} diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 18cd123..8fa01c0 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -9,9 +9,12 @@ homepage.workspace = true anyhow.workspace = true axum.workspace = true hyper = { workspace = true, features=["client"]} +juniper.workspace = true +juniper-axum = { path = "../../crates/juniper-axum" } lazy_static = "1.4.0" mime_guess = "2.0.4" rust-embed = "8.0.0" +thiserror.workspace = true tokio.workspace = true tracing.workspace = true unicase = "2.7.0" diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index d4b0cd3..1a77c89 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -1,70 +1,37 @@ -mod proxy; +mod schema; mod ui; +mod webserver; mod worker; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use axum::{ extract::State, http::Request, middleware::{from_fn_with_state, Next}, - response::IntoResponse, - Router, + routing, Extension, Router, }; -use hyper::{client::HttpConnector, Body, Client, StatusCode}; -use tracing::warn; - -#[derive(Default)] -pub struct Webserver { - client: Client, - completion: worker::WorkerGroup, - chat: worker::WorkerGroup, -} - -impl Webserver { - async fn dispatch_request( - &self, - request: Request, - next: Next, - ) -> axum::response::Response { - let path = request.uri().path(); - - let remote_addr = request - .extensions() - .get::>() - .map(|ci| ci.0) - .expect("Unable to extract remote addr"); - - let worker = if path.starts_with("/v1/completions") { - self.completion.select().await - } else if path.starts_with("/v1beta/chat/completions") { - self.chat.select().await - } else { - None - }; - - if let Some(worker) = worker { - match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await { - Ok(res) => res.into_response(), - Err(err) => { - warn!("Failed to proxy request {}", err); - axum::response::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap() - .into_response() - } - } - } else { - next.run(request).await - } - } -} +use hyper::Body; +use juniper::EmptySubscription; +use juniper_axum::{graphiql, graphql, playground}; +use schema::{Mutation, Query, Schema}; +use webserver::Webserver; pub fn attach_webserver(router: Router) -> Router { let ws = Arc::new(Webserver::default()); + let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new())); + + let app = Router::new() + .route("/graphql", routing::get(playground("/graphql", None))) + .route("/graphiql", routing::get(graphiql("/graphql", None))) + .route( + "/graphql", + routing::post(graphql::, Arc>).with_state(ws.clone()), + ) + .layer(Extension(schema)); router + .merge(app) .fallback(ui::handler) .layer(from_fn_with_state(ws, distributed_tabby_layer)) } diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs new file mode 100644 index 0000000..d4f9956 --- /dev/null +++ b/ee/tabby-webserver/src/schema.rs @@ -0,0 +1,84 @@ +use std::{net::SocketAddr, sync::Arc}; + +use juniper::{ + graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject, + IntoFieldError, RootNode, ScalarValue, Value, +}; +use juniper_axum::FromStateAndClientAddr; + +use crate::webserver::{Webserver, WebserverError}; + +pub struct Request { + ws: Arc, + client_addr: SocketAddr, +} + +impl FromStateAndClientAddr> for Request { + fn build(ws: Arc, client_addr: SocketAddr) -> Request { + Request { ws, client_addr } + } +} + +// To make our context usable by Juniper, we have to implement a marker trait. +impl juniper::Context for Request {} + +#[derive(GraphQLEnum, Clone, Debug)] +pub enum WorkerKind { + Completion, + Chat, +} + +#[derive(GraphQLObject, Clone, Debug)] +pub struct Worker { + kind: WorkerKind, + addr: String, +} + +impl Worker { + pub fn new(kind: WorkerKind, addr: String) -> Self { + Self { kind, addr } + } +} + +#[derive(Default)] +pub struct Query; + +#[graphql_object(context = Request)] +impl Query { + async fn workers(request: &Request) -> Vec { + request.ws.list_workers().await + } +} + +pub struct Mutation; + +#[graphql_object(context = Request)] +impl Mutation { + async fn register_worker( + request: &Request, + token: String, + kind: WorkerKind, + port: i32, + ) -> Result { + let ws = &request.ws; + ws.register_worker(token, request.client_addr, kind, port) + .await + } +} + +pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; + +impl IntoFieldError for WebserverError { + fn into_field_error(self) -> FieldError { + let msg = format!("{}", &self); + match self { + WebserverError::InvalidToken(token) => FieldError::new( + msg, + graphql_value!({ + "token": token + }), + ), + _ => FieldError::new(msg, Value::Null), + } + } +} diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs new file mode 100644 index 0000000..2cb0a56 --- /dev/null +++ b/ee/tabby-webserver/src/webserver.rs @@ -0,0 +1,111 @@ +mod proxy; + +use std::net::SocketAddr; + +use axum::{http::Request, middleware::Next, response::IntoResponse}; +use hyper::{client::HttpConnector, Body, Client, StatusCode}; +use thiserror::Error; +use tracing::{info, warn}; + +use crate::{ + schema::{Worker, WorkerKind}, + worker, +}; + +#[derive(Error, Debug)] +pub enum WebserverError { + #[error("Invalid worker token")] + InvalidToken(String), + + #[error("Feature requires enterprise license")] + RequiresEnterpriseLicense, +} + +#[derive(Default)] +pub struct Webserver { + client: Client, + completion: worker::WorkerGroup, + chat: worker::WorkerGroup, +} + +// FIXME: generate token and support refreshing in database. +static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382"; + +impl Webserver { + pub async fn register_worker( + &self, + token: String, + client_addr: SocketAddr, + kind: WorkerKind, + port: i32, + ) -> Result { + if token != WORKER_TOKEN { + return Err(WebserverError::InvalidToken(token)); + } + + let addr = SocketAddr::new(client_addr.ip(), port as u16); + let addr = match kind { + WorkerKind::Completion => self.completion.register(addr).await, + WorkerKind::Chat => self.chat.register(addr).await, + }; + + if let Some(addr) = addr { + info!("registering <{:?}> worker running at {}", kind, addr); + Ok(Worker::new(kind, addr)) + } else { + Err(WebserverError::RequiresEnterpriseLicense) + } + } + + pub async fn list_workers(&self) -> Vec { + let make_workers = |x: WorkerKind, lst: Vec| -> Vec { + lst.into_iter() + .map(|addr| Worker::new(x.clone(), addr)) + .collect() + }; + + [ + make_workers(WorkerKind::Completion, self.completion.list().await), + make_workers(WorkerKind::Chat, self.chat.list().await), + ] + .concat() + } + + pub async fn dispatch_request( + &self, + request: Request, + next: Next, + ) -> axum::response::Response { + let path = request.uri().path(); + + let remote_addr = request + .extensions() + .get::>() + .map(|ci| ci.0) + .expect("Unable to extract remote addr"); + + let worker = if path.starts_with("/v1/completions") { + self.completion.select().await + } else if path.starts_with("/v1beta/chat/completions") { + self.chat.select().await + } else { + None + }; + + if let Some(worker) = worker { + match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await { + Ok(res) => res.into_response(), + Err(err) => { + warn!("Failed to proxy request {}", err); + axum::response::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap() + .into_response() + } + } + } else { + next.run(request).await + } + } +} diff --git a/ee/tabby-webserver/src/proxy.rs b/ee/tabby-webserver/src/webserver/proxy.rs similarity index 100% rename from ee/tabby-webserver/src/proxy.rs rename to ee/tabby-webserver/src/webserver/proxy.rs