選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

datasource_provider_service.py 7.6KB

5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
5ヶ月前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import logging
  2. from flask_login import current_user
  3. from constants import HIDDEN_VALUE
  4. from core.helper import encrypter
  5. from core.model_runtime.entities.provider_entities import FormType
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.plugin.impl.datasource import PluginDatasourceManager
  8. from extensions.ext_database import db
  9. from models.oauth import DatasourceProvider
  10. logger = logging.getLogger(__name__)
  11. class DatasourceProviderService:
  12. """
  13. Model Provider Service
  14. """
  15. def __init__(self) -> None:
  16. self.provider_manager = PluginDatasourceManager()
  17. def datasource_provider_credentials_validate(
  18. self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
  19. ) -> None:
  20. """
  21. validate datasource provider credentials.
  22. :param tenant_id:
  23. :param provider:
  24. :param credentials:
  25. """
  26. credential_valid = self.provider_manager.validate_provider_credentials(
  27. tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials
  28. )
  29. if credential_valid:
  30. # Get all provider configurations of the current workspace
  31. datasource_provider = (
  32. db.session.query(DatasourceProvider)
  33. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
  34. .first()
  35. )
  36. provider_credential_secret_variables = self.extract_secret_variables(
  37. tenant_id=tenant_id,
  38. provider_id=f"{plugin_id}/{provider}"
  39. )
  40. for key, value in credentials.items():
  41. if key in provider_credential_secret_variables:
  42. # if send [__HIDDEN__] in secret input, it will be same as original value
  43. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  44. datasource_provider = DatasourceProvider(
  45. tenant_id=tenant_id,
  46. provider=provider,
  47. plugin_id=plugin_id,
  48. auth_type="api_key",
  49. encrypted_credentials=credentials,
  50. )
  51. db.session.add(datasource_provider)
  52. db.session.commit()
  53. else:
  54. raise CredentialsValidateFailedError()
  55. def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
  56. """
  57. Extract secret input form variables.
  58. :param credential_form_schemas:
  59. :return:
  60. """
  61. datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id,
  62. provider_id=provider_id
  63. )
  64. credential_form_schemas = datasource_provider.declaration.credentials_schema
  65. secret_input_form_variables = []
  66. for credential_form_schema in credential_form_schemas:
  67. if credential_form_schema.type == FormType.SECRET_INPUT:
  68. secret_input_form_variables.append(credential_form_schema.name)
  69. return secret_input_form_variables
  70. def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  71. """
  72. get datasource credentials.
  73. :param tenant_id: workspace id
  74. :param provider_id: provider id
  75. :return:
  76. """
  77. # Get all provider configurations of the current workspace
  78. datasource_providers: list[DatasourceProvider] = (
  79. db.session.query(DatasourceProvider)
  80. .filter(
  81. DatasourceProvider.tenant_id == tenant_id,
  82. DatasourceProvider.provider == provider,
  83. DatasourceProvider.plugin_id == plugin_id,
  84. )
  85. .all()
  86. )
  87. if not datasource_providers:
  88. return []
  89. copy_credentials_list = []
  90. for datasource_provider in datasource_providers:
  91. encrypted_credentials = datasource_provider.encrypted_credentials
  92. # Get provider credential secret variables
  93. credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
  94. # Obfuscate provider credentials
  95. copy_credentials = encrypted_credentials.copy()
  96. for key, value in copy_credentials.items():
  97. if key in credential_secret_variables:
  98. copy_credentials[key] = encrypter.obfuscated_token(value)
  99. copy_credentials_list.append(
  100. {
  101. "credentials": copy_credentials,
  102. "type": datasource_provider.auth_type,
  103. }
  104. )
  105. return copy_credentials_list
  106. def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
  107. """
  108. update datasource credentials.
  109. """
  110. credential_valid = self.provider_manager.validate_provider_credentials(
  111. tenant_id=tenant_id, user_id=current_user.id, provider=provider,plugin_id=plugin_id, credentials=credentials
  112. )
  113. if credential_valid:
  114. # Get all provider configurations of the current workspace
  115. datasource_provider = (
  116. db.session.query(DatasourceProvider)
  117. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  118. .first()
  119. )
  120. provider_credential_secret_variables = self.extract_secret_variables(
  121. tenant_id=tenant_id,
  122. provider_id=f"{plugin_id}/{provider}"
  123. )
  124. if not datasource_provider:
  125. raise ValueError("Datasource provider not found")
  126. else:
  127. original_credentials = datasource_provider.encrypted_credentials
  128. for key, value in credentials.items():
  129. if key in provider_credential_secret_variables:
  130. # if send [__HIDDEN__] in secret input, it will be same as original value
  131. if value == HIDDEN_VALUE and key in original_credentials:
  132. original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
  133. credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
  134. else:
  135. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  136. datasource_provider.encrypted_credentials = credentials
  137. db.session.commit()
  138. else:
  139. raise CredentialsValidateFailedError()
  140. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  141. """
  142. remove datasource credentials.
  143. :param tenant_id: workspace id
  144. :param provider: provider name
  145. :param plugin_id: plugin id
  146. :return:
  147. """
  148. datasource_provider = (
  149. db.session.query(DatasourceProvider)
  150. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  151. .first()
  152. )
  153. if datasource_provider:
  154. db.session.delete(datasource_provider)
  155. db.session.commit()