Parcourir la source

Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2

tags/2.0.0-beta.1
twwu il y a 3 mois
Parent
révision
1b6a925b34

+ 62
- 15
api/services/datasource_provider_service.py Voir le fichier

@@ -54,7 +54,21 @@ class DatasourceProviderService:
)
if not datasource_provider:
return {}
return datasource_provider.encrypted_credentials

encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)

# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials

def get_real_credential_by_id(
self, tenant_id: str, credential_id: str, provider: str, plugin_id: str
@@ -367,21 +381,52 @@ class DatasourceProviderService:
update datasource oauth provider
"""
with Session(db.engine) as session:
target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
if target_provider is None:
raise ValueError("provider not found")
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
)
if target_provider is None:
raise ValueError("provider not found")

provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
db_provider_name = name
if not db_provider_name:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
)
.count()
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)

target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
target_provider.name = name or target_provider.name
session.commit()
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)

target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()

def add_datasource_oauth_provider(
self,
@@ -672,7 +717,9 @@ class DatasourceProviderService:
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{
"provider": datasource.provider,

+ 22
- 0
web/app/components/header/account-setting/data-source-page-new/card.tsx Voir le fichier

@@ -1,6 +1,7 @@
import {
memo,
useCallback,
useRef,
} from 'react'
import { useTranslation } from 'react-i18next'
import Item from './item'
@@ -17,6 +18,8 @@ import {
} from '@/app/components/plugins/plugin-auth'
import { useDataSourceAuthUpdate } from './hooks'
import Confirm from '@/app/components/base/confirm'
import { useGetDataSourceOAuthUrl } from '@/service/use-datasource'
import { openOAuthPopup } from '@/hooks/use-oauth'

type CardProps = {
item: DataSourceAuth
@@ -55,6 +58,20 @@ const Card = ({
closeConfirm,
pendingOperationCredentialId,
} = usePluginAuthAction(pluginPayload, handleAuthUpdate)
const changeCredentialIdRef = useRef<string | undefined>(undefined)
const {
mutateAsync: getPluginOAuthUrl,
} = useGetDataSourceOAuthUrl(pluginPayload.provider)
const handleOAuth = useCallback(async () => {
const { authorization_url } = await getPluginOAuthUrl(changeCredentialIdRef.current)

if (authorization_url) {
openOAuthPopup(
authorization_url,
handleAuthUpdate,
)
}
}, [getPluginOAuthUrl, handleAuthUpdate])
const handleAction = useCallback((
action: string,
credentialItem: DataSourceCredential,
@@ -78,6 +95,11 @@ const Card = ({

if (action === 'rename')
handleRename(renamePayload as any)

if (action === 'change') {
changeCredentialIdRef.current = credentialItem.id
handleOAuth()
}
}, [
openConfirm,
handleEdit,

+ 19
- 1
web/service/use-datasource.ts Voir le fichier

@@ -1,4 +1,7 @@
import { useQuery } from '@tanstack/react-query'
import {
useMutation,
useQuery,
} from '@tanstack/react-query'
import { get } from './base'
import { useInvalid } from './use-base'
import type { DataSourceAuth } from '@/app/components/header/account-setting/data-source-page-new/types'
@@ -31,3 +34,18 @@ export const useInvalidDefaultDataSourceListAuth = (
) => {
return useInvalid([NAME_SPACE, 'default-list'])
}
export const useGetDataSourceOAuthUrl = (
provider: string,
) => {
return useMutation({
mutationKey: [NAME_SPACE, 'oauth-url', provider],
mutationFn: (credentialId?: string) => {
return get<
{
authorization_url: string
state: string
context_id: string
}>(`/oauth/plugin/${provider}/datasource/get-authorization-url?credential_id=${credentialId}`)
},
})
}

Chargement…
Annuler
Enregistrer