| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975 |
- import logging
- import time
- from collections.abc import Mapping
- from typing import Any
-
- from flask_login import current_user
- from sqlalchemy.orm import Session
-
- from configs import dify_config
- from constants import HIDDEN_VALUE, UNKNOWN_VALUE
- from core.helper import encrypter
- from core.helper.name_generator import generate_incremental_name
- from core.helper.provider_cache import NoOpProviderCredentialCache
- from core.model_runtime.entities.provider_entities import FormType
- from core.plugin.impl.datasource import PluginDatasourceManager
- from core.plugin.impl.oauth import OAuthHandler
- from core.tools.entities.tool_entities import CredentialType
- from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
- from extensions.ext_database import db
- from extensions.ext_redis import redis_client
- from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
- from models.provider_ids import DatasourceProviderID
- from services.plugin.plugin_service import PluginService
-
- logger = logging.getLogger(__name__)
-
-
- class DatasourceProviderService:
- """
- Model Provider Service
- """
-
- def __init__(self) -> None:
- self.provider_manager = PluginDatasourceManager()
-
- def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
- """
- remove oauth custom client params
- """
- with Session(db.engine) as session:
- session.query(DatasourceOauthTenantParamConfig).filter_by(
- tenant_id=tenant_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- ).delete()
- session.commit()
-
- def decrypt_datasource_provider_credentials(
- self,
- tenant_id: str,
- datasource_provider: DatasourceProvider,
- plugin_id: str,
- provider: str,
- ) -> dict[str, Any]:
- encrypted_credentials = datasource_provider.encrypted_credentials
- credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id,
- provider_id=f"{plugin_id}/{provider}",
- credential_type=CredentialType.of(datasource_provider.auth_type),
- )
- decrypted_credentials = encrypted_credentials.copy()
- for key, value in decrypted_credentials.items():
- if key in credential_secret_variables:
- decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
- return decrypted_credentials
-
- def encrypt_datasource_provider_credentials(
- self,
- tenant_id: str,
- provider: str,
- plugin_id: str,
- raw_credentials: Mapping[str, Any],
- datasource_provider: DatasourceProvider,
- ) -> dict[str, Any]:
- provider_credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id,
- provider_id=f"{plugin_id}/{provider}",
- credential_type=CredentialType.of(datasource_provider.auth_type),
- )
- encrypted_credentials = dict(raw_credentials)
- for key, value in encrypted_credentials.items():
- if key in provider_credential_secret_variables:
- encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
- return encrypted_credentials
-
- def get_datasource_credentials(
- self,
- tenant_id: str,
- provider: str,
- plugin_id: str,
- credential_id: str | None = None,
- ) -> dict[str, Any]:
- """
- get credential by id
- """
- with Session(db.engine) as session:
- if credential_id:
- datasource_provider = (
- session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
- )
- else:
- datasource_provider = (
- session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
- .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
- .first()
- )
- if not datasource_provider:
- return {}
- # refresh the credentials
- if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
- decrypted_credentials = self.decrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- datasource_provider=datasource_provider,
- plugin_id=plugin_id,
- provider=provider,
- )
- datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
- provider_name = datasource_provider_id.provider_name
- redirect_uri = (
- f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
- f"{datasource_provider_id}/datasource/callback"
- )
- system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
- refreshed_credentials = OAuthHandler().refresh_credentials(
- tenant_id=tenant_id,
- user_id=current_user.id,
- plugin_id=datasource_provider_id.plugin_id,
- provider=provider_name,
- redirect_uri=redirect_uri,
- system_credentials=system_credentials or {},
- credentials=decrypted_credentials,
- )
- datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- raw_credentials=refreshed_credentials.credentials,
- provider=provider,
- plugin_id=plugin_id,
- datasource_provider=datasource_provider,
- )
- datasource_provider.expires_at = refreshed_credentials.expires_at
- session.commit()
-
- return self.decrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- datasource_provider=datasource_provider,
- plugin_id=plugin_id,
- provider=provider,
- )
-
- def get_all_datasource_credentials_by_provider(
- self,
- tenant_id: str,
- provider: str,
- plugin_id: str,
- ) -> list[dict[str, Any]]:
- """
- get all datasource credentials by provider
- """
- with Session(db.engine) as session:
- datasource_providers = (
- session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
- .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
- .all()
- )
- if not datasource_providers:
- return []
- # refresh the credentials
- real_credentials_list = []
- for datasource_provider in datasource_providers:
- decrypted_credentials = self.decrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- datasource_provider=datasource_provider,
- plugin_id=plugin_id,
- provider=provider,
- )
- datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
- provider_name = datasource_provider_id.provider_name
- redirect_uri = (
- f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
- f"{datasource_provider_id}/datasource/callback"
- )
- system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
- refreshed_credentials = OAuthHandler().refresh_credentials(
- tenant_id=tenant_id,
- user_id=current_user.id,
- plugin_id=datasource_provider_id.plugin_id,
- provider=provider_name,
- redirect_uri=redirect_uri,
- system_credentials=system_credentials or {},
- credentials=decrypted_credentials,
- )
- datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- raw_credentials=refreshed_credentials.credentials,
- provider=provider,
- plugin_id=plugin_id,
- datasource_provider=datasource_provider,
- )
- datasource_provider.expires_at = refreshed_credentials.expires_at
- real_credentials = self.decrypt_datasource_provider_credentials(
- tenant_id=tenant_id,
- datasource_provider=datasource_provider,
- plugin_id=plugin_id,
- provider=provider,
- )
- real_credentials_list.append(real_credentials)
- session.commit()
-
- return real_credentials_list
-
- def update_datasource_provider_name(
- self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
- ):
- """
- update datasource provider name
- """
- with Session(db.engine) as session:
- target_provider = (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- id=credential_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- )
- .first()
- )
- if target_provider is None:
- raise ValueError("provider not found")
-
- if target_provider.name == name:
- return
-
- # check name is exist
- if (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- name=name,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- )
- .count()
- > 0
- ):
- raise ValueError("Authorization name is already exists")
-
- target_provider.name = name
- session.commit()
- return
-
- def set_default_datasource_provider(
- self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
- ):
- """
- set default datasource provider
- """
- with Session(db.engine) as session:
- # get provider
- target_provider = (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- id=credential_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- )
- .first()
- )
- if target_provider is None:
- raise ValueError("provider not found")
-
- # clear default provider
- session.query(DatasourceProvider).filter_by(
- tenant_id=tenant_id,
- provider=target_provider.provider,
- plugin_id=target_provider.plugin_id,
- is_default=True,
- ).update({"is_default": False})
-
- # set new default provider
- target_provider.is_default = True
- session.commit()
- return {"result": "success"}
-
- def setup_oauth_custom_client_params(
- self,
- tenant_id: str,
- datasource_provider_id: DatasourceProviderID,
- client_params: dict | None,
- enabled: bool | None,
- ):
- """
- setup oauth custom client params
- """
- if client_params is None and enabled is None:
- return
- with Session(db.engine) as session:
- tenant_oauth_client_params = (
- session.query(DatasourceOauthTenantParamConfig)
- .filter_by(
- tenant_id=tenant_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- )
- .first()
- )
-
- if not tenant_oauth_client_params:
- tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
- tenant_id=tenant_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- client_params={},
- enabled=False,
- )
- session.add(tenant_oauth_client_params)
-
- if client_params is not None:
- encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
- original_params = (
- encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
- )
- new_params: dict = {
- key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
- for key, value in client_params.items()
- }
- tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
-
- if enabled is not None:
- tenant_oauth_client_params.enabled = enabled
- session.commit()
-
- def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
- """
- check if system oauth params exist
- """
- with Session(db.engine).no_autoflush as session:
- return (
- session.query(DatasourceOauthParamConfig)
- .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
- .first()
- is not None
- )
-
- def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
- """
- check if tenant oauth params is enabled
- """
- return (
- db.session.query(DatasourceOauthTenantParamConfig)
- .filter_by(
- tenant_id=tenant_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- enabled=True,
- )
- .count()
- > 0
- )
-
- def get_tenant_oauth_client(
- self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
- ) -> dict[str, Any] | None:
- """
- get tenant oauth client
- """
- tenant_oauth_client_params = (
- db.session.query(DatasourceOauthTenantParamConfig)
- .filter_by(
- tenant_id=tenant_id,
- provider=datasource_provider_id.provider_name,
- plugin_id=datasource_provider_id.plugin_id,
- )
- .first()
- )
- if tenant_oauth_client_params:
- encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
- if mask:
- return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
- else:
- return encrypter.decrypt(tenant_oauth_client_params.client_params)
- return None
-
- def get_oauth_encrypter(
- self, tenant_id: str, datasource_provider_id: DatasourceProviderID
- ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
- """
- get oauth encrypter
- """
- datasource_provider = self.provider_manager.fetch_datasource_provider(
- tenant_id=tenant_id, provider_id=str(datasource_provider_id)
- )
- if not datasource_provider.declaration.oauth_schema:
- raise ValueError("Datasource provider oauth schema not found")
-
- client_schema = datasource_provider.declaration.oauth_schema.client_schema
- return create_provider_encrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in client_schema],
- cache=NoOpProviderCredentialCache(),
- )
-
- def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
- """
- get oauth client
- """
- provider = datasource_provider_id.provider_name
- plugin_id = datasource_provider_id.plugin_id
- with Session(db.engine).no_autoflush as session:
- # get tenant oauth client params
- tenant_oauth_client_params = (
- session.query(DatasourceOauthTenantParamConfig)
- .filter_by(
- tenant_id=tenant_id,
- provider=provider,
- plugin_id=plugin_id,
- enabled=True,
- )
- .first()
- )
- if tenant_oauth_client_params:
- encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
- return encrypter.decrypt(tenant_oauth_client_params.client_params)
-
- provider_controller = self.provider_manager.fetch_datasource_provider(
- tenant_id=tenant_id, provider_id=str(datasource_provider_id)
- )
- is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
- if is_verified:
- # fallback to system oauth client params
- oauth_client_params = (
- session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
- )
- if oauth_client_params:
- return oauth_client_params.system_credentials
-
- raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
-
- @staticmethod
- def generate_next_datasource_provider_name(
- session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
- ) -> str:
- db_providers = (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- )
- .all()
- )
- return generate_incremental_name(
- [provider.name for provider in db_providers],
- f"{credential_type.get_name()}",
- )
-
- def reauthorize_datasource_oauth_provider(
- self,
- name: str | None,
- tenant_id: str,
- provider_id: DatasourceProviderID,
- avatar_url: str | None,
- expire_at: int,
- credentials: dict,
- credential_id: str,
- ) -> None:
- """
- update datasource oauth provider
- """
- with Session(db.engine) as session:
- lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
- with redis_client.lock(lock, timeout=20):
- target_provider = (
- session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
- )
- if target_provider is None:
- raise ValueError("provider not found")
-
- db_provider_name = name
- if not db_provider_name:
- db_provider_name = target_provider.name
- else:
- name_conflict = (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- name=db_provider_name,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- auth_type=CredentialType.OAUTH2.value,
- )
- .count()
- )
- if name_conflict > 0:
- db_provider_name = generate_incremental_name(
- [
- provider.name
- for provider in session.query(DatasourceProvider).filter_by(
- tenant_id=tenant_id,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- )
- ],
- db_provider_name,
- )
-
- provider_credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
- )
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- credentials[key] = encrypter.encrypt_token(tenant_id, value)
-
- target_provider.expires_at = expire_at
- target_provider.encrypted_credentials = credentials
- target_provider.avatar_url = avatar_url or target_provider.avatar_url
- session.commit()
-
- def add_datasource_oauth_provider(
- self,
- name: str | None,
- tenant_id: str,
- provider_id: DatasourceProviderID,
- avatar_url: str | None,
- expire_at: int,
- credentials: dict,
- ) -> None:
- """
- add datasource oauth provider
- """
- credential_type = CredentialType.OAUTH2
- with Session(db.engine) as session:
- lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
- with redis_client.lock(lock, timeout=60):
- db_provider_name = name
- if not db_provider_name:
- db_provider_name = self.generate_next_datasource_provider_name(
- session=session,
- tenant_id=tenant_id,
- provider_id=provider_id,
- credential_type=credential_type,
- )
- else:
- if (
- session.query(DatasourceProvider)
- .filter_by(
- tenant_id=tenant_id,
- name=db_provider_name,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- auth_type=credential_type.value,
- )
- .count()
- > 0
- ):
- db_provider_name = generate_incremental_name(
- [
- provider.name
- for provider in session.query(DatasourceProvider).filter_by(
- tenant_id=tenant_id,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- )
- ],
- db_provider_name,
- )
-
- provider_credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
- )
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- credentials[key] = encrypter.encrypt_token(tenant_id, value)
-
- datasource_provider = DatasourceProvider(
- tenant_id=tenant_id,
- name=db_provider_name,
- provider=provider_id.provider_name,
- plugin_id=provider_id.plugin_id,
- auth_type=credential_type.value,
- encrypted_credentials=credentials,
- avatar_url=avatar_url or "default",
- expires_at=expire_at,
- )
- session.add(datasource_provider)
- session.commit()
-
- def add_datasource_api_key_provider(
- self,
- name: str | None,
- tenant_id: str,
- provider_id: DatasourceProviderID,
- credentials: dict,
- ) -> None:
- """
- validate datasource provider credentials.
-
- :param tenant_id:
- :param provider:
- :param credentials:
- """
- provider_name = provider_id.provider_name
- plugin_id = provider_id.plugin_id
- with Session(db.engine) as session:
- lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
- with redis_client.lock(lock, timeout=20):
- db_provider_name = name or self.generate_next_datasource_provider_name(
- session=session,
- tenant_id=tenant_id,
- provider_id=provider_id,
- credential_type=CredentialType.API_KEY,
- )
-
- # check name is exist
- if (
- session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
- .count()
- > 0
- ):
- raise ValueError("Authorization name is already exists")
-
- try:
- self.provider_manager.validate_provider_credentials(
- tenant_id=tenant_id,
- user_id=current_user.id,
- provider=provider_name,
- plugin_id=plugin_id,
- credentials=credentials,
- )
- except Exception as e:
- raise ValueError(f"Failed to validate credentials: {str(e)}")
-
- provider_credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
- )
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- # if send [__HIDDEN__] in secret input, it will be same as original value
- credentials[key] = encrypter.encrypt_token(tenant_id, value)
- datasource_provider = DatasourceProvider(
- tenant_id=tenant_id,
- name=db_provider_name,
- provider=provider_name,
- plugin_id=plugin_id,
- auth_type=CredentialType.API_KEY.value,
- encrypted_credentials=credentials,
- )
- session.add(datasource_provider)
- session.commit()
-
- def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
- """
- Extract secret input form variables.
-
- :param credential_form_schemas:
- :return:
- """
- datasource_provider = self.provider_manager.fetch_datasource_provider(
- tenant_id=tenant_id, provider_id=provider_id
- )
- credential_form_schemas = []
- if credential_type == CredentialType.API_KEY:
- credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
- elif credential_type == CredentialType.OAUTH2:
- if not datasource_provider.declaration.oauth_schema:
- raise ValueError("Datasource provider oauth schema not found")
- credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
- else:
- raise ValueError(f"Invalid credential type: {credential_type}")
-
- secret_input_form_variables = []
- for credential_form_schema in credential_form_schemas:
- if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
- secret_input_form_variables.append(credential_form_schema.name)
-
- return secret_input_form_variables
-
- def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
- """
- list datasource credentials with obfuscated sensitive fields.
-
- :param tenant_id: workspace id
- :param provider_id: provider id
- :return:
- """
- # Get all provider configurations of the current workspace
- datasource_providers: list[DatasourceProvider] = (
- db.session.query(DatasourceProvider)
- .where(
- DatasourceProvider.tenant_id == tenant_id,
- DatasourceProvider.provider == provider,
- DatasourceProvider.plugin_id == plugin_id,
- )
- .all()
- )
- if not datasource_providers:
- return []
- copy_credentials_list = []
- default_provider = (
- db.session.query(DatasourceProvider.id)
- .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
- .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
- .first()
- )
- default_provider_id = default_provider.id if default_provider else None
- for datasource_provider in datasource_providers:
- encrypted_credentials = datasource_provider.encrypted_credentials
- # Get provider credential secret variables
- credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id,
- provider_id=f"{plugin_id}/{provider}",
- credential_type=CredentialType.of(datasource_provider.auth_type),
- )
-
- # Obfuscate provider credentials
- copy_credentials = encrypted_credentials.copy()
- for key, value in copy_credentials.items():
- if key in credential_secret_variables:
- copy_credentials[key] = encrypter.obfuscated_token(value)
- copy_credentials_list.append(
- {
- "credential": copy_credentials,
- "type": datasource_provider.auth_type,
- "name": datasource_provider.name,
- "avatar_url": datasource_provider.avatar_url,
- "id": datasource_provider.id,
- "is_default": default_provider_id and datasource_provider.id == default_provider_id,
- }
- )
-
- return copy_credentials_list
-
- def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
- """
- get datasource credentials.
-
- :return:
- """
- # get all plugin providers
- manager = PluginDatasourceManager()
- datasources = manager.fetch_installed_datasource_providers(tenant_id)
- datasource_credentials = []
- for datasource in datasources:
- datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
- credentials = self.list_datasource_credentials(
- tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
- )
- redirect_uri = (
- f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
- )
- datasource_credentials.append(
- {
- "provider": datasource.provider,
- "plugin_id": datasource.plugin_id,
- "plugin_unique_identifier": datasource.plugin_unique_identifier,
- "icon": datasource.declaration.identity.icon,
- "name": datasource.declaration.identity.name.split("/")[-1],
- "label": datasource.declaration.identity.label.model_dump(),
- "description": datasource.declaration.identity.description.model_dump(),
- "author": datasource.declaration.identity.author,
- "credentials_list": credentials,
- "credential_schema": [
- credential.model_dump() for credential in datasource.declaration.credentials_schema
- ],
- "oauth_schema": {
- "client_schema": [
- client_schema.model_dump()
- for client_schema in datasource.declaration.oauth_schema.client_schema
- ],
- "credentials_schema": [
- credential_schema.model_dump()
- for credential_schema in datasource.declaration.oauth_schema.credentials_schema
- ],
- "oauth_custom_client_params": self.get_tenant_oauth_client(
- tenant_id, datasource_provider_id, mask=True
- ),
- "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
- tenant_id, datasource_provider_id
- ),
- "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
- "redirect_uri": redirect_uri,
- }
- if datasource.declaration.oauth_schema
- else None,
- }
- )
- return datasource_credentials
-
- def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
- """
- get hard code datasource credentials.
-
- :return:
- """
- # get all plugin providers
- manager = PluginDatasourceManager()
- datasources = manager.fetch_installed_datasource_providers(tenant_id)
- datasource_credentials = []
- for datasource in datasources:
- if datasource.plugin_id in [
- "langgenius/firecrawl_datasource",
- "langgenius/notion_datasource",
- "langgenius/jina_datasource",
- ]:
- datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
- credentials = self.list_datasource_credentials(
- tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
- )
- redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
- dify_config.CONSOLE_API_URL, datasource_provider_id
- )
- datasource_credentials.append(
- {
- "provider": datasource.provider,
- "plugin_id": datasource.plugin_id,
- "plugin_unique_identifier": datasource.plugin_unique_identifier,
- "icon": datasource.declaration.identity.icon,
- "name": datasource.declaration.identity.name.split("/")[-1],
- "label": datasource.declaration.identity.label.model_dump(),
- "description": datasource.declaration.identity.description.model_dump(),
- "author": datasource.declaration.identity.author,
- "credentials_list": credentials,
- "credential_schema": [
- credential.model_dump() for credential in datasource.declaration.credentials_schema
- ],
- "oauth_schema": {
- "client_schema": [
- client_schema.model_dump()
- for client_schema in datasource.declaration.oauth_schema.client_schema
- ],
- "credentials_schema": [
- credential_schema.model_dump()
- for credential_schema in datasource.declaration.oauth_schema.credentials_schema
- ],
- "oauth_custom_client_params": self.get_tenant_oauth_client(
- tenant_id, datasource_provider_id, mask=True
- ),
- "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
- tenant_id, datasource_provider_id
- ),
- "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
- "redirect_uri": redirect_uri,
- }
- if datasource.declaration.oauth_schema
- else None,
- }
- )
- return datasource_credentials
-
- def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
- """
- get datasource credentials.
-
- :param tenant_id: workspace id
- :param provider_id: provider id
- :return:
- """
- # Get all provider configurations of the current workspace
- datasource_providers: list[DatasourceProvider] = (
- db.session.query(DatasourceProvider)
- .where(
- DatasourceProvider.tenant_id == tenant_id,
- DatasourceProvider.provider == provider,
- DatasourceProvider.plugin_id == plugin_id,
- )
- .all()
- )
- if not datasource_providers:
- return []
- copy_credentials_list = []
- for datasource_provider in datasource_providers:
- encrypted_credentials = datasource_provider.encrypted_credentials
- # Get provider credential secret variables
- credential_secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id,
- provider_id=f"{plugin_id}/{provider}",
- credential_type=CredentialType.of(datasource_provider.auth_type),
- )
-
- # Obfuscate provider credentials
- copy_credentials = encrypted_credentials.copy()
- for key, value in copy_credentials.items():
- if key in credential_secret_variables:
- copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
- copy_credentials_list.append(
- {
- "credentials": copy_credentials,
- "type": datasource_provider.auth_type,
- }
- )
-
- return copy_credentials_list
-
- def update_datasource_credentials(
- self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
- ) -> None:
- """
- update datasource credentials.
- """
- with Session(db.engine) as session:
- datasource_provider = (
- session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
- .first()
- )
- if not datasource_provider:
- raise ValueError("Datasource provider not found")
- # update name
- if name and name != datasource_provider.name:
- if (
- session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
- .count()
- > 0
- ):
- raise ValueError("Authorization name is already exists")
- datasource_provider.name = name
-
- # update credentials
- if credentials:
- secret_variables = self.extract_secret_variables(
- tenant_id=tenant_id,
- provider_id=f"{plugin_id}/{provider}",
- credential_type=CredentialType.of(datasource_provider.auth_type),
- )
- original_credentials = {
- key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
- for key, value in datasource_provider.encrypted_credentials.items()
- }
- new_credentials = {
- key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
- for key, value in credentials.items()
- }
- try:
- self.provider_manager.validate_provider_credentials(
- tenant_id=tenant_id,
- user_id=current_user.id,
- provider=provider,
- plugin_id=plugin_id,
- credentials=new_credentials,
- )
- except Exception as e:
- raise ValueError(f"Failed to validate credentials: {str(e)}")
-
- encrypted_credentials = {}
- for key, value in new_credentials.items():
- if key in secret_variables:
- encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
- else:
- encrypted_credentials[key] = value
-
- datasource_provider.encrypted_credentials = encrypted_credentials
- session.commit()
-
- def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
- """
- remove datasource credentials.
-
- :param tenant_id: workspace id
- :param provider: provider name
- :param plugin_id: plugin id
- :return:
- """
- datasource_provider = (
- db.session.query(DatasourceProvider)
- .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
- .first()
- )
- if datasource_provider:
- db.session.delete(datasource_provider)
- db.session.commit()
|