You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasource_auth.py 12KB

5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
4 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
3 mesi fa
3 mesi fa
5 mesi fa
3 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
5 mesi fa
3 mesi fa
3 mesi fa
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from fastapi.encoders import jsonable_encoder
  2. from flask import make_response, redirect, request
  3. from flask_login import current_user # type: ignore
  4. from flask_restful import ( # type: ignore
  5. Resource, # type: ignore
  6. reqparse,
  7. )
  8. from werkzeug.exceptions import Forbidden, NotFound
  9. from configs import dify_config
  10. from controllers.console import api
  11. from controllers.console.wraps import (
  12. account_initialization_required,
  13. setup_required,
  14. )
  15. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  16. from core.plugin.entities.plugin import DatasourceProviderID
  17. from core.plugin.impl.oauth import OAuthHandler
  18. from extensions.ext_database import db
  19. from libs.login import login_required
  20. from models.oauth import DatasourceOauthParamConfig
  21. from services.datasource_provider_service import DatasourceProviderService
  22. from services.plugin.oauth_service import OAuthProxyService
  23. class DatasourcePluginOAuthAuthorizationUrl(Resource):
  24. @setup_required
  25. @login_required
  26. @account_initialization_required
  27. def get(self, provider_id: str):
  28. user = current_user
  29. tenant_id = user.current_tenant_id
  30. if not current_user.is_editor:
  31. raise Forbidden()
  32. datasource_provider_id = DatasourceProviderID(provider_id)
  33. provider_name = datasource_provider_id.provider_name
  34. plugin_id = datasource_provider_id.plugin_id
  35. oauth_config = (
  36. db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
  37. )
  38. if not oauth_config:
  39. raise ValueError(f"No OAuth Client Config for {provider_id}")
  40. context_id = OAuthProxyService.create_proxy_context(
  41. user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
  42. )
  43. oauth_handler = OAuthHandler()
  44. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
  45. oauth_client_params = oauth_config.system_credentials
  46. authorization_url_response = oauth_handler.get_authorization_url(
  47. tenant_id=tenant_id,
  48. user_id=user.id,
  49. plugin_id=plugin_id,
  50. provider=provider_name,
  51. redirect_uri=redirect_uri,
  52. system_credentials=oauth_client_params,
  53. )
  54. response = make_response(jsonable_encoder(authorization_url_response))
  55. response.set_cookie(
  56. "context_id",
  57. context_id,
  58. httponly=True,
  59. samesite="Lax",
  60. max_age=OAuthProxyService.__MAX_AGE__,
  61. )
  62. return response
  63. class DatasourceOAuthCallback(Resource):
  64. @setup_required
  65. def get(self, provider_id: str):
  66. context_id = request.cookies.get("context_id") or request.args.get("context_id")
  67. if not context_id:
  68. raise Forbidden("context_id not found")
  69. context = OAuthProxyService.use_proxy_context(context_id)
  70. if context is None:
  71. raise Forbidden("Invalid context_id")
  72. user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
  73. datasource_provider_id = DatasourceProviderID(provider_id)
  74. plugin_id = datasource_provider_id.plugin_id
  75. datasource_provider_service = DatasourceProviderService()
  76. oauth_client_params = datasource_provider_service.get_oauth_client(
  77. tenant_id=tenant_id,
  78. datasource_provider_id=datasource_provider_id,
  79. )
  80. if not oauth_client_params:
  81. raise NotFound()
  82. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
  83. oauth_handler = OAuthHandler()
  84. oauth_response = oauth_handler.get_credentials(
  85. tenant_id=tenant_id,
  86. user_id=user_id,
  87. plugin_id=plugin_id,
  88. provider=datasource_provider_id.provider_name,
  89. redirect_uri=redirect_uri,
  90. system_credentials=oauth_client_params,
  91. request=request,
  92. )
  93. datasource_provider_service.add_datasource_oauth_provider(
  94. tenant_id=tenant_id,
  95. provider_id=datasource_provider_id,
  96. avatar_url=oauth_response.metadata.get("avatar_url") or None,
  97. name=oauth_response.metadata.get("name") or None,
  98. credentials=dict(oauth_response.credentials),
  99. )
  100. return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
  101. class DatasourceAuth(Resource):
  102. @setup_required
  103. @login_required
  104. @account_initialization_required
  105. def post(self, provider_id: str):
  106. if not current_user.is_editor:
  107. raise Forbidden()
  108. parser = reqparse.RequestParser()
  109. parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None)
  110. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  111. args = parser.parse_args()
  112. datasource_provider_id = DatasourceProviderID(provider_id)
  113. datasource_provider_service = DatasourceProviderService()
  114. try:
  115. datasource_provider_service.add_datasource_api_key_provider(
  116. tenant_id=current_user.current_tenant_id,
  117. provider_id=datasource_provider_id,
  118. credentials=args["credentials"],
  119. name=args["name"],
  120. )
  121. except CredentialsValidateFailedError as ex:
  122. raise ValueError(str(ex))
  123. return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
  124. @setup_required
  125. @login_required
  126. @account_initialization_required
  127. def get(self, provider_id: str):
  128. datasource_provider_id = DatasourceProviderID(provider_id)
  129. datasource_provider_service = DatasourceProviderService()
  130. datasources = datasource_provider_service.get_datasource_credentials(
  131. tenant_id=current_user.current_tenant_id,
  132. provider=datasource_provider_id.provider_name,
  133. plugin_id=datasource_provider_id.plugin_id,
  134. )
  135. return {"result": datasources}, 200
  136. class DatasourceAuthUpdateDeleteApi(Resource):
  137. @setup_required
  138. @login_required
  139. @account_initialization_required
  140. def delete(self, provider_id: str, auth_id: str):
  141. datasource_provider_id = DatasourceProviderID(provider_id)
  142. plugin_id = datasource_provider_id.plugin_id
  143. provider_name = datasource_provider_id.provider_name
  144. if not current_user.is_editor:
  145. raise Forbidden()
  146. datasource_provider_service = DatasourceProviderService()
  147. datasource_provider_service.remove_datasource_credentials(
  148. tenant_id=current_user.current_tenant_id,
  149. auth_id=auth_id,
  150. provider=provider_name,
  151. plugin_id=plugin_id,
  152. )
  153. return {"result": "success"}, 200
  154. @setup_required
  155. @login_required
  156. @account_initialization_required
  157. def patch(self, provider_id: str, auth_id: str):
  158. datasource_provider_id = DatasourceProviderID(provider_id)
  159. parser = reqparse.RequestParser()
  160. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  161. args = parser.parse_args()
  162. if not current_user.is_editor:
  163. raise Forbidden()
  164. try:
  165. datasource_provider_service = DatasourceProviderService()
  166. datasource_provider_service.update_datasource_credentials(
  167. tenant_id=current_user.current_tenant_id,
  168. auth_id=auth_id,
  169. provider=datasource_provider_id.provider_name,
  170. plugin_id=datasource_provider_id.plugin_id,
  171. credentials=args["credentials"],
  172. )
  173. except CredentialsValidateFailedError as ex:
  174. raise ValueError(str(ex))
  175. return {"result": "success"}, 201
  176. class DatasourceAuthListApi(Resource):
  177. @setup_required
  178. @login_required
  179. @account_initialization_required
  180. def get(self):
  181. datasource_provider_service = DatasourceProviderService()
  182. datasources = datasource_provider_service.get_all_datasource_credentials(
  183. tenant_id=current_user.current_tenant_id
  184. )
  185. return {"result": jsonable_encoder(datasources)}, 200
  186. class DatasourceAuthOauthCustomClient(Resource):
  187. @setup_required
  188. @login_required
  189. @account_initialization_required
  190. def post(self, provider_id: str):
  191. if not current_user.is_editor:
  192. raise Forbidden()
  193. parser = reqparse.RequestParser()
  194. parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
  195. parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
  196. args = parser.parse_args()
  197. datasource_provider_id = DatasourceProviderID(provider_id)
  198. datasource_provider_service = DatasourceProviderService()
  199. datasource_provider_service.setup_oauth_custom_client_params(
  200. tenant_id=current_user.current_tenant_id,
  201. datasource_provider_id=datasource_provider_id,
  202. client_params=args.get("client_params", {}),
  203. enabled=args.get("enabled", False),
  204. )
  205. return {"result": "success"}, 200
  206. class DatasourceAuthDefaultApi(Resource):
  207. @setup_required
  208. @login_required
  209. @account_initialization_required
  210. def post(self, provider_id: str):
  211. if not current_user.is_editor:
  212. raise Forbidden()
  213. parser = reqparse.RequestParser()
  214. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  215. args = parser.parse_args()
  216. datasource_provider_id = DatasourceProviderID(provider_id)
  217. datasource_provider_service = DatasourceProviderService()
  218. datasource_provider_service.set_default_datasource_provider(
  219. tenant_id=current_user.current_tenant_id,
  220. datasource_provider_id=datasource_provider_id,
  221. credential_id=args["credential_id"],
  222. )
  223. return {"result": "success"}, 200
  224. class DatasourceUpdateProviderNameApi(Resource):
  225. @setup_required
  226. @login_required
  227. @account_initialization_required
  228. def post(self, provider_id: str):
  229. if not current_user.is_editor:
  230. raise Forbidden()
  231. parser = reqparse.RequestParser()
  232. parser.add_argument("name", type=str, required=True, nullable=False, location="json")
  233. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  234. args = parser.parse_args()
  235. datasource_provider_id = DatasourceProviderID(provider_id)
  236. datasource_provider_service = DatasourceProviderService()
  237. datasource_provider_service.update_datasource_provider_name(
  238. tenant_id=current_user.current_tenant_id,
  239. datasource_provider_id=datasource_provider_id,
  240. name=args["name"],
  241. credential_id=args["credential_id"],
  242. )
  243. return {"result": "success"}, 200
  244. api.add_resource(
  245. DatasourcePluginOAuthAuthorizationUrl,
  246. "/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
  247. )
  248. api.add_resource(
  249. DatasourceOAuthCallback,
  250. "/oauth/plugin/<path:provider_id>/datasource/callback",
  251. )
  252. api.add_resource(
  253. DatasourceAuth,
  254. "/auth/plugin/datasource/<path:provider_id>",
  255. )
  256. api.add_resource(
  257. DatasourceAuthUpdateDeleteApi,
  258. "/auth/plugin/datasource/<path:provider_id>/<string:auth_id>",
  259. )
  260. api.add_resource(
  261. DatasourceAuthListApi,
  262. "/auth/plugin/datasource/list",
  263. )
  264. api.add_resource(
  265. DatasourceAuthOauthCustomClient,
  266. "/auth/plugin/datasource/<path:provider_id>/custom-client",
  267. )
  268. api.add_resource(
  269. DatasourceAuthDefaultApi,
  270. "/auth/plugin/datasource/<path:provider_id>/default",
  271. )
  272. api.add_resource(
  273. DatasourceUpdateProviderNameApi,
  274. "/auth/plugin/datasource/<path:provider_id>/update-name",
  275. )