You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from flask import redirect, request
  2. from flask_login import current_user # type: ignore
  3. from flask_restful import ( # type: ignore
  4. Resource, # type: ignore
  5. reqparse,
  6. )
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. from configs import dify_config
  9. from controllers.console import api
  10. from controllers.console.wraps import (
  11. account_initialization_required,
  12. setup_required,
  13. )
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.plugin.impl.oauth import OAuthHandler
  16. from extensions.ext_database import db
  17. from libs.login import login_required
  18. from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
  19. from services.datasource_provider_service import DatasourceProviderService
  20. class DatasourcePluginOauthApi(Resource):
  21. @setup_required
  22. @login_required
  23. @account_initialization_required
  24. def get(self, provider, plugin_id):
  25. # Check user role first
  26. if not current_user.is_editor:
  27. raise Forbidden()
  28. # get all plugin oauth configs
  29. plugin_oauth_config = (
  30. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  31. )
  32. if not plugin_oauth_config:
  33. raise NotFound()
  34. oauth_handler = OAuthHandler()
  35. redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback"
  36. system_credentials = plugin_oauth_config.system_credentials
  37. if system_credentials:
  38. system_credentials["redirect_url"] = redirect_url
  39. response = oauth_handler.get_authorization_url(
  40. current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
  41. )
  42. return response.model_dump()
  43. class DatasourceOauthCallback(Resource):
  44. @setup_required
  45. @login_required
  46. @account_initialization_required
  47. def get(self, provider, plugin_id):
  48. oauth_handler = OAuthHandler()
  49. plugin_oauth_config = (
  50. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
  51. )
  52. if not plugin_oauth_config:
  53. raise NotFound()
  54. credentials = oauth_handler.get_credentials(
  55. current_user.current_tenant.id,
  56. current_user.id,
  57. plugin_id,
  58. provider,
  59. system_credentials=plugin_oauth_config.system_credentials,
  60. request=request,
  61. )
  62. datasource_provider = DatasourceProvider(
  63. plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
  64. )
  65. db.session.add(datasource_provider)
  66. db.session.commit()
  67. return redirect(f"{dify_config.CONSOLE_WEB_URL}")
  68. class DatasourceAuth(Resource):
  69. @setup_required
  70. @login_required
  71. @account_initialization_required
  72. def post(self, provider, plugin_id):
  73. if not current_user.is_editor:
  74. raise Forbidden()
  75. parser = reqparse.RequestParser()
  76. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  77. args = parser.parse_args()
  78. datasource_provider_service = DatasourceProviderService()
  79. try:
  80. datasource_provider_service.datasource_provider_credentials_validate(
  81. tenant_id=current_user.current_tenant_id,
  82. provider=provider,
  83. plugin_id=plugin_id,
  84. credentials=args["credentials"],
  85. )
  86. except CredentialsValidateFailedError as ex:
  87. raise ValueError(str(ex))
  88. return {"result": "success"}, 201
  89. @setup_required
  90. @login_required
  91. @account_initialization_required
  92. def get(self, provider, plugin_id):
  93. datasource_provider_service = DatasourceProviderService()
  94. datasources = datasource_provider_service.get_datasource_credentials(
  95. tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
  96. )
  97. return {"result": datasources}, 200
  98. class DatasourceAuthDeleteApi(Resource):
  99. @setup_required
  100. @login_required
  101. @account_initialization_required
  102. def delete(self, provider, plugin_id):
  103. if not current_user.is_editor:
  104. raise Forbidden()
  105. datasource_provider_service = DatasourceProviderService()
  106. datasource_provider_service.remove_datasource_credentials(
  107. tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
  108. )
  109. return {"result": "success"}, 200
  110. # Import Rag Pipeline
  111. api.add_resource(
  112. DatasourcePluginOauthApi,
  113. "/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
  114. )
  115. api.add_resource(
  116. DatasourceOauthCallback,
  117. "/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>/callback",
  118. )
  119. api.add_resource(
  120. DatasourceAuth,
  121. "/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
  122. )