support authentication
parent
2b8a07baa0
commit
15c4d97437
|
|
@ -1,6 +1,6 @@
|
||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import { UseChatHelpers } from 'ai/react'
|
import { UseChatHelpers } from 'ai/react'
|
||||||
import { debounce, has } from 'lodash-es'
|
import { debounce, has, isEqual } from 'lodash-es'
|
||||||
import useSWR from 'swr'
|
import useSWR from 'swr'
|
||||||
|
|
||||||
import { useEnterSubmit } from '@/lib/hooks/use-enter-submit'
|
import { useEnterSubmit } from '@/lib/hooks/use-enter-submit'
|
||||||
|
|
@ -26,6 +26,7 @@ import {
|
||||||
TooltipContent,
|
TooltipContent,
|
||||||
TooltipTrigger
|
TooltipTrigger
|
||||||
} from '@/components/ui/tooltip'
|
} from '@/components/ui/tooltip'
|
||||||
|
import { useSession } from '@/lib/tabby/auth'
|
||||||
|
|
||||||
export interface PromptProps
|
export interface PromptProps
|
||||||
extends Pick<UseChatHelpers, 'input' | 'setInput'> {
|
extends Pick<UseChatHelpers, 'input' | 'setInput'> {
|
||||||
|
|
@ -45,7 +46,6 @@ function PromptFormRenderer(
|
||||||
const [queryCompletionUrl, setQueryCompletionUrl] = React.useState<
|
const [queryCompletionUrl, setQueryCompletionUrl] = React.useState<
|
||||||
string | null
|
string | null
|
||||||
>(null)
|
>(null)
|
||||||
const latestFetchKey = React.useRef('')
|
|
||||||
const inputRef = React.useRef<HTMLTextAreaElement>(null)
|
const inputRef = React.useRef<HTMLTextAreaElement>(null)
|
||||||
// store the input selection for replacing inputValue
|
// store the input selection for replacing inputValue
|
||||||
const prevInputSelectionEnd = React.useRef<number>()
|
const prevInputSelectionEnd = React.useRef<number>()
|
||||||
|
|
@ -56,11 +56,11 @@ function PromptFormRenderer(
|
||||||
Record<string, ISearchHit>
|
Record<string, ISearchHit>
|
||||||
>({})
|
>({})
|
||||||
|
|
||||||
useSWR<SearchReponse>(queryCompletionUrl, fetcher, {
|
const { data } = useSession();
|
||||||
|
useSWR<SearchReponse>([queryCompletionUrl, data?.accessToken], fetcher, {
|
||||||
revalidateOnFocus: false,
|
revalidateOnFocus: false,
|
||||||
dedupingInterval: 0,
|
dedupingInterval: 0,
|
||||||
onSuccess: (data, key) => {
|
onSuccess: (data) => {
|
||||||
if (key !== latestFetchKey.current) return
|
|
||||||
setOptions(data?.hits ?? [])
|
setOptions(data?.hits ?? [])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -102,7 +102,6 @@ function PromptFormRenderer(
|
||||||
if (queryName) {
|
if (queryName) {
|
||||||
const query = encodeURIComponent(`name:${queryName} AND kind:function`)
|
const query = encodeURIComponent(`name:${queryName} AND kind:function`)
|
||||||
const url = `/v1beta/search?q=${query}`
|
const url = `/v1beta/search?q=${query}`
|
||||||
latestFetchKey.current = url
|
|
||||||
setQueryCompletionUrl(url)
|
setQueryCompletionUrl(url)
|
||||||
} else {
|
} else {
|
||||||
setOptions([])
|
setOptions([])
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { SWRResponse } from 'swr'
|
import useSWR, { SWRResponse } from 'swr'
|
||||||
import useSWRImmutable from 'swr/immutable'
|
|
||||||
|
|
||||||
import fetcher from '@/lib/tabby/fetcher'
|
import fetcher from '@/lib/tabby/fetcher'
|
||||||
|
import { useSession } from '../tabby/auth'
|
||||||
|
|
||||||
export interface HealthInfo {
|
export interface HealthInfo {
|
||||||
device: 'metal' | 'cpu' | 'cuda'
|
device: 'metal' | 'cpu' | 'cuda'
|
||||||
|
|
@ -19,5 +19,6 @@ export interface HealthInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useHealth(): SWRResponse<HealthInfo> {
|
export function useHealth(): SWRResponse<HealthInfo> {
|
||||||
return useSWRImmutable('/v1/health', fetcher)
|
const { data } = useSession()
|
||||||
|
return useSWR(['/v1/health', data?.accessToken], fetcher)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,30 +5,43 @@ import {
|
||||||
StreamingTextResponse,
|
StreamingTextResponse,
|
||||||
type AIStreamCallbacksAndOptions
|
type AIStreamCallbacksAndOptions
|
||||||
} from 'ai'
|
} from 'ai'
|
||||||
|
import { useSession } from '../tabby/auth'
|
||||||
|
|
||||||
const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || ''
|
const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || ''
|
||||||
|
|
||||||
export function usePatchFetch() {
|
export function usePatchFetch() {
|
||||||
|
const { data } = useSession()
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetch = window.fetch
|
if (!(window as any)._originFetch) {
|
||||||
|
(window as any)._originFetch = window.fetch;
|
||||||
|
}
|
||||||
|
|
||||||
|
const fetch = (window as any)._originFetch as (typeof window.fetch);
|
||||||
|
|
||||||
window.fetch = async function (url, options) {
|
window.fetch = async function (url, options) {
|
||||||
if (url !== '/api/chat') {
|
if (url !== '/api/chat') {
|
||||||
return fetch(url, options)
|
return fetch(url, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const headers: HeadersInit = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data?.accessToken) {
|
||||||
|
headers["Authorization"] = `Bearer ${data?.accessToken}`;
|
||||||
|
}
|
||||||
|
|
||||||
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
|
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
|
||||||
...options,
|
...options,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const stream = StreamAdapter(res, undefined)
|
const stream = StreamAdapter(res, undefined)
|
||||||
return new StreamingTextResponse(stream)
|
return new StreamingTextResponse(stream)
|
||||||
}
|
}
|
||||||
}, [])
|
}, [data?.accessToken])
|
||||||
}
|
}
|
||||||
|
|
||||||
const utf8Decoder = new TextDecoder('utf-8')
|
const utf8Decoder = new TextDecoder('utf-8')
|
||||||
|
|
|
||||||
|
|
@ -210,6 +210,7 @@ function useSignOut(): () => Promise<void> {
|
||||||
interface User {
|
interface User {
|
||||||
email: string
|
email: string
|
||||||
isAdmin: boolean
|
isAdmin: boolean
|
||||||
|
accessToken: string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session =
|
type Session =
|
||||||
|
|
@ -231,7 +232,8 @@ function useSession(): Session {
|
||||||
return {
|
return {
|
||||||
data: {
|
data: {
|
||||||
email: user.email,
|
email: user.email,
|
||||||
isAdmin: user.is_admin
|
isAdmin: user.is_admin,
|
||||||
|
accessToken: authState.data.accessToken
|
||||||
},
|
},
|
||||||
status: authState.status
|
status: authState.status
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
export default function fetcher(url: string): Promise<any> {
|
export default function tokenFetcher([url, token]: Array<string | undefined>): Promise<any> {
|
||||||
if (process.env.NODE_ENV === 'production') {
|
const headers = new Headers();
|
||||||
return fetch(url).then(x => x.json())
|
if (token) {
|
||||||
} else {
|
headers.append("authorization", `Bearer ${token}`)
|
||||||
return fetch(`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`).then(x =>
|
|
||||||
x.json()
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (process.env.NODE_ENV !== 'production') {
|
||||||
|
url = `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return fetch(url!, { headers }).then(x => x.json())
|
||||||
}
|
}
|
||||||
|
|
@ -2,6 +2,7 @@ import { TypedDocumentNode } from '@graphql-typed-document-node/core'
|
||||||
import { GraphQLClient, Variables } from 'graphql-request'
|
import { GraphQLClient, Variables } from 'graphql-request'
|
||||||
import { GraphQLResponse } from 'graphql-request/build/esm/types'
|
import { GraphQLResponse } from 'graphql-request/build/esm/types'
|
||||||
import useSWR, { SWRConfiguration, SWRResponse } from 'swr'
|
import useSWR, { SWRConfiguration, SWRResponse } from 'swr'
|
||||||
|
import { useSession } from './auth'
|
||||||
|
|
||||||
export const gqlClient = new GraphQLClient(
|
export const gqlClient = new GraphQLClient(
|
||||||
`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql`
|
`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql`
|
||||||
|
|
@ -26,10 +27,18 @@ export function useGraphQLForm<
|
||||||
onError?: (path: string, message: string) => void
|
onError?: (path: string, message: string) => void
|
||||||
}
|
}
|
||||||
) {
|
) {
|
||||||
const onSubmit = async (values: TVariables) => {
|
const { data } = useSession();
|
||||||
|
const accessToken = data?.accessToken;
|
||||||
|
const onSubmit = async (variables: TVariables) => {
|
||||||
let res
|
let res
|
||||||
try {
|
try {
|
||||||
res = await gqlClient.request(document, values)
|
res = await gqlClient.request({
|
||||||
|
document,
|
||||||
|
variables,
|
||||||
|
requestHeaders: accessToken ? {
|
||||||
|
"authorization": `Bearer ${accessToken}`
|
||||||
|
} : undefined
|
||||||
|
})
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
const { errors = [] } = (err as any).response as GraphQLResponse
|
const { errors = [] } = (err as any).response as GraphQLResponse
|
||||||
for (const error of errors) {
|
for (const error of errors) {
|
||||||
|
|
@ -61,9 +70,16 @@ export function useGraphQLQuery<
|
||||||
variables?: TVariables,
|
variables?: TVariables,
|
||||||
swrConfiguration?: SWRConfiguration<TResult>
|
swrConfiguration?: SWRConfiguration<TResult>
|
||||||
): SWRResponse<TResult> {
|
): SWRResponse<TResult> {
|
||||||
|
const { data } = useSession();
|
||||||
return useSWR(
|
return useSWR(
|
||||||
[document, variables],
|
[document, variables, data?.accessToken],
|
||||||
([document, variables]) => gqlClient.request(document, variables),
|
([document, variables, accessToken]) => gqlClient.request({
|
||||||
|
document,
|
||||||
|
variables,
|
||||||
|
requestHeaders: accessToken ? {
|
||||||
|
"authorization": `Bearer ${accessToken}`
|
||||||
|
} : undefined
|
||||||
|
}),
|
||||||
swrConfiguration
|
swrConfiguration
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,12 @@ fn jwt_token_secret() -> String {
|
||||||
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
|
let jwt_secret = match std::env::var("TABBY_WEBSERVER_JWT_TOKEN_SECRET") {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
warn!(
|
eprintln!("
|
||||||
r"TABBY_WEBSERVER_JWT_TOKEN_SECRET is not set. Tabby generates a one-time (non-persisted) JWT token solely for testing purposes."
|
\x1b[93;1mJWT secret is not set\x1b[0m
|
||||||
|
|
||||||
|
Tabby server will generate a one-time (non-persisted) JWT secret for the current process.
|
||||||
|
Please set the \x1b[94mTABBY_WEBSERVER_JWT_TOKEN_SECRET\x1b[0m environment variable for production usage.
|
||||||
|
"
|
||||||
);
|
);
|
||||||
Uuid::new_v4().to_string()
|
Uuid::new_v4().to_string()
|
||||||
}
|
}
|
||||||
|
|
@ -51,6 +55,7 @@ fn jwt_token_secret() -> String {
|
||||||
|
|
||||||
if uuid::Uuid::parse_str(&jwt_secret).is_err() {
|
if uuid::Uuid::parse_str(&jwt_secret).is_err() {
|
||||||
warn!("JWT token secret needs to be in standard uuid format to ensure its security, you might generate one at https://www.uuidgenerator.net");
|
warn!("JWT token secret needs to be in standard uuid format to ensure its security, you might generate one at https://www.uuidgenerator.net");
|
||||||
|
std::process::exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwt_secret
|
jwt_secret
|
||||||
|
|
@ -280,7 +285,7 @@ pub trait AuthenticationService: Send + Sync {
|
||||||
&self,
|
&self,
|
||||||
refresh_token: String,
|
refresh_token: String,
|
||||||
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
) -> std::result::Result<RefreshTokenResponse, RefreshTokenError>;
|
||||||
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse>;
|
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse>;
|
||||||
async fn is_admin_initialized(&self) -> Result<bool>;
|
async fn is_admin_initialized(&self) -> Result<bool>;
|
||||||
|
|
||||||
async fn create_invitation(&self, email: String) -> Result<i32>;
|
async fn create_invitation(&self, email: String) -> Result<i32>;
|
||||||
|
|
@ -293,7 +298,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
#[test]
|
#[test]
|
||||||
fn test_generate_jwt() {
|
fn test_generate_jwt() {
|
||||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
let claims = Claims::new(UserInfo::new("test".to_string(), false, "cde".to_owned()));
|
||||||
let token = generate_jwt(claims).unwrap();
|
let token = generate_jwt(claims).unwrap();
|
||||||
|
|
||||||
assert!(!token.is_empty())
|
assert!(!token.is_empty())
|
||||||
|
|
@ -301,12 +306,13 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_jwt() {
|
fn test_validate_jwt() {
|
||||||
let claims = Claims::new(UserInfo::new("test".to_string(), false));
|
let user = UserInfo::new("test".to_string(), false, "cde".to_owned());
|
||||||
|
let claims = Claims::new(user.clone());
|
||||||
let token = generate_jwt(claims).unwrap();
|
let token = generate_jwt(claims).unwrap();
|
||||||
let claims = validate_jwt(&token).unwrap();
|
let claims = validate_jwt(&token).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
claims.user_info(),
|
claims.user_info(),
|
||||||
&UserInfo::new("test".to_string(), false)
|
&user,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,14 +71,38 @@ pub struct Query;
|
||||||
|
|
||||||
#[graphql_object(context = Context)]
|
#[graphql_object(context = Context)]
|
||||||
impl Query {
|
impl Query {
|
||||||
async fn workers(ctx: &Context) -> Vec<Worker> {
|
async fn workers(ctx: &Context) -> Result<Vec<Worker>> {
|
||||||
ctx.locator.worker().list_workers().await
|
if ctx.locator.auth().is_admin_initialized().await? {
|
||||||
|
if let Some(claims) = &ctx.claims {
|
||||||
|
if claims.user_info().is_admin() {
|
||||||
|
let workers = ctx.locator.worker().list_workers().await;
|
||||||
|
return Ok(workers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(CoreError::Unauthorized(
|
||||||
|
"Only admin is able to read workers",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok(ctx.locator.worker().list_workers().await)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn registration_token(ctx: &Context) -> Result<String> {
|
async fn registration_token(ctx: &Context) -> Result<String> {
|
||||||
|
if ctx.locator.auth().is_admin_initialized().await? {
|
||||||
|
if let Some(claims) = &ctx.claims {
|
||||||
|
if claims.user_info().is_admin() {
|
||||||
|
let token = ctx.locator.worker().read_registration_token().await?;
|
||||||
|
return Ok(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(CoreError::Unauthorized(
|
||||||
|
"Only admin is able to read registeration_token",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
let token = ctx.locator.worker().read_registration_token().await?;
|
let token = ctx.locator.worker().read_registration_token().await?;
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
|
async fn is_admin_initialized(ctx: &Context) -> Result<bool> {
|
||||||
Ok(ctx.locator.auth().is_admin_initialized().await?)
|
Ok(ctx.locator.auth().is_admin_initialized().await?)
|
||||||
|
|
@ -142,7 +166,7 @@ impl Mutation {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
async fn verify_token(ctx: &Context, token: String) -> Result<VerifyTokenResponse> {
|
||||||
Ok(ctx.locator.auth().verify_access_token(token).await?)
|
Ok(ctx.locator.auth().verify_access_token(&token).await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn refresh_token(
|
async fn refresh_token(
|
||||||
|
|
|
||||||
|
|
@ -220,8 +220,8 @@ impl AuthenticationService for DbConn {
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn verify_access_token(&self, access_token: String) -> Result<VerifyTokenResponse> {
|
async fn verify_access_token(&self, access_token: &str) -> Result<VerifyTokenResponse> {
|
||||||
let claims = validate_jwt(&access_token)?;
|
let claims = validate_jwt(access_token)?;
|
||||||
let resp = VerifyTokenResponse::new(claims);
|
let resp = VerifyTokenResponse::new(claims);
|
||||||
Ok(resp)
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ pub struct User {
|
||||||
pub is_admin: bool,
|
pub is_admin: bool,
|
||||||
|
|
||||||
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
|
/// To authenticate IDE extensions / plugins to access code completion / chat api endpoints.
|
||||||
auth_token: String,
|
pub auth_token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl User {
|
impl User {
|
||||||
|
|
@ -54,7 +54,7 @@ impl DbConn {
|
||||||
let mut stmt = c.prepare(
|
let mut stmt = c.prepare(
|
||||||
r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
|
r#"INSERT INTO users (email, password_encrypted, is_admin, auth_token) VALUES (?, ?, ?, ?)"#,
|
||||||
)?;
|
)?;
|
||||||
let id = stmt.insert((email, password_encrypted, is_admin, Uuid::new_v4().to_string()))?;
|
let id = stmt.insert((email, password_encrypted, is_admin, generate_auth_token()))?;
|
||||||
Ok(id)
|
Ok(id)
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -132,6 +132,11 @@ impl DbConn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn generate_auth_token() -> String {
|
||||||
|
let uuid = Uuid::new_v4().to_string().replace("-", "");
|
||||||
|
format!("auth_{}", uuid)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ impl ServerContext {
|
||||||
// Authorization is enabled
|
// Authorization is enabled
|
||||||
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
|
&& self.db_conn.is_admin_initialized().await.unwrap_or(false)
|
||||||
{
|
{
|
||||||
let auth_token = {
|
let token = {
|
||||||
let authorization = request
|
let authorization = request
|
||||||
.headers()
|
.headers()
|
||||||
.get("authorization")
|
.get("authorization")
|
||||||
|
|
@ -71,8 +71,10 @@ impl ServerContext {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth_token) = auth_token {
|
if let Some(token) = token {
|
||||||
if !self.db_conn.verify_auth_token(auth_token).await {
|
if !self.db_conn.verify_access_token(token).await.is_ok()
|
||||||
|
&& !self.db_conn.verify_auth_token(token).await
|
||||||
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue