You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_provider_service.py 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import logging
  2. from flask_login import current_user
  3. from constants import HIDDEN_VALUE
  4. from core.helper import encrypter
  5. from core.model_runtime.entities.provider_entities import FormType
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.plugin.impl.datasource import PluginDatasourceManager
  8. from extensions.ext_database import db
  9. from models.oauth import DatasourceProvider
  10. logger = logging.getLogger(__name__)
  11. class DatasourceProviderService:
  12. """
  13. Model Provider Service
  14. """
  15. def __init__(self) -> None:
  16. self.provider_manager = PluginDatasourceManager()
  17. def datasource_provider_credentials_validate(
  18. self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
  19. ) -> None:
  20. """
  21. validate datasource provider credentials.
  22. :param tenant_id:
  23. :param provider:
  24. :param credentials:
  25. """
  26. credential_valid = self.provider_manager.validate_provider_credentials(
  27. tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials
  28. )
  29. if credential_valid:
  30. # Get all provider configurations of the current workspace
  31. datasource_provider = (
  32. db.session.query(DatasourceProvider)
  33. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
  34. .first()
  35. )
  36. provider_credential_secret_variables = self.extract_secret_variables(
  37. tenant_id=tenant_id,
  38. provider_id=f"{plugin_id}/{provider}"
  39. )
  40. for key, value in credentials.items():
  41. if key in provider_credential_secret_variables:
  42. # if send [__HIDDEN__] in secret input, it will be same as original value
  43. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  44. datasource_provider = DatasourceProvider(
  45. tenant_id=tenant_id,
  46. provider=provider,
  47. plugin_id=plugin_id,
  48. auth_type="api_key",
  49. encrypted_credentials=credentials,
  50. )
  51. db.session.add(datasource_provider)
  52. db.session.commit()
  53. else:
  54. raise CredentialsValidateFailedError()
  55. def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
  56. """
  57. Extract secret input form variables.
  58. :param credential_form_schemas:
  59. :return:
  60. """
  61. datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id,
  62. provider_id=provider_id
  63. )
  64. credential_form_schemas = datasource_provider.declaration.credentials_schema
  65. secret_input_form_variables = []
  66. for credential_form_schema in credential_form_schemas:
  67. if credential_form_schema.type == FormType.SECRET_INPUT:
  68. secret_input_form_variables.append(credential_form_schema.name)
  69. return secret_input_form_variables
  70. def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  71. """
  72. get datasource credentials.
  73. :param tenant_id: workspace id
  74. :param provider_id: provider id
  75. :return:
  76. """
  77. # Get all provider configurations of the current workspace
  78. datasource_providers: list[DatasourceProvider] = (
  79. db.session.query(DatasourceProvider)
  80. .filter(
  81. DatasourceProvider.tenant_id == tenant_id,
  82. DatasourceProvider.provider == provider,
  83. DatasourceProvider.plugin_id == plugin_id,
  84. )
  85. .all()
  86. )
  87. if not datasource_providers:
  88. return []
  89. copy_credentials_list = []
  90. for datasource_provider in datasource_providers:
  91. encrypted_credentials = datasource_provider.encrypted_credentials
  92. # Get provider credential secret variables
  93. credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
  94. # Obfuscate provider credentials
  95. copy_credentials = encrypted_credentials.copy()
  96. for key, value in copy_credentials.items():
  97. if key in credential_secret_variables:
  98. copy_credentials[key] = encrypter.obfuscated_token(value)
  99. copy_credentials_list.append(
  100. {
  101. "credentials": copy_credentials,
  102. "type": datasource_provider.auth_type,
  103. }
  104. )
  105. return copy_credentials_list
  106. def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
  107. """
  108. update datasource credentials.
  109. """
  110. credential_valid = self.provider_manager.validate_provider_credentials(
  111. tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials
  112. )
  113. if credential_valid:
  114. # Get all provider configurations of the current workspace
  115. datasource_provider = (
  116. db.session.query(DatasourceProvider)
  117. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  118. .first()
  119. )
  120. provider_credential_secret_variables = self.extract_secret_variables(
  121. tenant_id=tenant_id,
  122. provider_id=f"{plugin_id}/{provider}"
  123. )
  124. if not datasource_provider:
  125. raise ValueError("Datasource provider not found")
  126. else:
  127. original_credentials = datasource_provider.encrypted_credentials
  128. for key, value in credentials.items():
  129. if key in provider_credential_secret_variables:
  130. # if send [__HIDDEN__] in secret input, it will be same as original value
  131. if value == HIDDEN_VALUE and key in original_credentials:
  132. original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
  133. credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
  134. else:
  135. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  136. datasource_provider.encrypted_credentials = credentials
  137. db.session.commit()
  138. else:
  139. raise CredentialsValidateFailedError()
  140. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  141. """
  142. remove datasource credentials.
  143. :param tenant_id: workspace id
  144. :param provider: provider name
  145. :param plugin_id: plugin id
  146. :return:
  147. """
  148. datasource_provider = (
  149. db.session.query(DatasourceProvider)
  150. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  151. .first()
  152. )
  153. if datasource_provider:
  154. db.session.delete(datasource_provider)
  155. db.session.commit()