feat: add graphql interface to tabby-webserver (#770)
feat: add graphql interface to tabby-webserverextract-routes
parent
4d6dc626c0
commit
3a9b4d9ef5
|
|
@ -23,7 +23,7 @@ version = "0.7.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"getrandom 0.2.9",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
|
@ -186,6 +186,12 @@ version = "0.7.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
|
||||
|
||||
[[package]]
|
||||
name = "ascii"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eab1c04a571841102f5345a8fc0f6bb3d31c315dec879b5c6e42e40ce7ffa34e"
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
|
|
@ -602,6 +608,23 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bson"
|
||||
version = "1.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de0aa578035b938855a710ba58d43cfb4d435f3619f99236fb35922a574d6cb1"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"hex",
|
||||
"lazy_static",
|
||||
"linked-hash-map",
|
||||
"rand 0.7.3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"uuid 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.7.0"
|
||||
|
|
@ -708,7 +731,7 @@ dependencies = [
|
|||
"bitflags 1.3.2",
|
||||
"clap_derive 3.2.25",
|
||||
"clap_lex 0.2.4",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"once_cell",
|
||||
"strsim 0.10.0",
|
||||
"termcolor",
|
||||
|
|
@ -802,6 +825,19 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "3.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da3da6baa321ec19e1cc41d31bf599f00c783d0517095cdaf0332e3fe8d20680"
|
||||
dependencies = [
|
||||
"ascii",
|
||||
"byteorder",
|
||||
"either",
|
||||
"memchr",
|
||||
"unreachable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.3.0"
|
||||
|
|
@ -1197,6 +1233,17 @@ dependencies = [
|
|||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_utils"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "532b4c15dccee12c7044f1fcad956e98410860b22231e44a3b827464797ca7bf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
|
@ -1314,6 +1361,12 @@ dependencies = [
|
|||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.1"
|
||||
|
|
@ -1489,6 +1542,17 @@ version = "0.3.28"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-enum"
|
||||
version = "0.1.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3422d14de7903a52e9dbc10ae05a7e14445ec61890100e098754e120b2bd7b1e"
|
||||
dependencies = [
|
||||
"derive_utils",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.28"
|
||||
|
|
@ -1594,6 +1658,17 @@ dependencies = [
|
|||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi 0.9.0+wasi-snapshot-preview1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.9"
|
||||
|
|
@ -1602,7 +1677,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
|
|||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1636,6 +1711,16 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphql-parser"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1abd4ce5247dfc04a03ccde70f87a048458c9356c7e41d21ad8c407b3dde6f2"
|
||||
dependencies = [
|
||||
"combine",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.19"
|
||||
|
|
@ -1648,7 +1733,7 @@ dependencies = [
|
|||
"futures-sink",
|
||||
"futures-util",
|
||||
"http",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
|
|
@ -1934,6 +2019,16 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.17.7"
|
||||
|
|
@ -2041,6 +2136,73 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "juniper"
|
||||
version = "0.15.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52adf17d43d0b526eed31fac15d9312941c5c2558ffbfb105811690b96d6e2f1"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bson",
|
||||
"chrono",
|
||||
"fnv",
|
||||
"futures",
|
||||
"futures-enum",
|
||||
"graphql-parser",
|
||||
"indexmap 1.9.3",
|
||||
"juniper_codegen",
|
||||
"serde",
|
||||
"smartstring",
|
||||
"static_assertions",
|
||||
"url",
|
||||
"uuid 0.8.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "juniper-axum"
|
||||
version = "0.6.0-dev"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"juniper",
|
||||
"juniper_graphql_ws",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "juniper_codegen"
|
||||
version = "0.15.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aee97671061ad50301ba077d054d295e01d31a1868fbd07902db651f987e71db"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "juniper_graphql_ws"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed5526c2f2a9c40f08841dc559971641fdd71c008a265745d18bb0c8b7e105b3"
|
||||
dependencies = [
|
||||
"juniper",
|
||||
"juniper_subscriptions",
|
||||
"serde",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "juniper_subscriptions"
|
||||
version = "0.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2983b26a1e12b691c17432aee3881d8bec4a94d6c64bc933c0eaf6d9e3429f13"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"juniper",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kdam"
|
||||
version = "0.5.0"
|
||||
|
|
@ -2123,6 +2285,12 @@ dependencies = [
|
|||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.3.8"
|
||||
|
|
@ -2374,7 +2542,7 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
|
|||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
|
|
@ -2681,7 +2849,7 @@ dependencies = [
|
|||
"fnv",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"js-sys",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
|
|
@ -2704,7 +2872,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"opentelemetry_api",
|
||||
"percent-encoding",
|
||||
"rand",
|
||||
"rand 0.8.5",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
|
|
@ -2824,7 +2992,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4"
|
||||
dependencies = [
|
||||
"fixedbitset",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3054,6 +3222,19 @@ dependencies = [
|
|||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
"libc",
|
||||
"rand_chacha 0.2.2",
|
||||
"rand_core 0.5.1",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
|
|
@ -3061,8 +3242,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3072,7 +3263,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3081,7 +3281,16 @@ version = "0.6.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"getrandom 0.2.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3142,7 +3351,7 @@ version = "0.4.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"getrandom 0.2.9",
|
||||
"redox_syscall 0.2.16",
|
||||
"thiserror",
|
||||
]
|
||||
|
|
@ -3253,7 +3462,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "9babe80d5c16becf6594aa32ad2be8fe08498e7ae60b77de8df700e67f191d7e"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"getrandom",
|
||||
"getrandom 0.2.9",
|
||||
"libc",
|
||||
"spin 0.9.8",
|
||||
"untrusted 0.9.0",
|
||||
|
|
@ -3654,6 +3863,7 @@ version = "1.0.107"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
|
||||
dependencies = [
|
||||
"indexmap 2.0.1",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
|
|
@ -4187,9 +4397,12 @@ dependencies = [
|
|||
"anyhow",
|
||||
"axum",
|
||||
"hyper",
|
||||
"juniper",
|
||||
"juniper-axum",
|
||||
"lazy_static",
|
||||
"mime_guess",
|
||||
"rust-embed 8.0.0",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"unicase",
|
||||
|
|
@ -4557,7 +4770,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
|
||||
dependencies = [
|
||||
"pin-project",
|
||||
"rand",
|
||||
"rand 0.8.5",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
|
|
@ -4639,7 +4852,7 @@ version = "0.19.10"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2380d56e8670370eee6566b0bfd4265f65b3f432e8c6d85623f728d4fa31f739"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
|
|
@ -4699,10 +4912,10 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
|||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"pin-project",
|
||||
"pin-project-lite",
|
||||
"rand",
|
||||
"rand 0.8.5",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
|
|
@ -5036,7 +5249,7 @@ dependencies = [
|
|||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
|
|
@ -5091,6 +5304,15 @@ version = "0.1.10"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
|
||||
|
||||
[[package]]
|
||||
name = "unreachable"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56"
|
||||
dependencies = [
|
||||
"void",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untildify"
|
||||
version = "0.1.1"
|
||||
|
|
@ -5162,7 +5384,7 @@ version = "3.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"indexmap 1.9.3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"utoipa-gen",
|
||||
|
|
@ -5202,7 +5424,7 @@ version = "0.8.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"getrandom 0.2.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -5211,8 +5433,8 @@ version = "1.4.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"rand",
|
||||
"getrandom 0.2.9",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"uuid-macro-internal",
|
||||
]
|
||||
|
|
@ -5269,6 +5491,12 @@ version = "0.9.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "void"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
|
||||
|
||||
[[package]]
|
||||
name = "vte"
|
||||
version = "0.11.1"
|
||||
|
|
@ -5366,6 +5594,12 @@ dependencies = [
|
|||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ members = [
|
|||
"crates/tabby-inference",
|
||||
"crates/llama-cpp-bindings",
|
||||
"crates/http-api-bindings",
|
||||
"crates/juniper-axum",
|
||||
"ee/tabby-webserver",
|
||||
]
|
||||
|
||||
|
|
@ -38,3 +39,4 @@ thiserror = "1.0.49"
|
|||
utoipa = "3.3"
|
||||
axum = "0.6"
|
||||
hyper = "0.14"
|
||||
juniper = "0.15"
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "juniper-axum"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
axum.workspace = true
|
||||
juniper.workspace = true
|
||||
juniper_graphql_ws = "0.3.0"
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
BSD 2-Clause License
|
||||
|
||||
Adapted from https://github.com/graphql-rust/juniper/blob/master/juniper_axum
|
||||
|
||||
Copyright (c) 2023, The TabbyML team
|
||||
Copyright (c) 2022-2023, Benno Tielen, Kai Ren
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# juniper_axum
|
||||
|
||||
Adopted from https://github.com/graphql-rust/juniper/tree/master/juniper_axum for juniper 15
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
//! Types and traits for extracting data from [`Request`]s.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::Body,
|
||||
extract::{FromRequest, FromRequestParts, Query},
|
||||
http::{HeaderValue, Method, Request, StatusCode},
|
||||
response::{IntoResponse as _, Response},
|
||||
Json, RequestExt as _,
|
||||
};
|
||||
use juniper::{
|
||||
http::{GraphQLBatchRequest, GraphQLRequest},
|
||||
DefaultScalarValue, ScalarValue,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// use axum::{routing::post, Extension, Json, Router};
|
||||
/// use juniper::{
|
||||
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
|
||||
/// };
|
||||
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
|
||||
///
|
||||
/// #[derive(Clone, Copy, Debug)]
|
||||
/// pub struct Context;
|
||||
///
|
||||
/// impl juniper::Context for Context {}
|
||||
///
|
||||
/// #[derive(Clone, Copy, Debug)]
|
||||
/// pub struct Query;
|
||||
///
|
||||
/// #[graphql_object(context = Context)]
|
||||
/// impl Query {
|
||||
/// fn add(a: i32, b: i32) -> i32 {
|
||||
/// a + b
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
|
||||
///
|
||||
/// let schema = Schema::new(
|
||||
/// Query,
|
||||
/// EmptyMutation::<Context>::new(),
|
||||
/// EmptySubscription::<Context>::new()
|
||||
/// );
|
||||
///
|
||||
/// let app: Router = Router::new()
|
||||
/// .route("/graphql", post(graphql))
|
||||
/// .layer(Extension(Arc::new(schema)))
|
||||
/// .layer(Extension(Context));
|
||||
///
|
||||
/// # #[axum::debug_handler]
|
||||
/// async fn graphql(
|
||||
/// Extension(schema): Extension<Arc<Schema>>,
|
||||
/// Extension(context): Extension<Context>,
|
||||
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
|
||||
/// ) -> JuniperResponse {
|
||||
/// JuniperResponse(req.execute(&*schema, &context).await)
|
||||
/// }
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
|
||||
where
|
||||
S: ScalarValue;
|
||||
|
||||
#[async_trait]
|
||||
impl<S, State> FromRequest<State, Body> for JuniperRequest<S>
|
||||
where
|
||||
S: ScalarValue,
|
||||
State: Sync,
|
||||
Query<GetRequest>: FromRequestParts<State>,
|
||||
Json<GraphQLBatchRequest<S>>: FromRequest<State, Body>,
|
||||
<Json<GraphQLBatchRequest<S>> as FromRequest<State, Body>>::Rejection: fmt::Display,
|
||||
String: FromRequest<State, Body>,
|
||||
{
|
||||
type Rejection = Response;
|
||||
|
||||
async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
|
||||
let content_type = req
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.map(HeaderValue::to_str)
|
||||
.transpose()
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Content-Type` header is not a valid HTTP header string",
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
|
||||
match (req.method(), content_type) {
|
||||
(&Method::GET, _) => req
|
||||
.extract_parts::<Query<GetRequest>>()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid request query string: {e}"),
|
||||
)
|
||||
.into_response()
|
||||
})
|
||||
.and_then(|query| {
|
||||
query
|
||||
.0
|
||||
.try_into()
|
||||
.map(|q| Self(GraphQLBatchRequest::Single(q)))
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid request query `variables`: {e}"),
|
||||
)
|
||||
.into_response()
|
||||
})
|
||||
}),
|
||||
(&Method::POST, Some("application/json")) => {
|
||||
Json::<GraphQLBatchRequest<S>>::from_request(req, state)
|
||||
.await
|
||||
.map(|req| Self(req.0))
|
||||
.map_err(|e| {
|
||||
(StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
|
||||
})
|
||||
}
|
||||
(&Method::POST, Some("application/graphql")) => String::from_request(req, state)
|
||||
.await
|
||||
.map(|body| {
|
||||
Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
||||
body, None, None,
|
||||
)))
|
||||
})
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()),
|
||||
(&Method::POST, _) => Err((
|
||||
StatusCode::UNSUPPORTED_MEDIA_TYPE,
|
||||
"`Content-Type` header is expected to be either `application/json` or \
|
||||
`application/graphql`",
|
||||
)
|
||||
.into_response()),
|
||||
_ => Err((
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
"HTTP method is expected to be either GET or POST",
|
||||
)
|
||||
.into_response()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string,
|
||||
/// containing `variables` in JSON format.
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct GetRequest {
|
||||
query: String,
|
||||
#[serde(rename = "operationName")]
|
||||
operation_name: Option<String>,
|
||||
variables: Option<String>,
|
||||
}
|
||||
|
||||
impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
|
||||
type Error = serde_json::Error;
|
||||
fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
|
||||
let GetRequest {
|
||||
query,
|
||||
operation_name,
|
||||
variables,
|
||||
} = req;
|
||||
Ok(Self::new(
|
||||
query,
|
||||
operation_name,
|
||||
variables.map(|v| serde_json::from_str(&v)).transpose()?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
pub mod extract;
|
||||
pub mod response;
|
||||
|
||||
use std::{future, net::SocketAddr};
|
||||
|
||||
use axum::{
|
||||
extract::{ConnectInfo, Extension, State},
|
||||
response::{Html, IntoResponse},
|
||||
};
|
||||
use juniper_graphql_ws::Schema;
|
||||
|
||||
use self::{extract::JuniperRequest, response::JuniperResponse};
|
||||
|
||||
pub trait FromStateAndClientAddr<C, S> {
|
||||
fn build(state: S, client_addr: SocketAddr) -> C;
|
||||
}
|
||||
|
||||
/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing
|
||||
/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one.
|
||||
///
|
||||
/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need
|
||||
/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s
|
||||
/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a
|
||||
/// > [`JuniperRequest`] (see its documentation for examples).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// use axum::{routing::post, Extension, Json, Router};
|
||||
/// use juniper::{
|
||||
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
|
||||
/// };
|
||||
/// use juniper_axum::graphql;
|
||||
///
|
||||
/// #[derive(Clone, Copy, Debug, Default)]
|
||||
/// pub struct Context;
|
||||
///
|
||||
/// impl juniper::Context for Context {}
|
||||
///
|
||||
/// #[derive(Clone, Copy, Debug)]
|
||||
/// pub struct Query;
|
||||
///
|
||||
/// #[graphql_object(context = Context)]
|
||||
/// impl Query {
|
||||
/// fn add(a: i32, b: i32) -> i32 {
|
||||
/// a + b
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
|
||||
///
|
||||
/// let schema = Schema::new(
|
||||
/// Query,
|
||||
/// EmptyMutation::<Context>::new(),
|
||||
/// EmptySubscription::<Context>::new()
|
||||
/// );
|
||||
///
|
||||
/// let app: Router = Router::new()
|
||||
/// .route("/graphql", post(graphql::<Arc<Schema>>))
|
||||
/// .layer(Extension(Arc::new(schema)));
|
||||
/// ```
|
||||
///
|
||||
/// [`extract`]: axum::extract
|
||||
/// [`Handler`]: axum::handler::Handler
|
||||
#[cfg_attr(text, axum::debug_handler)]
|
||||
pub async fn graphql<S, C>(
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
State(state): State<C>,
|
||||
Extension(schema): Extension<S>,
|
||||
JuniperRequest(req): JuniperRequest<S::ScalarValue>,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
|
||||
S::Context: FromStateAndClientAddr<S::Context, C>,
|
||||
C: Clone,
|
||||
{
|
||||
let context = S::Context::build(state.clone(), addr);
|
||||
JuniperResponse(req.execute(schema.root_node(), &context).await).into_response()
|
||||
}
|
||||
|
||||
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].
|
||||
///
|
||||
/// This does not handle routing, so you can mount it on any endpoint.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{routing::get, Router};
|
||||
/// use juniper_axum::graphiql;
|
||||
///
|
||||
/// let app: Router = Router::new()
|
||||
/// .route("/", get(graphiql("/graphql", "/subscriptions")));
|
||||
/// ```
|
||||
///
|
||||
/// [`Handler`]: axum::handler::Handler
|
||||
/// [GraphiQL]: https://github.com/graphql/graphiql
|
||||
pub fn graphiql<'a>(
|
||||
graphql_endpoint_url: &str,
|
||||
subscriptions_endpoint_url: impl Into<Option<&'a str>>,
|
||||
) -> impl FnOnce() -> future::Ready<Html<String>> + Clone + Send {
|
||||
let html = Html(juniper::http::graphiql::graphiql_source(
|
||||
graphql_endpoint_url,
|
||||
subscriptions_endpoint_url.into(),
|
||||
));
|
||||
|
||||
|| future::ready(html)
|
||||
}
|
||||
|
||||
/// Creates a [`Handler`] that replies with an HTML page containing [GraphQL Playground].
|
||||
///
|
||||
/// This does not handle routing, so you can mount it on any endpoint.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{routing::get, Router};
|
||||
/// use juniper_axum::playground;
|
||||
///
|
||||
/// let app: Router = Router::new()
|
||||
/// .route("/", get(playground("/graphql", "/subscriptions")));
|
||||
/// ```
|
||||
///
|
||||
/// [`Handler`]: axum::handler::Handler
|
||||
/// [GraphQL Playground]: https://github.com/prisma/graphql-playground
|
||||
pub fn playground<'a>(
|
||||
graphql_endpoint_url: &str,
|
||||
subscriptions_endpoint_url: impl Into<Option<&'a str>>,
|
||||
) -> impl FnOnce() -> future::Ready<Html<String>> + Clone + Send {
|
||||
let html = Html(juniper::http::playground::playground_source(
|
||||
graphql_endpoint_url,
|
||||
subscriptions_endpoint_url.into(),
|
||||
));
|
||||
|
||||
|| future::ready(html)
|
||||
}
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
//! [`JuniperResponse`] definition.
|
||||
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue};
|
||||
|
||||
/// Wrapper around a [`GraphQLBatchResponse`], implementing [`IntoResponse`], so it can be returned
|
||||
/// from [`axum`] handlers.
|
||||
pub struct JuniperResponse<'a, S = DefaultScalarValue>(pub GraphQLBatchResponse<'a, S>)
|
||||
where
|
||||
S: ScalarValue;
|
||||
|
||||
impl<S: ScalarValue> IntoResponse for JuniperResponse<'_, S> {
|
||||
fn into_response(self) -> Response {
|
||||
if self.0.is_ok() {
|
||||
Json(self.0).into_response()
|
||||
} else {
|
||||
(StatusCode::BAD_REQUEST, Json(self.0)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -9,9 +9,12 @@ homepage.workspace = true
|
|||
anyhow.workspace = true
|
||||
axum.workspace = true
|
||||
hyper = { workspace = true, features=["client"]}
|
||||
juniper.workspace = true
|
||||
juniper-axum = { path = "../../crates/juniper-axum" }
|
||||
lazy_static = "1.4.0"
|
||||
mime_guess = "2.0.4"
|
||||
rust-embed = "8.0.0"
|
||||
thiserror.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
unicase = "2.7.0"
|
||||
|
|
|
|||
|
|
@ -1,70 +1,37 @@
|
|||
mod proxy;
|
||||
mod schema;
|
||||
mod ui;
|
||||
mod webserver;
|
||||
mod worker;
|
||||
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::Request,
|
||||
middleware::{from_fn_with_state, Next},
|
||||
response::IntoResponse,
|
||||
Router,
|
||||
routing, Extension, Router,
|
||||
};
|
||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Webserver {
|
||||
client: Client<HttpConnector>,
|
||||
completion: worker::WorkerGroup,
|
||||
chat: worker::WorkerGroup,
|
||||
}
|
||||
|
||||
impl Webserver {
|
||||
async fn dispatch_request(
|
||||
&self,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response {
|
||||
let path = request.uri().path();
|
||||
|
||||
let remote_addr = request
|
||||
.extensions()
|
||||
.get::<axum::extract::ConnectInfo<SocketAddr>>()
|
||||
.map(|ci| ci.0)
|
||||
.expect("Unable to extract remote addr");
|
||||
|
||||
let worker = if path.starts_with("/v1/completions") {
|
||||
self.completion.select().await
|
||||
} else if path.starts_with("/v1beta/chat/completions") {
|
||||
self.chat.select().await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(worker) = worker {
|
||||
match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await {
|
||||
Ok(res) => res.into_response(),
|
||||
Err(err) => {
|
||||
warn!("Failed to proxy request {}", err);
|
||||
axum::response::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
||||
}
|
||||
use hyper::Body;
|
||||
use juniper::EmptySubscription;
|
||||
use juniper_axum::{graphiql, graphql, playground};
|
||||
use schema::{Mutation, Query, Schema};
|
||||
use webserver::Webserver;
|
||||
|
||||
pub fn attach_webserver(router: Router) -> Router {
|
||||
let ws = Arc::new(Webserver::default());
|
||||
let schema = Arc::new(Schema::new(Query, Mutation, EmptySubscription::new()));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/graphql", routing::get(playground("/graphql", None)))
|
||||
.route("/graphiql", routing::get(graphiql("/graphql", None)))
|
||||
.route(
|
||||
"/graphql",
|
||||
routing::post(graphql::<Arc<Schema>, Arc<Webserver>>).with_state(ws.clone()),
|
||||
)
|
||||
.layer(Extension(schema));
|
||||
|
||||
router
|
||||
.merge(app)
|
||||
.fallback(ui::handler)
|
||||
.layer(from_fn_with_state(ws, distributed_tabby_layer))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use juniper::{
|
||||
graphql_object, graphql_value, EmptySubscription, FieldError, GraphQLEnum, GraphQLObject,
|
||||
IntoFieldError, RootNode, ScalarValue, Value,
|
||||
};
|
||||
use juniper_axum::FromStateAndClientAddr;
|
||||
|
||||
use crate::webserver::{Webserver, WebserverError};
|
||||
|
||||
pub struct Request {
|
||||
ws: Arc<Webserver>,
|
||||
client_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl FromStateAndClientAddr<Request, Arc<Webserver>> for Request {
|
||||
fn build(ws: Arc<Webserver>, client_addr: SocketAddr) -> Request {
|
||||
Request { ws, client_addr }
|
||||
}
|
||||
}
|
||||
|
||||
// To make our context usable by Juniper, we have to implement a marker trait.
|
||||
impl juniper::Context for Request {}
|
||||
|
||||
#[derive(GraphQLEnum, Clone, Debug)]
|
||||
pub enum WorkerKind {
|
||||
Completion,
|
||||
Chat,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone, Debug)]
|
||||
pub struct Worker {
|
||||
kind: WorkerKind,
|
||||
addr: String,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
pub fn new(kind: WorkerKind, addr: String) -> Self {
|
||||
Self { kind, addr }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Query;
|
||||
|
||||
#[graphql_object(context = Request)]
|
||||
impl Query {
|
||||
async fn workers(request: &Request) -> Vec<Worker> {
|
||||
request.ws.list_workers().await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mutation;
|
||||
|
||||
#[graphql_object(context = Request)]
|
||||
impl Mutation {
|
||||
async fn register_worker(
|
||||
request: &Request,
|
||||
token: String,
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
) -> Result<Worker, WebserverError> {
|
||||
let ws = &request.ws;
|
||||
ws.register_worker(token, request.client_addr, kind, port)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Request>>;
|
||||
|
||||
impl<S: ScalarValue> IntoFieldError<S> for WebserverError {
|
||||
fn into_field_error(self) -> FieldError<S> {
|
||||
let msg = format!("{}", &self);
|
||||
match self {
|
||||
WebserverError::InvalidToken(token) => FieldError::new(
|
||||
msg,
|
||||
graphql_value!({
|
||||
"token": token
|
||||
}),
|
||||
),
|
||||
_ => FieldError::new(msg, Value::Null),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
mod proxy;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{http::Request, middleware::Next, response::IntoResponse};
|
||||
use hyper::{client::HttpConnector, Body, Client, StatusCode};
|
||||
use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
schema::{Worker, WorkerKind},
|
||||
worker,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum WebserverError {
|
||||
#[error("Invalid worker token")]
|
||||
InvalidToken(String),
|
||||
|
||||
#[error("Feature requires enterprise license")]
|
||||
RequiresEnterpriseLicense,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Webserver {
|
||||
client: Client<HttpConnector>,
|
||||
completion: worker::WorkerGroup,
|
||||
chat: worker::WorkerGroup,
|
||||
}
|
||||
|
||||
// FIXME: generate token and support refreshing in database.
|
||||
static WORKER_TOKEN: &str = "4c749fad-2be7-45a3-849e-7714ccade382";
|
||||
|
||||
impl Webserver {
|
||||
pub async fn register_worker(
|
||||
&self,
|
||||
token: String,
|
||||
client_addr: SocketAddr,
|
||||
kind: WorkerKind,
|
||||
port: i32,
|
||||
) -> Result<Worker, WebserverError> {
|
||||
if token != WORKER_TOKEN {
|
||||
return Err(WebserverError::InvalidToken(token));
|
||||
}
|
||||
|
||||
let addr = SocketAddr::new(client_addr.ip(), port as u16);
|
||||
let addr = match kind {
|
||||
WorkerKind::Completion => self.completion.register(addr).await,
|
||||
WorkerKind::Chat => self.chat.register(addr).await,
|
||||
};
|
||||
|
||||
if let Some(addr) = addr {
|
||||
info!("registering <{:?}> worker running at {}", kind, addr);
|
||||
Ok(Worker::new(kind, addr))
|
||||
} else {
|
||||
Err(WebserverError::RequiresEnterpriseLicense)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_workers(&self) -> Vec<Worker> {
|
||||
let make_workers = |x: WorkerKind, lst: Vec<String>| -> Vec<Worker> {
|
||||
lst.into_iter()
|
||||
.map(|addr| Worker::new(x.clone(), addr))
|
||||
.collect()
|
||||
};
|
||||
|
||||
[
|
||||
make_workers(WorkerKind::Completion, self.completion.list().await),
|
||||
make_workers(WorkerKind::Chat, self.chat.list().await),
|
||||
]
|
||||
.concat()
|
||||
}
|
||||
|
||||
pub async fn dispatch_request(
|
||||
&self,
|
||||
request: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> axum::response::Response {
|
||||
let path = request.uri().path();
|
||||
|
||||
let remote_addr = request
|
||||
.extensions()
|
||||
.get::<axum::extract::ConnectInfo<SocketAddr>>()
|
||||
.map(|ci| ci.0)
|
||||
.expect("Unable to extract remote addr");
|
||||
|
||||
let worker = if path.starts_with("/v1/completions") {
|
||||
self.completion.select().await
|
||||
} else if path.starts_with("/v1beta/chat/completions") {
|
||||
self.chat.select().await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(worker) = worker {
|
||||
match proxy::call(self.client.clone(), remote_addr.ip(), &worker, request).await {
|
||||
Ok(res) => res.into_response(),
|
||||
Err(err) => {
|
||||
warn!("Failed to proxy request {}", err);
|
||||
axum::response::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue