From 1a9cbdcc3c793273edaf6610d410a262d765fa73 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Fri, 1 Dec 2023 20:46:01 +0800 Subject: [PATCH] feat(ee): implement auth claims (#932) * feat(ee): implement auth claims * fix test * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- crates/juniper-axum/src/extract.rs | 42 +++++++++++++++++- crates/juniper-axum/src/lib.rs | 16 +++++-- ee/tabby-webserver/graphql/schema.graphql | 10 +---- ee/tabby-webserver/src/lib.rs | 2 +- ee/tabby-webserver/src/schema.rs | 53 ++++++++++++++--------- ee/tabby-webserver/src/schema/auth.rs | 4 +- ee/tabby-webserver/src/server/auth.rs | 5 ++- 7 files changed, 94 insertions(+), 38 deletions(-) diff --git a/crates/juniper-axum/src/extract.rs b/crates/juniper-axum/src/extract.rs index 82207ca..d32e852 100644 --- a/crates/juniper-axum/src/extract.rs +++ b/crates/juniper-axum/src/extract.rs @@ -6,7 +6,7 @@ use axum::{ async_trait, body::Body, extract::{FromRequest, FromRequestParts, Query}, - http::{HeaderValue, Method, Request, StatusCode}, + http::{request::Parts, HeaderValue, Method, Request, StatusCode}, response::{IntoResponse as _, Response}, Json, RequestExt as _, }; @@ -16,6 +16,46 @@ use juniper::{ }; use serde::Deserialize; +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct AuthBearer(pub Option); + +pub type Rejection = (StatusCode, &'static str); + +#[async_trait] +impl FromRequestParts for AuthBearer +where + B: Send + Sync, +{ + type Rejection = Rejection; + + async fn from_request_parts(req: &mut Parts, _: &B) -> Result { + // Get authorization header + let authorization = req + .headers + .get("authorization") + .map(HeaderValue::to_str) + .transpose() + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "authorization contains invalid characters", + ) + })?; + + let Some(authorization) = authorization else { + return Ok(Self(None)); + }; + + // Check that its a well-formed bearer and return + let split = authorization.split_once(' '); + match split { + // Found proper bearer + Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))), + _ => Ok(Self(None)), + } + } +} + #[derive(Debug, PartialEq)] pub struct JuniperRequest(pub GraphQLBatchRequest) where diff --git a/crates/juniper-axum/src/lib.rs b/crates/juniper-axum/src/lib.rs index e80b900..f64ac2e 100644 --- a/crates/juniper-axum/src/lib.rs +++ b/crates/juniper-axum/src/lib.rs @@ -1,26 +1,34 @@ pub mod extract; pub mod response; -use std::{future, sync::Arc}; +use std::{future}; use axum::{ extract::{Extension, State}, response::{Html, IntoResponse}, }; +use extract::AuthBearer; use juniper_graphql_ws::Schema; use self::{extract::JuniperRequest, response::JuniperResponse}; +pub trait FromAuth { + fn build(state: S, bearer: Option) -> Self; +} + #[cfg_attr(text, axum::debug_handler)] -pub async fn graphql( - State(state): State>, +pub async fn graphql( + State(state): State, Extension(schema): Extension, + AuthBearer(bearer): AuthBearer, JuniperRequest(req): JuniperRequest, ) -> impl IntoResponse where S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. + S::Context: FromAuth, { - JuniperResponse(req.execute(schema.root_node(), &state).await).into_response() + let ctx = S::Context::build(state, bearer); + JuniperResponse(req.execute(schema.root_node(), &ctx).await).into_response() } /// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. diff --git a/ee/tabby-webserver/graphql/schema.graphql b/ee/tabby-webserver/graphql/schema.graphql index 3a95548..92d6e56 100644 --- a/ee/tabby-webserver/graphql/schema.graphql +++ b/ee/tabby-webserver/graphql/schema.graphql @@ -1,12 +1,6 @@ type RegisterResponse { accessToken: String! refreshToken: String! - errors: [AuthError!]! -} - -type AuthError { - message: String! - code: String! } enum WorkerKind { @@ -15,7 +9,7 @@ enum WorkerKind { } type Mutation { - resetRegistrationToken(token: String): String! + resetRegistrationToken: String! register(email: String!, password1: String!, password2: String!): RegisterResponse! tokenAuth(email: String!, password: String!): TokenAuthResponse! verifyToken(token: String!): VerifyTokenResponse! @@ -27,7 +21,6 @@ type UserInfo { } type VerifyTokenResponse { - errors: [AuthError!]! claims: Claims! } @@ -56,7 +49,6 @@ type Worker { type TokenAuthResponse { accessToken: String! refreshToken: String! - errors: [AuthError!]! } schema { diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 8a5ad9e..07bb961 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -46,7 +46,7 @@ pub async fn attach_webserver( .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer)) .route( "/graphql", - routing::post(graphql::>).with_state(ctx.clone()), + routing::post(graphql::, Arc>).with_state(ctx.clone()), ) .route("/graphql", routing::get(playground("/graphql", None))) .layer(Extension(schema)) diff --git a/ee/tabby-webserver/src/schema.rs b/ee/tabby-webserver/src/schema.rs index c6088b4..f02ba98 100644 --- a/ee/tabby-webserver/src/schema.rs +++ b/ee/tabby-webserver/src/schema.rs @@ -1,8 +1,12 @@ pub mod auth; +use std::sync::Arc; + + use juniper::{ graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode, }; +use juniper_axum::FromAuth; use crate::{ api::Worker, @@ -13,20 +17,32 @@ use crate::{ }, }; +pub struct Context { + claims: Option, + server: Arc, +} + +impl FromAuth> for Context { + fn build(server: Arc, bearer: Option) -> Self { + let claims = bearer.and_then(|token| validate_jwt(&token).ok()); + Self { claims, server } + } +} + // To make our context usable by Juniper, we have to implement a marker trait. -impl juniper::Context for ServerContext {} +impl juniper::Context for Context {} #[derive(Default)] pub struct Query; -#[graphql_object(context = ServerContext)] +#[graphql_object(context = Context)] impl Query { - async fn workers(ctx: &ServerContext) -> Vec { - ctx.list_workers().await + async fn workers(ctx: &Context) -> Vec { + ctx.server.list_workers().await } - async fn registration_token(ctx: &ServerContext) -> FieldResult { - let token = ctx.read_registration_token().await?; + async fn registration_token(ctx: &Context) -> FieldResult { + let token = ctx.server.read_registration_token().await?; Ok(token) } } @@ -34,15 +50,12 @@ impl Query { #[derive(Default)] pub struct Mutation; -#[graphql_object(context = ServerContext)] +#[graphql_object(context = Context)] impl Mutation { - async fn reset_registration_token( - ctx: &ServerContext, - token: Option, - ) -> FieldResult { - if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) { + async fn reset_registration_token(ctx: &Context) -> FieldResult { + if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() { - let reg_token = ctx.reset_registration_token().await?; + let reg_token = ctx.server.reset_registration_token().await?; return Ok(reg_token); } } @@ -53,7 +66,7 @@ impl Mutation { } async fn register( - ctx: &ServerContext, + ctx: &Context, email: String, password1: String, password2: String, @@ -63,24 +76,24 @@ impl Mutation { password1, password2, }; - ctx.auth().register(input).await + ctx.server.auth().register(input).await } async fn token_auth( - ctx: &ServerContext, + ctx: &Context, email: String, password: String, ) -> FieldResult { let input = TokenAuthInput { email, password }; - ctx.auth().token_auth(input).await + ctx.server.auth().token_auth(input).await } - async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult { - ctx.auth().verify_token(token).await + async fn verify_token(ctx: &Context, token: String) -> FieldResult { + ctx.server.auth().verify_token(token).await } } -pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; +pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription>; pub fn create_schema() -> Schema { Schema::new(Query, Mutation, EmptySubscription::new()) diff --git a/ee/tabby-webserver/src/schema/auth.rs b/ee/tabby-webserver/src/schema/auth.rs index 1f478f8..9e5aa61 100644 --- a/ee/tabby-webserver/src/schema/auth.rs +++ b/ee/tabby-webserver/src/schema/auth.rs @@ -123,7 +123,7 @@ impl Claims { } } - pub fn user_info(self) -> UserInfo { - self.user + pub fn user_info(&self) -> &UserInfo { + &self.user } } diff --git a/ee/tabby-webserver/src/server/auth.rs b/ee/tabby-webserver/src/server/auth.rs index 8d09b50..df54e88 100644 --- a/ee/tabby-webserver/src/server/auth.rs +++ b/ee/tabby-webserver/src/server/auth.rs @@ -268,6 +268,9 @@ mod tests { let claims = Claims::new(UserInfo::new("test".to_string(), false)); let token = generate_jwt(claims).unwrap(); let claims = validate_jwt(&token).unwrap(); - assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false)); + assert_eq!( + claims.user_info(), + &UserInfo::new("test".to_string(), false) + ); } }