feat: make all features in webserver requires auth (#992)

* feat: require auth for webserver features

* cleanup frontend

* feat: add flag --webserver

* implement admin

* update format
r0.7
Meng Zhang 2023-12-09 01:49:10 +08:00 committed by GitHub
parent 8c02c22373
commit 6c6a2c803f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 91 additions and 106 deletions

View File

@ -92,6 +92,10 @@ pub struct ServeArgs {
/// memory requirement e.g., GPU vRAM. /// memory requirement e.g., GPU vRAM.
#[clap(long, default_value_t = 1)] #[clap(long, default_value_t = 1)]
parallelism: u8, parallelism: u8,
#[cfg(feature = "ee")]
#[clap(hide = true, long, default_value_t = false)]
webserver: bool,
} }
pub async fn main(config: &Config, args: &ServeArgs) { pub async fn main(config: &Config, args: &ServeArgs) {
@ -114,7 +118,12 @@ pub async fn main(config: &Config, args: &ServeArgs) {
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())); .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()));
#[cfg(feature = "ee")] #[cfg(feature = "ee")]
let (api, ui) = tabby_webserver::attach_webserver(api, ui, logger, code).await; let (api, ui) = if args.webserver {
tabby_webserver::attach_webserver(api, ui, logger, code).await
} else {
let ui = ui.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") });
(api, ui)
};
#[cfg(not(feature = "ee"))] #[cfg(not(feature = "ee"))]
let ui = ui.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") }); let ui = ui.fallback(|| async { axum::response::Redirect::permanent("/swagger-ui") });

View File

