You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_auth.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import random
  2. from flask import redirect, request
  3. from flask_login import current_user # type: ignore
  4. from flask_restful import ( # type: ignore
  5. Resource, # type: ignore
  6. reqparse,
  7. )
  8. from werkzeug.exceptions import Forbidden, NotFound
  9. from configs import dify_config
  10. from controllers.console import api
  11. from controllers.console.wraps import (
  12. account_initialization_required,
  13. setup_required,
  14. )
  15. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  16. from core.plugin.impl.oauth import OAuthHandler
  17. from extensions.ext_database import db
  18. from libs.login import login_required
  19. from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
  20. from services.datasource_provider_service import DatasourceProviderService
  21. class DatasourcePluginOauthApi(Resource):
  22. @setup_required
  23. @login_required
  24. @account_initialization_required
  25. def get(self):
  26. parser = reqparse.RequestParser()
  27. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  28. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  29. args = parser.parse_args()
  30. provider = args["provider"]
  31. plugin_id = args["plugin_id"]
  32. # Check user role first
  33. if not current_user.is_editor:
  34. raise Forbidden()
  35. # get all plugin oauth configs
  36. plugin_oauth_config = (
  37. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  38. )
  39. if not plugin_oauth_config:
  40. raise NotFound()
  41. oauth_handler = OAuthHandler()
  42. redirect_url = (
  43. f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
  44. )
  45. system_credentials = plugin_oauth_config.system_credentials
  46. if system_credentials:
  47. system_credentials["redirect_url"] = redirect_url
  48. response = oauth_handler.get_authorization_url(
  49. current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
  50. )
  51. return response.model_dump()
  52. class DatasourceOauthCallback(Resource):
  53. @setup_required
  54. @login_required
  55. @account_initialization_required
  56. def get(self):
  57. parser = reqparse.RequestParser()
  58. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  59. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  60. args = parser.parse_args()
  61. provider = args["provider"]
  62. plugin_id = args["plugin_id"]
  63. oauth_handler = OAuthHandler()
  64. plugin_oauth_config = (
  65. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  66. )
  67. if not plugin_oauth_config:
  68. raise NotFound()
  69. credentials = oauth_handler.get_credentials(
  70. current_user.current_tenant.id,
  71. current_user.id,
  72. plugin_id,
  73. provider,
  74. system_credentials=plugin_oauth_config.system_credentials,
  75. request=request,
  76. )
  77. datasource_provider = DatasourceProvider(
  78. plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
  79. )
  80. db.session.add(datasource_provider)
  81. db.session.commit()
  82. return redirect(f"{dify_config.CONSOLE_WEB_URL}")
  83. class DatasourceAuth(Resource):
  84. @setup_required
  85. @login_required
  86. @account_initialization_required
  87. def post(self):
  88. if not current_user.is_editor:
  89. raise Forbidden()
  90. parser = reqparse.RequestParser()
  91. parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
  92. parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
  93. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
  94. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  95. args = parser.parse_args()
  96. datasource_provider_service = DatasourceProviderService()
  97. try:
  98. datasource_provider_service.datasource_provider_credentials_validate(
  99. tenant_id=current_user.current_tenant_id,
  100. provider=args["provider"],
  101. plugin_id=args["plugin_id"],
  102. credentials=args["credentials"],
  103. name="test" + str(random.randint(1, 1000000)),
  104. )
  105. except CredentialsValidateFailedError as ex:
  106. raise ValueError(str(ex))
  107. return {"result": "success"}, 201
  108. @setup_required
  109. @login_required
  110. @account_initialization_required
  111. def get(self):
  112. parser = reqparse.RequestParser()
  113. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  114. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  115. args = parser.parse_args()
  116. datasource_provider_service = DatasourceProviderService()
  117. datasources = datasource_provider_service.get_datasource_credentials(
  118. tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
  119. )
  120. return {"result": datasources}, 200
  121. class DatasourceAuthUpdateDeleteApi(Resource):
  122. @setup_required
  123. @login_required
  124. @account_initialization_required
  125. def delete(self, auth_id: str):
  126. parser = reqparse.RequestParser()
  127. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  128. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  129. args = parser.parse_args()
  130. if not current_user.is_editor:
  131. raise Forbidden()
  132. datasource_provider_service = DatasourceProviderService()
  133. datasource_provider_service.remove_datasource_credentials(
  134. tenant_id=current_user.current_tenant_id,
  135. auth_id=auth_id,
  136. provider=args["provider"],
  137. plugin_id=args["plugin_id"],
  138. )
  139. return {"result": "success"}, 200
  140. @setup_required
  141. @login_required
  142. @account_initialization_required
  143. def patch(self, auth_id: str):
  144. parser = reqparse.RequestParser()
  145. parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
  146. parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
  147. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  148. args = parser.parse_args()
  149. if not current_user.is_editor:
  150. raise Forbidden()
  151. try:
  152. datasource_provider_service = DatasourceProviderService()
  153. datasource_provider_service.update_datasource_credentials(
  154. tenant_id=current_user.current_tenant_id,
  155. auth_id=auth_id,
  156. provider=args["provider"],
  157. plugin_id=args["plugin_id"],
  158. credentials=args["credentials"],
  159. )
  160. except CredentialsValidateFailedError as ex:
  161. raise ValueError(str(ex))
  162. return {"result": "success"}, 201
  163. class DatasourceAuthListApi(Resource):
  164. @setup_required
  165. @login_required
  166. @account_initialization_required
  167. def get(self):
  168. datasource_provider_service = DatasourceProviderService()
  169. datasources = datasource_provider_service.get_all_datasource_credentials(
  170. tenant_id=current_user.current_tenant_id
  171. )
  172. return {"result": datasources}, 200
  173. # Import Rag Pipeline
  174. api.add_resource(
  175. DatasourcePluginOauthApi,
  176. "/oauth/plugin/datasource",
  177. )
  178. api.add_resource(
  179. DatasourceOauthCallback,
  180. "/oauth/plugin/datasource/callback",
  181. )
  182. api.add_resource(
  183. DatasourceAuth,
  184. "/auth/plugin/datasource",
  185. )
  186. api.add_resource(
  187. DatasourceAuthUpdateDeleteApi,
  188. "/auth/plugin/datasource/<string:auth_id>",
  189. )
  190. api.add_resource(
  191. DatasourceAuthListApi,
  192. "/auth/plugin/datasource/list",
  193. )