feat(ee): implement auth claims (#932)

* feat(ee): implement auth claims

* fix test

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
add-signin-page
Meng Zhang 2023-12-01 20:46:01 +08:00 committed by GitHub
parent 8d3be2ea36
commit 1a9cbdcc3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 94 additions and 38 deletions

View File

@ -6,7 +6,7 @@ use axum::{
async_trait, async_trait,
body::Body, body::Body,
extract::{FromRequest, FromRequestParts, Query}, extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode}, http::{request::Parts, HeaderValue, Method, Request, StatusCode},
response::{IntoResponse as _, Response}, response::{IntoResponse as _, Response},
Json, RequestExt as _, Json, RequestExt as _,
}; };
@ -16,6 +16,46 @@ use juniper::{
}; };
use serde::Deserialize; use serde::Deserialize;
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct AuthBearer(pub Option<String>);
pub type Rejection = (StatusCode, &'static str);
#[async_trait]
impl<B> FromRequestParts<B> for AuthBearer
where
B: Send + Sync,
{
type Rejection = Rejection;
async fn from_request_parts(req: &mut Parts, _: &B) -> Result<Self, Self::Rejection> {
// Get authorization header
let authorization = req
.headers
.get("authorization")
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"authorization contains invalid characters",
)
})?;
let Some(authorization) = authorization else {
return Ok(Self(None));
};
// Check that its a well-formed bearer and return
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some((name, contents)) if name == "Bearer" => Ok(Self(Some(contents.to_owned()))),
_ => Ok(Self(None)),
}
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>) pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where where

View File

@ -1,26 +1,34 @@
pub mod extract; pub mod extract;
pub mod response; pub mod response;
use std::{future, sync::Arc}; use std::{future};
use axum::{ use axum::{
extract::{Extension, State}, extract::{Extension, State},
response::{Html, IntoResponse}, response::{Html, IntoResponse},
}; };
use extract::AuthBearer;
use juniper_graphql_ws::Schema; use juniper_graphql_ws::Schema;
use self::{extract::JuniperRequest, response::JuniperResponse}; use self::{extract::JuniperRequest, response::JuniperResponse};
pub trait FromAuth<S> {
fn build(state: S, bearer: Option<String>) -> Self;
}
#[cfg_attr(text, axum::debug_handler)] #[cfg_attr(text, axum::debug_handler)]
pub async fn graphql<S>( pub async fn graphql<S, C>(
State(state): State<Arc<S::Context>>, State(state): State<C>,
Extension(schema): Extension<S>, Extension(schema): Extension<S>,
AuthBearer(bearer): AuthBearer,
JuniperRequest(req): JuniperRequest<S::ScalarValue>, JuniperRequest(req): JuniperRequest<S::ScalarValue>,
) -> impl IntoResponse ) -> impl IntoResponse
where where
S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here.
S::Context: FromAuth<C>,
{ {
JuniperResponse(req.execute(schema.root_node(), &state).await).into_response() let ctx = S::Context::build(state, bearer);
JuniperResponse(req.execute(schema.root_node(), &ctx).await).into_response()
} }
/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. /// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL].

View File

