You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_provider_service.py 8.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import logging
  2. from flask_login import current_user
  3. from constants import HIDDEN_VALUE
  4. from core.helper import encrypter
  5. from core.model_runtime.entities.provider_entities import FormType
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.plugin.impl.datasource import PluginDatasourceManager
  8. from extensions.ext_database import db
  9. from models.oauth import DatasourceProvider
  10. logger = logging.getLogger(__name__)
  11. class DatasourceProviderService:
  12. """
  13. Model Provider Service
  14. """
  15. def __init__(self) -> None:
  16. self.provider_manager = PluginDatasourceManager()
  17. def datasource_provider_credentials_validate(
  18. self, tenant_id: str, provider: str, plugin_id: str, credentials: dict
  19. ) -> None:
  20. """
  21. validate datasource provider credentials.
  22. :param tenant_id:
  23. :param provider:
  24. :param credentials:
  25. """
  26. credential_valid = self.provider_manager.validate_provider_credentials(
  27. tenant_id=tenant_id,
  28. user_id=current_user.id,
  29. provider=provider,
  30. plugin_id=plugin_id,
  31. credentials=credentials,
  32. )
  33. if credential_valid:
  34. # Get all provider configurations of the current workspace
  35. datasource_provider = (
  36. db.session.query(DatasourceProvider)
  37. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
  38. .first()
  39. )
  40. provider_credential_secret_variables = self.extract_secret_variables(
  41. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  42. )
  43. for key, value in credentials.items():
  44. if key in provider_credential_secret_variables:
  45. # if send [__HIDDEN__] in secret input, it will be same as original value
  46. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  47. datasource_provider = DatasourceProvider(
  48. tenant_id=tenant_id,
  49. provider=provider,
  50. plugin_id=plugin_id,
  51. auth_type="api_key",
  52. encrypted_credentials=credentials,
  53. )
  54. db.session.add(datasource_provider)
  55. db.session.commit()
  56. else:
  57. raise CredentialsValidateFailedError()
  58. def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
  59. """
  60. Extract secret input form variables.
  61. :param credential_form_schemas:
  62. :return:
  63. """
  64. datasource_provider = self.provider_manager.fetch_datasource_provider(
  65. tenant_id=tenant_id, provider_id=provider_id
  66. )
  67. credential_form_schemas = datasource_provider.declaration.credentials_schema
  68. secret_input_form_variables = []
  69. for credential_form_schema in credential_form_schemas:
  70. if credential_form_schema.type == FormType.SECRET_INPUT:
  71. secret_input_form_variables.append(credential_form_schema.name)
  72. return secret_input_form_variables
  73. def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  74. """
  75. get datasource credentials.
  76. :param tenant_id: workspace id
  77. :param provider_id: provider id
  78. :return:
  79. """
  80. # Get all provider configurations of the current workspace
  81. datasource_providers: list[DatasourceProvider] = (
  82. db.session.query(DatasourceProvider)
  83. .filter(
  84. DatasourceProvider.tenant_id == tenant_id,
  85. DatasourceProvider.provider == provider,
  86. DatasourceProvider.plugin_id == plugin_id,
  87. )
  88. .all()
  89. )
  90. if not datasource_providers:
  91. return []
  92. copy_credentials_list = []
  93. for datasource_provider in datasource_providers:
  94. encrypted_credentials = datasource_provider.encrypted_credentials
  95. # Get provider credential secret variables
  96. credential_secret_variables = self.extract_secret_variables(
  97. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  98. )
  99. # Obfuscate provider credentials
  100. copy_credentials = encrypted_credentials.copy()
  101. for key, value in copy_credentials.items():
  102. if key in credential_secret_variables:
  103. copy_credentials[key] = encrypter.obfuscated_token(value)
  104. copy_credentials_list.append(
  105. {
  106. "credentials": copy_credentials,
  107. "type": datasource_provider.auth_type,
  108. }
  109. )
  110. return copy_credentials_list
  111. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  112. """
  113. get datasource credentials.
  114. :param tenant_id: workspace id
  115. :param provider_id: provider id
  116. :return:
  117. """
  118. # Get all provider configurations of the current workspace
  119. datasource_providers: list[DatasourceProvider] = (
  120. db.session.query(DatasourceProvider)
  121. .filter(
  122. DatasourceProvider.tenant_id == tenant_id,
  123. DatasourceProvider.provider == provider,
  124. DatasourceProvider.plugin_id == plugin_id,
  125. )
  126. .all()
  127. )
  128. if not datasource_providers:
  129. return []
  130. copy_credentials_list = []
  131. for datasource_provider in datasource_providers:
  132. encrypted_credentials = datasource_provider.encrypted_credentials
  133. # Get provider credential secret variables
  134. credential_secret_variables = self.extract_secret_variables(
  135. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  136. )
  137. # Obfuscate provider credentials
  138. copy_credentials = encrypted_credentials.copy()
  139. for key, value in copy_credentials.items():
  140. if key in credential_secret_variables:
  141. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  142. copy_credentials_list.append(
  143. {
  144. "credentials": copy_credentials,
  145. "type": datasource_provider.auth_type,
  146. }
  147. )
  148. return copy_credentials_list
  149. def update_datasource_credentials(
  150. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict
  151. ) -> None:
  152. """
  153. update datasource credentials.
  154. """
  155. credential_valid = self.provider_manager.validate_provider_credentials(
  156. tenant_id=tenant_id,
  157. user_id=current_user.id,
  158. provider=provider,
  159. plugin_id=plugin_id,
  160. credentials=credentials,
  161. )
  162. if credential_valid:
  163. # Get all provider configurations of the current workspace
  164. datasource_provider = (
  165. db.session.query(DatasourceProvider)
  166. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  167. .first()
  168. )
  169. provider_credential_secret_variables = self.extract_secret_variables(
  170. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  171. )
  172. if not datasource_provider:
  173. raise ValueError("Datasource provider not found")
  174. else:
  175. original_credentials = datasource_provider.encrypted_credentials
  176. for key, value in credentials.items():
  177. if key in provider_credential_secret_variables:
  178. # if send [__HIDDEN__] in secret input, it will be same as original value
  179. if value == HIDDEN_VALUE and key in original_credentials:
  180. original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
  181. credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
  182. else:
  183. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  184. datasource_provider.encrypted_credentials = credentials
  185. db.session.commit()
  186. else:
  187. raise CredentialsValidateFailedError()
  188. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  189. """
  190. remove datasource credentials.
  191. :param tenant_id: workspace id
  192. :param provider: provider name
  193. :param plugin_id: plugin id
  194. :return:
  195. """
  196. datasource_provider = (
  197. db.session.query(DatasourceProvider)
  198. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  199. .first()
  200. )
  201. if datasource_provider:
  202. db.session.delete(datasource_provider)
  203. db.session.commit()