support authentication

support-auth-token
Meng Zhang 2023-12-08 19:30:54 +08:00
parent 2b8a07baa0
commit 15c4d97437
11 changed files with 116 additions and 45 deletions

View File

@ -1,6 +1,6 @@
import * as React from 'react' import * as React from 'react'
import { UseChatHelpers } from 'ai/react' import { UseChatHelpers } from 'ai/react'
import { debounce, has } from 'lodash-es' import { debounce, has, isEqual } from 'lodash-es'
import useSWR from 'swr' import useSWR from 'swr'
import { useEnterSubmit } from '@/lib/hooks/use-enter-submit' import { useEnterSubmit } from '@/lib/hooks/use-enter-submit'
@ -26,6 +26,7 @@ import {
TooltipContent, TooltipContent,
TooltipTrigger TooltipTrigger
} from '@/components/ui/tooltip' } from '@/components/ui/tooltip'
import { useSession } from '@/lib/tabby/auth'
export interface PromptProps export interface PromptProps
extends Pick<UseChatHelpers, 'input' | 'setInput'> { extends Pick<UseChatHelpers, 'input' | 'setInput'> {
@ -45,7 +46,6 @@ function PromptFormRenderer(
const [queryCompletionUrl, setQueryCompletionUrl] = React.useState< const [queryCompletionUrl, setQueryCompletionUrl] = React.useState<
string | null string | null
>(null) >(null)
const latestFetchKey = React.useRef('')
const inputRef = React.useRef<HTMLTextAreaElement>(null) const inputRef = React.useRef<HTMLTextAreaElement>(null)
// store the input selection for replacing inputValue // store the input selection for replacing inputValue
const prevInputSelectionEnd = React.useRef<number>() const prevInputSelectionEnd = React.useRef<number>()
@ -56,11 +56,11 @@ function PromptFormRenderer(
Record<string, ISearchHit> Record<string, ISearchHit>
>({}) >({})
useSWR<SearchReponse>(queryCompletionUrl, fetcher, { const { data } = useSession();
useSWR<SearchReponse>([queryCompletionUrl, data?.accessToken], fetcher, {
revalidateOnFocus: false, revalidateOnFocus: false,
dedupingInterval: 0, dedupingInterval: 0,
onSuccess: (data, key) => { onSuccess: (data) => {
if (key !== latestFetchKey.current) return
setOptions(data?.hits ?? []) setOptions(data?.hits ?? [])
} }
}) })
@ -102,7 +102,6 @@ function PromptFormRenderer(
if (queryName) { if (queryName) {
const query = encodeURIComponent(`name:${queryName} AND kind:function`) const query = encodeURIComponent(`name:${queryName} AND kind:function`)
const url = `/v1beta/search?q=${query}` const url = `/v1beta/search?q=${query}`
latestFetchKey.current = url
setQueryCompletionUrl(url) setQueryCompletionUrl(url)
} else { } else {
setOptions([]) setOptions([])

View File

@ -1,9 +1,9 @@
'use client' 'use client'
import { SWRResponse } from 'swr' import useSWR, { SWRResponse } from 'swr'
import useSWRImmutable from 'swr/immutable'
import fetcher from '@/lib/tabby/fetcher' import fetcher from '@/lib/tabby/fetcher'
import { useSession } from '../tabby/auth'
export interface HealthInfo { export interface HealthInfo {
device: 'metal' | 'cpu' | 'cuda' device: 'metal' | 'cpu' | 'cuda'
@ -19,5 +19,6 @@ export interface HealthInfo {
} }
export function useHealth(): SWRResponse<HealthInfo> { export function useHealth(): SWRResponse<HealthInfo> {
return useSWRImmutable('/v1/health', fetcher) const { data } = useSession()
return useSWR(['/v1/health', data?.accessToken], fetcher)
} }

View File

@ -5,30 +5,43 @@ import {
StreamingTextResponse, StreamingTextResponse,
type AIStreamCallbacksAndOptions type AIStreamCallbacksAndOptions
} from 'ai' } from 'ai'
import { useSession } from '../tabby/auth'
const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || '' const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || ''
export function usePatchFetch() { export function usePatchFetch() {
const { data } = useSession()
useEffect(() => { useEffect(() => {
const fetch = window.fetch if (!(window as any)._originFetch) {
(window as any)._originFetch = window.fetch;
}
const fetch = (window as any)._originFetch as (typeof window.fetch);
window.fetch = async function (url, options) { window.fetch = async function (url, options) {
if (url !== '/api/chat') { if (url !== '/api/chat') {
return fetch(url, options) return fetch(url, options)
} }
const headers: HeadersInit = {
'Content-Type': 'application/json',
}
if (data?.accessToken) {
headers["Authorization"] = `Bearer ${data?.accessToken}`;
}
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, { const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
...options, ...options,
method: 'POST', method: 'POST',
headers: { headers,
'Content-Type': 'application/json'
}
}) })
const stream = StreamAdapter(res, undefined) const stream = StreamAdapter(res, undefined)
return new StreamingTextResponse(stream) return new StreamingTextResponse(stream)
} }
}, []) }, [data?.accessToken])
} }
const utf8Decoder = new TextDecoder('utf-8') const utf8Decoder = new TextDecoder('utf-8')

View File

@ -210,6 +210,7 @@ function useSignOut(): () => Promise<void> {
interface User { interface User {
email: string email: string
isAdmin: boolean isAdmin: boolean
accessToken: string
} }
type Session = type Session =
@ -231,7 +232,8 @@ function useSession(): Session {
return { return {
data: { data: {
email: user.email, email: user.email,
isAdmin: user.is_admin isAdmin: user.is_admin,
accessToken: authState.data.accessToken
}, },
status: authState.status status: authState.status
} }

View File

@ -1,9 +1,12 @@
export default function fetcher(url: string): Promise<any> { export default function tokenFetcher([url, token]: Array<string | undefined>): Promise<any> {
if (process.env.NODE_ENV === 'production') { const headers = new Headers();
return fetch(url).then(x => x.json()) if (token) {
} else { headers.append("authorization", `Bearer ${token}`)
return fetch(`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`).then(x =>
x.json()
)
} }
if (process.env.NODE_ENV !== 'production') {
url = `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`;
}
return fetch(url!, { headers }).then(x => x.json())
} }

View File

@ -2,6 +2,7 @@ import { TypedDocumentNode } from '@graphql-typed-document-node/core'
import { GraphQLClient, Variables } from 'graphql-request' import { GraphQLClient, Variables } from 'graphql-request'
import { GraphQLResponse } from 'graphql-request/build/esm/types' import { GraphQLResponse } from 'graphql-request/build/esm/types'
import useSWR, { SWRConfiguration, SWRResponse } from 'swr' import useSWR, { SWRConfiguration, SWRResponse } from 'swr'
import { useSession } from './auth'
export const gqlClient = new GraphQLClient( export const gqlClient = new GraphQLClient(
`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql` `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql`
@ -26,10 +27,18 @@ export function useGraphQLForm<
onError?: (path: string, message: string) => void onError?: (path: string, message: string) => void
} }
) { ) {
const onSubmit = async (values: TVariables) => { const { data } = useSession();
const accessToken = data?.accessToken;
const onSubmit = async (variables: TVariables) => {
let res let res
try { try {
res = await gqlClient.request(document, values) res = await gqlClient.request({
document,
variables,
requestHeaders: accessToken ? {
"authorization": `Bearer ${accessToken}`
} : undefined
})
} catch (err) { } catch (err) {
const { errors = [] } = (err as any).response as GraphQLResponse const { errors = [] } = (err as any).response as GraphQLResponse
for (const error of errors) { for (const error of errors) {
@ -61,9 +70,16 @@ export function useGraphQLQuery<
variables?: TVariables, variables?: TVariables,
swrConfiguration?: SWRConfiguration<TResult> swrConfiguration?: SWRConfiguration<TResult>
): SWRResponse<TResult> { ): SWRResponse<TResult> {
const { data } = useSession();
return useSWR( return useSWR(
[document, variables], [document, variables, data?.accessToken],
([document, variables]) => gqlClient.request(document, variables), ([document, variables, accessToken]) => gqlClient.request({
document,
variables,
requestHeaders: accessToken ? {
"authorization": `Bearer ${accessToken}`
} : undefined
}),
swrConfiguration swrConfiguration
) )
} }

View File

@ -42,8 +42,12 @@ fn jwt_token_secret() -> String {
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") { let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
Ok(x) => x, Ok(x) => x,
Err(_) => { Err(_) => {
warn!( eprintln!("
r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes." \x1b[93;1mJWT secret is not set\x1b[0m
Tabby server will generate a one-time (non-persisted) JWT secret for the current process.
Please set the \x1b[94mTABBY_WEBSERVER_JWT_TOKEN_SECRET\x1b[0m environment variable for production usage.
"
); );
Uuid::new_v4().to_string() Uuid::new_v4().to_string()
} }
@ -51,6 +55,7 @@ fn jwt_token_secret() -> String {
if uuid::Uuid::parse_str(&jwt_secret).is_err() { if uuid::Uuid::parse_str(&jwt_secret).is_err() {
warn!("JWT token secret needs to be in standard uuid format to ensure its security, you might generate one at https://www.uuidgenerator.net"); warn!("JWT token secret needs to be in standard uuid format to ensure its security, you might generate one at https://www.uuidgenerator.net");
std::process::exit(1)
} }
jwt_secret jwt_secret
@ -280,7 +285,7 @@ pub trait AuthenticationService: Send + Sync {
&self, &self,
refresh_token: String, refresh_token: String,
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>; ) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse>; async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> Result<bool>; async fn is_admin_initialized(&self) -> Result<bool>;
async fn create_invitation(&self, email: String) -> Result<i32>; async fn create_invitation(&self, email: String) -> Result<i32>;
@ -293,7 +298,7 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_generate_jwt() { fn test_generate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false)); let claims = Claims::new(UserInfo::new("test".to_string(), false, "cde".to_owned()));
let token = generate_jwt(claims).unwrap(); let token = generate_jwt(claims).unwrap();
assert!(!token.is_empty()) assert!(!token.is_empty())
@ -301,12 +306,13 @@ mod tests {
#[test] #[test]
fn test_validate_jwt() { fn test_validate_jwt() {
let claims = Claims::new(UserInfo::new("test".to_string(), false)); let user = UserInfo::new("test".to_string(), false, "cde".to_owned());
let claims = Claims::new(user.clone());
let token = generate_jwt(claims).unwrap(); let token = generate_jwt(claims).unwrap();
let claims = validate_jwt(&token).unwrap(); let claims = validate_jwt(&token).unwrap();
assert_eq!( assert_eq!(
claims.user_info(), claims.user_info(),
&UserInfo::new("test".to_string(), false) &user,
); );
} }

View File

@ -71,13 +71,37 @@ pub struct Query;
#[graphql_object(context = Context)] #[graphql_object(context = Context)]
impl Query { impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> { async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
ctx.locator.worker().list_workers().await if ctx.locator.auth().is_admin_initialized().await? {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let workers = ctx.locator.worker().list_workers().await;
return Ok(workers);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read workers",
))
} else {
Ok(ctx.locator.worker().list_workers().await)
}
} }
async fn registration_token(ctx: &Context) -> Result<String> { async fn registration_token(ctx: &Context) -> Result<String> {
let token = ctx.locator.worker().read_registration_token().await?; if ctx.locator.auth().is_admin_initialized().await? {
Ok(token) if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let token = ctx.locator.worker().read_registration_token().await?;
return Ok(token);
}
}
Err(CoreError::Unauthorized(
"Only admin is able to read registeration_token",
))
} else {
let token = ctx.locator.worker().read_registration_token().await?;
Ok(token)
}
} }
async fn is_admin_initialized(ctx: &Context) -> Result<bool> { async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
@ -142,7 +166,7 @@ impl Mutation {
} }
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> { async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
Ok(ctx.locator.auth().verify_access_token(token).await?) Ok(ctx.locator.auth().verify_access_token(&token).await?)
} }
async fn refresh_token( async fn refresh_token(

View File

@ -220,8 +220,8 @@ impl AuthenticationService for DbConn {
Ok(resp) Ok(resp)
} }
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse> { async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse> {
let claims = validate_jwt(&access_token)?; let claims = validate_jwt(access_token)?;
let resp = VerifyTokenResponse::new(claims); let resp = VerifyTokenResponse::new(claims);
Ok(resp) Ok(resp)
} }

View File

@ -18,7 +18,7 @@ pub struct User {
pub is_admin: bool, pub is_admin: bool,
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints. /// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
auth_token: String, pub auth_token: String,
} }
impl User { impl User {
@ -54,7 +54,7 @@ impl DbConn {
let mut stmt = c.prepare( let mut stmt = c.prepare(
r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#, r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
)?; )?;
let id = stmt.insert((email, password_encrypted, is_admin, Uuid::new_v4().to_string()))?; let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?;
Ok(id) Ok(id)
}) })
.await?; .await?;
@ -132,6 +132,11 @@ impl DbConn {
} }
} }
fn generate_auth_token() -> String {
let uuid = Uuid::new_v4().to_string().replace("-", "");
format!("auth_{}", uuid)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -52,7 +52,7 @@ impl ServerContext {
// Authorization is enabled // Authorization is enabled
&& self.db_conn.is_admin_initialized().await.unwrap_or(false) && self.db_conn.is_admin_initialized().await.unwrap_or(false)
{ {
let auth_token = { let token = {
let authorization = request let authorization = request
.headers() .headers()
.get("authorization") .get("authorization")
@ -71,8 +71,10 @@ impl ServerContext {
} }
}; };
if let Some(auth_token) = auth_token { if let Some(token) = token {
if !self.db_conn.verify_auth_token(auth_token).await { if !self.db_conn.verify_access_token(token).await.is_ok()
&& !self.db_conn.verify_auth_token(token).await
{
return false; return false;
} }
} else { } else {