@ -1,12 +1,6 @@
type RegisterResponse { type RegisterResponse {
accessToken: String! accessToken: String!
refreshToken: String! refreshToken: String!
errors: [AuthError!]!
}
type AuthError {
message: String!
code: String!
} }
enum WorkerKind { enum WorkerKind {
@ -15,7 +9,7 @@ enum WorkerKind {
} }
type Mutation { type Mutation {
resetRegistrationToken(token: String): String! resetRegistrationToken: String!
register(email: String!, password1: String!, password2: String!): RegisterResponse! register(email: String!, password1: String!, password2: String!): RegisterResponse!
tokenAuth(email: String!, password: String!): TokenAuthResponse! tokenAuth(email: String!, password: String!): TokenAuthResponse!
verifyToken(token: String!): VerifyTokenResponse! verifyToken(token: String!): VerifyTokenResponse!
@ -27,7 +21,6 @@ type UserInfo {
} }
type VerifyTokenResponse { type VerifyTokenResponse {
errors: [AuthError!]!
claims: Claims! claims: Claims!
} }
@ -56,7 +49,6 @@ type Worker {
type TokenAuthResponse { type TokenAuthResponse {
accessToken: String! accessToken: String!
refreshToken: String! refreshToken: String!
errors: [AuthError!]!
} }
schema { schema {

View File

@ -46,7 +46,7 @@ pub async fn attach_webserver(
.layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer)) .layer(from_fn_with_state(ctx.clone(), distributed_tabby_layer))
.route( .route(
"/graphql", "/graphql",
routing::post(graphql::<Arc<Schema>>).with_state(ctx.clone()), routing::post(graphql::<Arc<Schema>, Arc<ServerContext>>).with_state(ctx.clone()),
) )
.route("/graphql", routing::get(playground("/graphql", None))) .route("/graphql", routing::get(playground("/graphql", None)))
.layer(Extension(schema)) .layer(Extension(schema))

View File

@ -1,8 +1,12 @@
pub mod auth; pub mod auth;
use std::sync::Arc;
use juniper::{ use juniper::{
graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode, graphql_object, graphql_value, EmptySubscription, FieldError, FieldResult, RootNode,
}; };
use juniper_axum::FromAuth;
use crate::{ use crate::{
api::Worker, api::Worker,
@ -13,20 +17,32 @@ use crate::{
}, },
}; };
pub struct Context {
claims: Option<auth::Claims>,
server: Arc<ServerContext>,
}
impl FromAuth<Arc<ServerContext>> for Context {
fn build(server: Arc<ServerContext>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
}
}
// To make our context usable by Juniper, we have to implement a marker trait. // To make our context usable by Juniper, we have to implement a marker trait.
impl juniper::Context for ServerContext {} impl juniper::Context for Context {}
#[derive(Default)] #[derive(Default)]
pub struct Query; pub struct Query;
#[graphql_object(context = ServerContext)] #[graphql_object(context = Context)]
impl Query { impl Query {
async fn workers(ctx: &ServerContext) -> Vec<Worker> { async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.list_workers().await ctx.server.list_workers().await
} }
async fn registration_token(ctx: &ServerContext) -> FieldResult<String> { async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.read_registration_token().await?; let token = ctx.server.read_registration_token().await?;
Ok(token) Ok(token)
} }
} }
@ -34,15 +50,12 @@ impl Query {
#[derive(Default)] #[derive(Default)]
pub struct Mutation; pub struct Mutation;
#[graphql_object(context = ServerContext)] #[graphql_object(context = Context)]
impl Mutation { impl Mutation {
async fn reset_registration_token( async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
ctx: &ServerContext, if let Some(claims) = &ctx.claims {
token: Option<String>,
) -> FieldResult<String> {
if let Some(Ok(claims)) = token.map(|t| validate_jwt(&t)) {
if claims.user_info().is_admin() { if claims.user_info().is_admin() {
let reg_token = ctx.reset_registration_token().await?; let reg_token = ctx.server.reset_registration_token().await?;
return Ok(reg_token); return Ok(reg_token);
} }
} }
@ -53,7 +66,7 @@ impl Mutation {
} }
async fn register( async fn register(
ctx: &ServerContext, ctx: &Context,
email: String, email: String,
password1: String, password1: String,
password2: String, password2: String,
@ -63,24 +76,24 @@ impl Mutation {
password1, password1,
password2, password2,
}; };
ctx.auth().register(input).await ctx.server.auth().register(input).await
} }
async fn token_auth( async fn token_auth(
ctx: &ServerContext, ctx: &Context,
email: String, email: String,
password: String, password: String,
) -> FieldResult<TokenAuthResponse> { ) -> FieldResult<TokenAuthResponse> {
let input = TokenAuthInput { email, password }; let input = TokenAuthInput { email, password };
ctx.auth().token_auth(input).await ctx.server.auth().token_auth(input).await
} }
async fn verify_token(ctx: &ServerContext, token: String) -> FieldResult<VerifyTokenResponse> { async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.auth().verify_token(token).await ctx.server.auth().verify_token(token).await
} }
} }
pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<ServerContext>>; pub type Schema = RootNode<'static, Query, Mutation, EmptySubscription<Context>>;
pub fn create_schema() -> Schema { pub fn create_schema() -> Schema {
Schema::new(Query, Mutation, EmptySubscription::new()) Schema::new(Query, Mutation, EmptySubscription::new())

View File

@ -123,7 +123,7 @@ impl Claims {
} }
} }
pub fn user_info(self) -> UserInfo { pub fn user_info(&self) -> &UserInfo {
self.user &self.user
} }
} }

View File

@ -268,6 +268,9 @@ mod tests {
let claims = Claims::new(UserInfo::new("test".to_string(), false)); let claims = Claims::new(UserInfo::new("test".to_string(), false));
let token = generate_jwt(claims).unwrap(); let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap(); let claims = validate_jwt(&token).unwrap();
assert_eq!(claims.user_info(), UserInfo::new("test".to_string(), false)); assert_eq!(
claims.user_info(),
&UserInfo::new("test".to_string(), false)
);
} }
} }