feat(tabby-ui): properly set accessToken for api access if user is login (#989)

* feat(tabby-ui): properly set accessToken for api access if user is login

* [autofix.ci] apply automated fixes

* fix: use return data from useSWR (#991)

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: aliang <1098486429@qq.com>
r0.7
Meng Zhang 2023-12-09 00:15:26 +08:00 committed by GitHub
parent d060888b5c
commit 8c02c22373
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 80 additions and 30 deletions

View File

@ -1,9 +1,10 @@
import * as React from 'react' import * as React from 'react'
import { UseChatHelpers } from 'ai/react' import { UseChatHelpers } from 'ai/react'
import { debounce, has } from 'lodash-es' import { debounce, has, isEqual } from 'lodash-es'
import useSWR from 'swr' import useSWR from 'swr'
import { useEnterSubmit } from '@/lib/hooks/use-enter-submit' import { useEnterSubmit } from '@/lib/hooks/use-enter-submit'
import { useSession } from '@/lib/tabby/auth'
import fetcher from '@/lib/tabby/fetcher' import fetcher from '@/lib/tabby/fetcher'
import type { ISearchHit, SearchReponse } from '@/lib/types' import type { ISearchHit, SearchReponse } from '@/lib/types'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
@ -45,7 +46,6 @@ function PromptFormRenderer(
const [queryCompletionUrl, setQueryCompletionUrl] = React.useState< const [queryCompletionUrl, setQueryCompletionUrl] = React.useState<
string | null string | null
>(null) >(null)
const latestFetchKey = React.useRef('')
const inputRef = React.useRef<HTMLTextAreaElement>(null) const inputRef = React.useRef<HTMLTextAreaElement>(null)
// store the input selection for replacing inputValue // store the input selection for replacing inputValue
const prevInputSelectionEnd = React.useRef<number>() const prevInputSelectionEnd = React.useRef<number>()
@ -56,14 +56,20 @@ function PromptFormRenderer(
Record<string, ISearchHit> Record<string, ISearchHit>
>({}) >({})
useSWR<SearchReponse>(queryCompletionUrl, fetcher, { const { data } = useSession()
revalidateOnFocus: false, const { data: completionData } = useSWR<SearchReponse>(
dedupingInterval: 0, [queryCompletionUrl, data?.accessToken],
onSuccess: (data, key) => { fetcher,
if (key !== latestFetchKey.current) return {
setOptions(data?.hits ?? []) revalidateOnFocus: false,
dedupingInterval: 0,
errorRetryCount: 0
} }
}) )
React.useEffect(() => {
setOptions(completionData?.hits ?? [])
}, [completionData?.hits])
React.useImperativeHandle(ref, () => { React.useImperativeHandle(ref, () => {
return { return {
@ -102,7 +108,6 @@ function PromptFormRenderer(
if (queryName) { if (queryName) {
const query = encodeURIComponent(`name:${queryName} AND kind:function`) const query = encodeURIComponent(`name:${queryName} AND kind:function`)
const url = `/v1beta/search?q=${query}` const url = `/v1beta/search?q=${query}`
latestFetchKey.current = url
setQueryCompletionUrl(url) setQueryCompletionUrl(url)
} else { } else {
setOptions([]) setOptions([])

View File

@ -1,10 +1,11 @@
'use client' 'use client'
import { SWRResponse } from 'swr' import useSWR, { SWRResponse } from 'swr'
import useSWRImmutable from 'swr/immutable'
import fetcher from '@/lib/tabby/fetcher' import fetcher from '@/lib/tabby/fetcher'
import { useSession } from '../tabby/auth'
export interface HealthInfo { export interface HealthInfo {
device: 'metal' | 'cpu' | 'cuda' device: 'metal' | 'cpu' | 'cuda'
model?: string model?: string
@ -19,5 +20,6 @@ export interface HealthInfo {
} }
export function useHealth(): SWRResponse<HealthInfo> { export function useHealth(): SWRResponse<HealthInfo> {
return useSWRImmutable('/v1/health', fetcher) const { data } = useSession()
return useSWR(['/v1/health', data?.accessToken], fetcher)
} }

View File

@ -6,29 +6,43 @@ import {
type AIStreamCallbacksAndOptions type AIStreamCallbacksAndOptions
} from 'ai' } from 'ai'
import { useSession } from '../tabby/auth'
const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || '' const serverUrl = process.env.NEXT_PUBLIC_TABBY_SERVER_URL || ''
export function usePatchFetch() { export function usePatchFetch() {
const { data } = useSession()
useEffect(() => { useEffect(() => {
const fetch = window.fetch if (!(window as any)._originFetch) {
;(window as any)._originFetch = window.fetch
}
const fetch = (window as any)._originFetch as typeof window.fetch
window.fetch = async function (url, options) { window.fetch = async function (url, options) {
if (url !== '/api/chat') { if (url !== '/api/chat') {
return fetch(url, options) return fetch(url, options)
} }
const headers: HeadersInit = {
'Content-Type': 'application/json'
}
if (data?.accessToken) {
headers['Authorization'] = `Bearer ${data?.accessToken}`
}
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, { const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
...options, ...options,
method: 'POST', method: 'POST',
headers: { headers
'Content-Type': 'application/json'
}
}) })
const stream = StreamAdapter(res, undefined) const stream = StreamAdapter(res, undefined)
return new StreamingTextResponse(stream) return new StreamingTextResponse(stream)
} }
}, []) }, [data?.accessToken])
} }
const utf8Decoder = new TextDecoder('utf-8') const utf8Decoder = new TextDecoder('utf-8')

View File

@ -210,6 +210,7 @@ function useSignOut(): () => Promise<void> {
interface User { interface User {
email: string email: string
isAdmin: boolean isAdmin: boolean
accessToken: string
} }
type Session = type Session =
@ -231,7 +232,8 @@ function useSession(): Session {
return { return {
data: { data: {
email: user.email, email: user.email,
isAdmin: user.is_admin isAdmin: user.is_admin,
accessToken: authState.data.accessToken
}, },
status: authState.status status: authState.status
} }

View File

@ -1,9 +1,14 @@
export default function fetcher(url: string): Promise<any> { export default function tokenFetcher([url, token]: Array<
if (process.env.NODE_ENV === 'production') { string | undefined
return fetch(url).then(x => x.json()) >): Promise<any> {
} else { const headers = new Headers()
return fetch(`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`).then(x => if (token) {
x.json() headers.append('authorization', `Bearer ${token}`)
)
} }
if (process.env.NODE_ENV !== 'production') {
url = `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL}${url}`
}
return fetch(url!, { headers }).then(x => x.json())
} }

View File

@ -3,6 +3,8 @@ import { GraphQLClient, Variables } from 'graphql-request'
import { GraphQLResponse } from 'graphql-request/build/esm/types' import { GraphQLResponse } from 'graphql-request/build/esm/types'
import useSWR, { SWRConfiguration, SWRResponse } from 'swr' import useSWR, { SWRConfiguration, SWRResponse } from 'swr'
import { useSession } from './auth'
export const gqlClient = new GraphQLClient( export const gqlClient = new GraphQLClient(
`${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql` `${process.env.NEXT_PUBLIC_TABBY_SERVER_URL ?? ''}/graphql`
) )
@ -26,10 +28,20 @@ export function useGraphQLForm<
onError?: (path: string, message: string) => void onError?: (path: string, message: string) => void
} }
) { ) {
const onSubmit = async (values: TVariables) => { const { data } = useSession()
const accessToken = data?.accessToken
const onSubmit = async (variables: TVariables) => {
let res let res
try { try {
res = await gqlClient.request(document, values) res = await gqlClient.request({
document,
variables,
requestHeaders: accessToken
? {
authorization: `Bearer ${accessToken}`
}
: undefined
})
} catch (err) { } catch (err) {
const { errors = [] } = (err as any).response as GraphQLResponse const { errors = [] } = (err as any).response as GraphQLResponse
for (const error of errors) { for (const error of errors) {
@ -61,9 +73,19 @@ export function useGraphQLQuery<
variables?: TVariables, variables?: TVariables,
swrConfiguration?: SWRConfiguration<TResult> swrConfiguration?: SWRConfiguration<TResult>
): SWRResponse<TResult> { ): SWRResponse<TResult> {
const { data } = useSession()
return useSWR( return useSWR(
[document, variables], [document, variables, data?.accessToken],
([document, variables]) => gqlClient.request(document, variables), ([document, variables, accessToken]) =>
gqlClient.request({
document,
variables,
requestHeaders: accessToken
? {
authorization: `Bearer ${accessToken}`
}
: undefined
}),
swrConfiguration swrConfiguration
) )
} }