You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_provider_service.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import logging
  2. from flask_login import current_user
  3. from constants import HIDDEN_VALUE
  4. from core.helper import encrypter
  5. from core.model_runtime.entities.provider_entities import FormType
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.plugin.impl.datasource import PluginDatasourceManager
  8. from extensions.ext_database import db
  9. from models.oauth import DatasourceProvider
  10. logger = logging.getLogger(__name__)
  11. class DatasourceProviderService:
  12. """
  13. Model Provider Service
  14. """
  15. def __init__(self) -> None:
  16. self.provider_manager = PluginDatasourceManager()
  17. def datasource_provider_credentials_validate(
  18. self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str
  19. ) -> None:
  20. """
  21. validate datasource provider credentials.
  22. :param tenant_id:
  23. :param provider:
  24. :param credentials:
  25. """
  26. # check name is exist
  27. datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first()
  28. if datasource_provider:
  29. raise ValueError("Authorization name is already exists")
  30. credential_valid = self.provider_manager.validate_provider_credentials(
  31. tenant_id=tenant_id,
  32. user_id=current_user.id,
  33. provider=provider,
  34. plugin_id=plugin_id,
  35. credentials=credentials,
  36. )
  37. if credential_valid:
  38. # Get all provider configurations of the current workspace
  39. datasource_provider = (
  40. db.session.query(DatasourceProvider)
  41. .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
  42. .first()
  43. )
  44. provider_credential_secret_variables = self.extract_secret_variables(
  45. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  46. )
  47. for key, value in credentials.items():
  48. if key in provider_credential_secret_variables:
  49. # if send [__HIDDEN__] in secret input, it will be same as original value
  50. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  51. datasource_provider = DatasourceProvider(
  52. tenant_id=tenant_id,
  53. name=name,
  54. provider=provider,
  55. plugin_id=plugin_id,
  56. auth_type="api_key",
  57. encrypted_credentials=credentials,
  58. )
  59. db.session.add(datasource_provider)
  60. db.session.commit()
  61. else:
  62. raise CredentialsValidateFailedError()
  63. def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
  64. """
  65. Extract secret input form variables.
  66. :param credential_form_schemas:
  67. :return:
  68. """
  69. datasource_provider = self.provider_manager.fetch_datasource_provider(
  70. tenant_id=tenant_id, provider_id=provider_id
  71. )
  72. credential_form_schemas = datasource_provider.declaration.credentials_schema
  73. secret_input_form_variables = []
  74. for credential_form_schema in credential_form_schemas:
  75. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  76. secret_input_form_variables.append(credential_form_schema.name)
  77. return secret_input_form_variables
  78. def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  79. """
  80. get datasource credentials.
  81. :param tenant_id: workspace id
  82. :param provider_id: provider id
  83. :return:
  84. """
  85. # Get all provider configurations of the current workspace
  86. datasource_providers: list[DatasourceProvider] = (
  87. db.session.query(DatasourceProvider)
  88. .filter(
  89. DatasourceProvider.tenant_id == tenant_id,
  90. DatasourceProvider.provider == provider,
  91. DatasourceProvider.plugin_id == plugin_id,
  92. )
  93. .all()
  94. )
  95. if not datasource_providers:
  96. return []
  97. copy_credentials_list = []
  98. for datasource_provider in datasource_providers:
  99. encrypted_credentials = datasource_provider.encrypted_credentials
  100. # Get provider credential secret variables
  101. credential_secret_variables = self.extract_secret_variables(
  102. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  103. )
  104. # Obfuscate provider credentials
  105. copy_credentials = encrypted_credentials.copy()
  106. for key, value in copy_credentials.items():
  107. if key in credential_secret_variables:
  108. copy_credentials[key] = encrypter.obfuscated_token(value)
  109. copy_credentials_list.append(
  110. {
  111. "credentials": copy_credentials,
  112. "type": datasource_provider.auth_type,
  113. "name": datasource_provider.name,
  114. }
  115. )
  116. return copy_credentials_list
  117. def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
  118. """
  119. get datasource credentials.
  120. :return:
  121. """
  122. # get all plugin providers
  123. manager = PluginDatasourceManager()
  124. datasources = manager.fetch_installed_datasource_providers(tenant_id)
  125. datasource_credentials = []
  126. for datasource in datasources:
  127. credentials = self.get_datasource_credentials(
  128. tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
  129. )
  130. datasource_credentials.append(
  131. {
  132. "provider": datasource.provider,
  133. "plugin_id": datasource.plugin_id,
  134. "plugin_unique_identifier": datasource.plugin_unique_identifier,
  135. "icon": datasource.declaration.identity.icon,
  136. "name": datasource.declaration.identity.name,
  137. "description": datasource.declaration.identity.description.model_dump(),
  138. "author": datasource.declaration.identity.author,
  139. "credentials": credentials,
  140. }
  141. )
  142. return datasource_credentials
  143. def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
  144. """
  145. get datasource credentials.
  146. :param tenant_id: workspace id
  147. :param provider_id: provider id
  148. :return:
  149. """
  150. # Get all provider configurations of the current workspace
  151. datasource_providers: list[DatasourceProvider] = (
  152. db.session.query(DatasourceProvider)
  153. .filter(
  154. DatasourceProvider.tenant_id == tenant_id,
  155. DatasourceProvider.provider == provider,
  156. DatasourceProvider.plugin_id == plugin_id,
  157. )
  158. .all()
  159. )
  160. if not datasource_providers:
  161. return []
  162. copy_credentials_list = []
  163. for datasource_provider in datasource_providers:
  164. encrypted_credentials = datasource_provider.encrypted_credentials
  165. # Get provider credential secret variables
  166. credential_secret_variables = self.extract_secret_variables(
  167. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  168. )
  169. # Obfuscate provider credentials
  170. copy_credentials = encrypted_credentials.copy()
  171. for key, value in copy_credentials.items():
  172. if key in credential_secret_variables:
  173. copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
  174. copy_credentials_list.append(
  175. {
  176. "credentials": copy_credentials,
  177. "type": datasource_provider.auth_type,
  178. }
  179. )
  180. return copy_credentials_list
  181. def update_datasource_credentials(
  182. self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict
  183. ) -> None:
  184. """
  185. update datasource credentials.
  186. """
  187. credential_valid = self.provider_manager.validate_provider_credentials(
  188. tenant_id=tenant_id,
  189. user_id=current_user.id,
  190. provider=provider,
  191. plugin_id=plugin_id,
  192. credentials=credentials,
  193. )
  194. if credential_valid:
  195. # Get all provider configurations of the current workspace
  196. datasource_provider = (
  197. db.session.query(DatasourceProvider)
  198. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  199. .first()
  200. )
  201. provider_credential_secret_variables = self.extract_secret_variables(
  202. tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
  203. )
  204. if not datasource_provider:
  205. raise ValueError("Datasource provider not found")
  206. else:
  207. original_credentials = datasource_provider.encrypted_credentials
  208. for key, value in credentials.items():
  209. if key in provider_credential_secret_variables:
  210. # if send [__HIDDEN__] in secret input, it will be same as original value
  211. if value == HIDDEN_VALUE and key in original_credentials:
  212. original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
  213. credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
  214. else:
  215. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  216. datasource_provider.encrypted_credentials = credentials
  217. db.session.commit()
  218. else:
  219. raise CredentialsValidateFailedError()
  220. def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
  221. """
  222. remove datasource credentials.
  223. :param tenant_id: workspace id
  224. :param provider: provider name
  225. :param plugin_id: plugin id
  226. :return:
  227. """
  228. datasource_provider = (
  229. db.session.query(DatasourceProvider)
  230. .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
  231. .first()
  232. )
  233. if datasource_provider:
  234. db.session.delete(datasource_provider)
  235. db.session.commit()