Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

datasource_auth.py 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from flask import redirect, request
  2. from flask_login import current_user # type: ignore
  3. from flask_restful import ( # type: ignore
  4. Resource, # type: ignore
  5. reqparse,
  6. )
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. from configs import dify_config
  9. from controllers.console import api
  10. from controllers.console.wraps import (
  11. account_initialization_required,
  12. setup_required,
  13. )
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.plugin.impl.oauth import OAuthHandler
  16. from extensions.ext_database import db
  17. from libs.login import login_required
  18. from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
  19. from services.datasource_provider_service import DatasourceProviderService
  20. class DatasourcePluginOauthApi(Resource):
  21. @setup_required
  22. @login_required
  23. @account_initialization_required
  24. def get(self):
  25. parser = reqparse.RequestParser()
  26. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  27. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  28. args = parser.parse_args()
  29. provider = args["provider"]
  30. plugin_id = args["plugin_id"]
  31. # Check user role first
  32. if not current_user.is_editor:
  33. raise Forbidden()
  34. # get all plugin oauth configs
  35. plugin_oauth_config = (
  36. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  37. )
  38. if not plugin_oauth_config:
  39. raise NotFound()
  40. oauth_handler = OAuthHandler()
  41. redirect_url = (
  42. f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
  43. )
  44. system_credentials = plugin_oauth_config.system_credentials
  45. if system_credentials:
  46. system_credentials["redirect_url"] = redirect_url
  47. response = oauth_handler.get_authorization_url(
  48. current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
  49. )
  50. return response.model_dump()
  51. class DatasourceOauthCallback(Resource):
  52. @setup_required
  53. @login_required
  54. @account_initialization_required
  55. def get(self):
  56. parser = reqparse.RequestParser()
  57. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  58. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  59. args = parser.parse_args()
  60. provider = args["provider"]
  61. plugin_id = args["plugin_id"]
  62. oauth_handler = OAuthHandler()
  63. plugin_oauth_config = (
  64. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  65. )
  66. if not plugin_oauth_config:
  67. raise NotFound()
  68. credentials = oauth_handler.get_credentials(
  69. current_user.current_tenant.id,
  70. current_user.id,
  71. plugin_id,
  72. provider,
  73. system_credentials=plugin_oauth_config.system_credentials,
  74. request=request,
  75. )
  76. datasource_provider = DatasourceProvider(
  77. plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
  78. )
  79. db.session.add(datasource_provider)
  80. db.session.commit()
  81. return redirect(f"{dify_config.CONSOLE_WEB_URL}")
  82. class DatasourceAuth(Resource):
  83. @setup_required
  84. @login_required
  85. @account_initialization_required
  86. def post(self):
  87. if not current_user.is_editor:
  88. raise Forbidden()
  89. parser = reqparse.RequestParser()
  90. parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
  91. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
  92. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  93. args = parser.parse_args()
  94. datasource_provider_service = DatasourceProviderService()
  95. try:
  96. datasource_provider_service.datasource_provider_credentials_validate(
  97. tenant_id=current_user.current_tenant_id,
  98. provider=args["provider"],
  99. plugin_id=args["plugin_id"],
  100. credentials=args["credentials"],
  101. )
  102. except CredentialsValidateFailedError as ex:
  103. raise ValueError(str(ex))
  104. return {"result": "success"}, 201
  105. @setup_required
  106. @login_required
  107. @account_initialization_required
  108. def get(self):
  109. parser = reqparse.RequestParser()
  110. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  111. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  112. args = parser.parse_args()
  113. datasource_provider_service = DatasourceProviderService()
  114. datasources = datasource_provider_service.get_datasource_credentials(
  115. tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
  116. )
  117. return {"result": datasources}, 200
  118. class DatasourceAuthUpdateDeleteApi(Resource):
  119. @setup_required
  120. @login_required
  121. @account_initialization_required
  122. def delete(self, auth_id: str):
  123. parser = reqparse.RequestParser()
  124. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  125. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  126. args = parser.parse_args()
  127. if not current_user.is_editor:
  128. raise Forbidden()
  129. datasource_provider_service = DatasourceProviderService()
  130. datasource_provider_service.remove_datasource_credentials(
  131. tenant_id=current_user.current_tenant_id,
  132. auth_id=auth_id,
  133. provider=args["provider"],
  134. plugin_id=args["plugin_id"],
  135. )
  136. return {"result": "success"}, 200
  137. @setup_required
  138. @login_required
  139. @account_initialization_required
  140. def patch(self, auth_id: str):
  141. parser = reqparse.RequestParser()
  142. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  143. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  144. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  145. args = parser.parse_args()
  146. if not current_user.is_editor:
  147. raise Forbidden()
  148. try:
  149. datasource_provider_service = DatasourceProviderService()
  150. datasource_provider_service.update_datasource_credentials(
  151. tenant_id=current_user.current_tenant_id,
  152. auth_id=auth_id,
  153. provider=args["provider"],
  154. plugin_id=args["plugin_id"],
  155. credentials=args["credentials"],
  156. )
  157. except CredentialsValidateFailedError as ex:
  158. raise ValueError(str(ex))
  159. return {"result": "success"}, 201
  160. # Import Rag Pipeline
  161. api.add_resource(
  162. DatasourcePluginOauthApi,
  163. "/oauth/plugin/datasource",
  164. )
  165. api.add_resource(
  166. DatasourceOauthCallback,
  167. "/oauth/plugin/datasource/callback",
  168. )
  169. api.add_resource(
  170. DatasourceAuth,
  171. "/auth/plugin/datasource",
  172. )
  173. api.add_resource(
  174. DatasourceAuthUpdateDeleteApi,
  175. "/auth/plugin/datasource/<string:auth_id>",
  176. )