@ -7,7 +7,7 @@ import { WorkerKind } from '@/lib/gql/generates/graphql'
import { useHealth } from '@/lib/hooks/use-health' import { useHealth } from '@/lib/hooks/use-health'
import { useWorkers } from '@/lib/hooks/use-workers' import { useWorkers } from '@/lib/hooks/use-workers'
import { useSession } from '@/lib/tabby/auth' import { useSession } from '@/lib/tabby/auth'
import { useGraphQLQuery } from '@/lib/tabby/gql' import { useAuthenticatedGraphQLQuery, useGraphQLQuery } from '@/lib/tabby/gql'
import { buttonVariants } from '@/components/ui/button' import { buttonVariants } from '@/components/ui/button'
import { import {
Dialog, Dialog,
@ -90,7 +90,7 @@ const getRegistrationTokenDocument = graphql(/* GraphQL */ `
function MainPanel() { function MainPanel() {
const { data: healthInfo } = useHealth() const { data: healthInfo } = useHealth()
const workers = useWorkers(healthInfo) const workers = useWorkers(healthInfo)
const { data: registrationTokenRes } = useGraphQLQuery( const { data: registrationTokenRes } = useAuthenticatedGraphQLQuery(
getRegistrationTokenDocument getRegistrationTokenDocument
) )

View File

@ -12,7 +12,7 @@ export default function Signup() {
const title = isAdmin ? 'Create an admin account' : 'Create an account' const title = isAdmin ? 'Create an admin account' : 'Create an account'
const description = isAdmin const description = isAdmin
? 'After creating an admin account, your instance is secured, and only registered users can access it.' ? 'Your instance will be secured, only registered users can access it.'
: 'Fill form below to create your account' : 'Fill form below to create your account'
if (isAdmin || invitationCode) { if (isAdmin || invitationCode) {

View File

@ -9,6 +9,7 @@ import { WorkerKind } from '@/lib/gql/generates/graphql'
import { useHealth } from '@/lib/hooks/use-health' import { useHealth } from '@/lib/hooks/use-health'
import { ReleaseInfo, useLatestRelease } from '@/lib/hooks/use-latest-release' import { ReleaseInfo, useLatestRelease } from '@/lib/hooks/use-latest-release'
import { useWorkers } from '@/lib/hooks/use-workers' import { useWorkers } from '@/lib/hooks/use-workers'
import { useAuthenticatedSession } from '@/lib/tabby/auth'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
import { buttonVariants } from '@/components/ui/button' import { buttonVariants } from '@/components/ui/button'
import { IconGitHub, IconNotice } from '@/components/ui/icons' import { IconGitHub, IconNotice } from '@/components/ui/icons'
@ -16,6 +17,9 @@ import { IconGitHub, IconNotice } from '@/components/ui/icons'
import { ThemeToggle } from './theme-toggle' import { ThemeToggle } from './theme-toggle'
export function Header() { export function Header() {
// Ensure login status.
useAuthenticatedSession()
const { data } = useHealth() const { data } = useHealth()
const workers = useWorkers(data) const workers = useWorkers(data)
const isChatEnabled = has(workers, WorkerKind.Chat) const isChatEnabled = has(workers, WorkerKind.Chat)

View File

@ -4,7 +4,7 @@ import { debounce, has, isEqual } from 'lodash-es'
import useSWR from 'swr' import useSWR from 'swr'
import { useEnterSubmit } from '@/lib/hooks/use-enter-submit' import { useEnterSubmit } from '@/lib/hooks/use-enter-submit'
import { useSession } from '@/lib/tabby/auth' import { useAuthenticatedApi, useSession } from '@/lib/tabby/auth'
import fetcher from '@/lib/tabby/fetcher' import fetcher from '@/lib/tabby/fetcher'
import type { ISearchHit, SearchReponse } from '@/lib/types' import type { ISearchHit, SearchReponse } from '@/lib/types'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
@ -56,9 +56,8 @@ function PromptFormRenderer(
Record<string, ISearchHit> Record<string, ISearchHit>
>({}) >({})
const { data } = useSession()
const { data: completionData } = useSWR<SearchReponse>( const { data: completionData } = useSWR<SearchReponse>(
[queryCompletionUrl, data?.accessToken], useAuthenticatedApi(queryCompletionUrl),
fetcher, fetcher,
{ {
revalidateOnFocus: false, revalidateOnFocus: false,

View File

@ -1,55 +1,23 @@
import React from 'react' import React from 'react'
import Link from 'next/link'
import { import { useAuthenticatedSession, useSignOut } from '@/lib/tabby/auth'
useAuthenticatedSession,
useIsAdminInitialized,
useSession,
useSignOut
} from '@/lib/tabby/auth'
import { cn } from '@/lib/utils'
import { IconLogout, IconUnlock } from './ui/icons' import { IconLogout } from './ui/icons'
export default function UserPanel() { export default function UserPanel() {
const isAdminInitialized = useIsAdminInitialized()
const Component = isAdminInitialized ? UserInfoPanel : EnableAdminPanel
return (
<div className="py-4 flex justify-center text-sm font-medium">
<Component className={cn('flex items-center gap-2')} />
</div>
)
}
function UserInfoPanel({ className }: React.ComponentProps<'span'>) {
const session = useAuthenticatedSession() const session = useAuthenticatedSession()
const signOut = useSignOut() const signOut = useSignOut()
return ( return (
session && ( session && (
<span className={className}> <div className="py-4 flex justify-center text-sm font-medium">
<span title="Sign out"> <span className="flex items-center gap-2">
<IconLogout className="cursor-pointer" onClick={signOut} /> <span title="Sign out">
<IconLogout className="cursor-pointer" onClick={signOut} />
</span>
{session.email}
</span> </span>
{session.email} </div>
</span>
) )
) )
} }
function EnableAdminPanel({ className }: React.ComponentProps<'span'>) {
return (
<Link
className={cn('cursor-pointer', className)}
title="Authentication is currently not enabled. Click to view details"
href={{
pathname: '/auth/signup',
query: { isAdmin: true }
}}
>
<IconUnlock /> Secure Access
</Link>
)
}

View File

@ -4,7 +4,7 @@ import useSWR, { SWRResponse } from 'swr'
import fetcher from '@/lib/tabby/fetcher' import fetcher from '@/lib/tabby/fetcher'
import { useSession } from '../tabby/auth' import { useAuthenticatedApi, useSession } from '../tabby/auth'
export interface HealthInfo { export interface HealthInfo {
device: 'metal' | 'cpu' | 'cuda' device: 'metal' | 'cpu' | 'cuda'
@ -20,6 +20,5 @@ export interface HealthInfo {
} }
export function useHealth(): SWRResponse<HealthInfo> { export function useHealth(): SWRResponse<HealthInfo> {
const { data } = useSession() return useSWR(useAuthenticatedApi('/v1/health'), fetcher)
return useSWR(['/v1/health', data?.accessToken], fetcher)
} }

View File

@ -3,7 +3,7 @@ import { findIndex, groupBy, slice } from 'lodash-es'
import { graphql } from '@/lib/gql/generates' import { graphql } from '@/lib/gql/generates'
import { Worker, WorkerKind } from '@/lib/gql/generates/graphql' import { Worker, WorkerKind } from '@/lib/gql/generates/graphql'
import { useGraphQLQuery } from '@/lib/tabby/gql' import { useAuthenticatedGraphQLQuery, useGraphQLQuery } from '@/lib/tabby/gql'
import type { HealthInfo } from './use-health' import type { HealthInfo } from './use-health'
@ -44,7 +44,7 @@ export const getAllWorkersDocument = graphql(/* GraphQL */ `
`) `)
function useWorkers(healthInfo?: HealthInfo) { function useWorkers(healthInfo?: HealthInfo) {
const { data } = useGraphQLQuery(getAllWorkersDocument) const { data } = useAuthenticatedGraphQLQuery(getAllWorkersDocument)
let workers = data?.workers let workers = data?.workers
const groupedWorkers = React.useMemo(() => { const groupedWorkers = React.useMemo(() => {

View File

@ -251,20 +251,15 @@ export const getIsAdminInitialized = graphql(/* GraphQL */ `
} }
`) `)
function useIsAdminInitialized() {
const { data } = useGraphQLQuery(getIsAdminInitialized)
return data?.isAdminInitialized
}
function useAuthenticatedSession() { function useAuthenticatedSession() {
const { data } = useGraphQLQuery(getIsAdminInitialized) const { data } = useGraphQLQuery(getIsAdminInitialized)
const router = useRouter() const router = useRouter()
const { data: session, status } = useSession() const { data: session, status } = useSession()
React.useEffect(() => { React.useEffect(() => {
if (!data?.isAdminInitialized) return if (data?.isAdminInitialized === false) {
router.replace('/auth/signup?isAdmin=true')
if (status === 'unauthenticated') { } else if (status === 'unauthenticated') {
router.replace('/auth/signin') router.replace('/auth/signin')
} }
}, [data, status]) }, [data, status])
@ -272,6 +267,11 @@ function useAuthenticatedSession() {
return session return session
} }
function useAuthenticatedApi(path: string | null): [string, string] | null {
const { data, status } = useSession()
return path && status === 'authenticated' ? [path, data.accessToken] : null
}
export type { AuthStore, User, Session } export type { AuthStore, User, Session }
export { export {
@ -279,6 +279,6 @@ export {
useSignIn, useSignIn,
useSignOut, useSignOut,
useSession, useSession,
useIsAdminInitialized, useAuthenticatedSession,
useAuthenticatedSession useAuthenticatedApi
} }

View File

@ -1,14 +1,13 @@
export default function tokenFetcher([url, token]: Array< export default function tokenFetcher([url, token]: [
string | undefined string,
>): Promise<any> { string
]): Promise<any> {
const headers = new Headers() const headers = new Headers()
if (token) { headers.append('authorization', `Bearer ${token}`)
headers.append('authorization', `Bearer ${token}`)
}
if (process.env.NODE_ENV !== 'production') { if (process.env.NODE_ENV !== 'production') {
url = `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}` url = `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`
} }
return fetch(url!, { headers }).then(x => x.json()) return fetch(url, { headers }).then(x => x.json())
} }

View File

@ -73,18 +73,37 @@ export function useGraphQLQuery<
variables?: TVariables, variables?: TVariables,
swrConfiguration?: SWRConfiguration<TResult> swrConfiguration?: SWRConfiguration<TResult>
): SWRResponse<TResult> { ): SWRResponse<TResult> {
const { data } = useSession()
return useSWR( return useSWR(
[document, variables, data?.accessToken], [document, variables],
([document, variables, accessToken]) => ([document, variables]) =>
gqlClient.request({ gqlClient.request({
document, document,
variables, variables
requestHeaders: accessToken }),
? { swrConfiguration
authorization: `Bearer ${accessToken}` )
} }
: undefined
export function useAuthenticatedGraphQLQuery<
TResult,
TVariables extends Variables | undefined
>(
document: TypedDocumentNode<TResult, TVariables>,
variables?: TVariables,
swrConfiguration?: SWRConfiguration<TResult>
): SWRResponse<TResult> {
const { data, status } = useSession()
return useSWR(
status === 'authenticated'
? [document, variables, data?.accessToken]
: null,
([document, variables, accessToken]) =>
gqlClient.request({
document,
variables,
requestHeaders: {
authorization: `Bearer ${accessToken}`
}
}), }),
swrConfiguration swrConfiguration
) )

View File

@ -72,36 +72,27 @@ pub struct Query;
#[graphql_object(context = Context)] #[graphql_object(context = Context)]
impl Query { impl Query {
async fn workers(ctx: &Context) -> Result<Vec<Worker>> { async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
if ctx.locator.auth().is_admin_initialized().await? { if let Some(claims) = &ctx.claims {
if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() {
if claims.user_info().is_admin() { let workers = ctx.locator.worker().list_workers().await;
let workers = ctx.locator.worker().list_workers().await; return Ok(workers);
return Ok(workers);
}
} }
Err(CoreError::Unauthorized(
"Only admin is able to read workers",
))
} else {
Ok(ctx.locator.worker().list_workers().await)
} }
Err(CoreError::Unauthorized(
"Only admin is able to read workers",
))
} }
async fn registration_token(ctx: &Context) -> Result<String> { async fn registration_token(ctx: &Context) -> Result<String> {
if ctx.locator.auth().is_admin_initialized().await? { if let Some(claims) = &ctx.claims {
if let Some(claims) = &ctx.claims { if claims.user_info().is_admin() {
if claims.user_info().is_admin() { let token = ctx.locator.worker().read_registration_token().await?;
let token = ctx.locator.worker().read_registration_token().await?; return Ok(token);
return Ok(token);
}
} }
Err(CoreError::Unauthorized(
"Only admin is able to read registeration_token",
))
} else {
let token = ctx.locator.worker().read_registration_token().await?;
Ok(token)
} }
Err(CoreError::Unauthorized(
"Only admin is able to read registeration_token",
))
} }
async fn is_admin_initialized(ctx: &Context) -> Result<bool> { async fn is_admin_initialized(ctx: &Context) -> Result<bool> {

View File

@ -48,10 +48,7 @@ impl ServerContext {
async fn authorize_request(&self, request: &Request<Body>) -> bool { async fn authorize_request(&self, request: &Request<Body>) -> bool {
let path = request.uri().path(); let path = request.uri().path();
if (path.starts_with("/v1/") || path.starts_with("/v1beta/")) if path.starts_with("/v1/") || path.starts_with("/v1beta/") {
// Authorization is enabled
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
{
let token = { let token = {
let authorization = request let authorization = request
.headers() .headers()