選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

datasource_auth.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from fastapi.encoders import jsonable_encoder
  2. from flask import make_response, redirect, request
  3. from flask_login import current_user # type: ignore
  4. from flask_restful import ( # type: ignore
  5. Resource, # type: ignore
  6. reqparse,
  7. )
  8. from werkzeug.exceptions import Forbidden, NotFound
  9. from configs import dify_config
  10. from controllers.console import api
  11. from controllers.console.wraps import (
  12. account_initialization_required,
  13. setup_required,
  14. )
  15. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  16. from core.plugin.entities.plugin import DatasourceProviderID
  17. from core.plugin.impl.oauth import OAuthHandler
  18. from libs.helper import StrLen
  19. from libs.login import login_required
  20. from services.datasource_provider_service import DatasourceProviderService
  21. from services.plugin.oauth_service import OAuthProxyService
  22. class DatasourcePluginOAuthAuthorizationUrl(Resource):
  23. @setup_required
  24. @login_required
  25. @account_initialization_required
  26. def get(self, provider_id: str):
  27. user = current_user
  28. tenant_id = user.current_tenant_id
  29. if not current_user.is_editor:
  30. raise Forbidden()
  31. credential_id = request.args.get("credential_id")
  32. datasource_provider_id = DatasourceProviderID(provider_id)
  33. provider_name = datasource_provider_id.provider_name
  34. plugin_id = datasource_provider_id.plugin_id
  35. oauth_config = DatasourceProviderService().get_oauth_client(
  36. tenant_id=tenant_id,
  37. datasource_provider_id=datasource_provider_id,
  38. )
  39. if not oauth_config:
  40. raise ValueError(f"No OAuth Client Config for {provider_id}")
  41. context_id = OAuthProxyService.create_proxy_context(
  42. user_id=current_user.id,
  43. tenant_id=tenant_id,
  44. plugin_id=plugin_id,
  45. provider=provider_name,
  46. credential_id=credential_id,
  47. )
  48. oauth_handler = OAuthHandler()
  49. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
  50. authorization_url_response = oauth_handler.get_authorization_url(
  51. tenant_id=tenant_id,
  52. user_id=user.id,
  53. plugin_id=plugin_id,
  54. provider=provider_name,
  55. redirect_uri=redirect_uri,
  56. system_credentials=oauth_config,
  57. )
  58. response = make_response(jsonable_encoder(authorization_url_response))
  59. response.set_cookie(
  60. "context_id",
  61. context_id,
  62. httponly=True,
  63. samesite="Lax",
  64. max_age=OAuthProxyService.__MAX_AGE__,
  65. )
  66. return response
  67. class DatasourceOAuthCallback(Resource):
  68. @setup_required
  69. def get(self, provider_id: str):
  70. context_id = request.cookies.get("context_id") or request.args.get("context_id")
  71. if not context_id:
  72. raise Forbidden("context_id not found")
  73. context = OAuthProxyService.use_proxy_context(context_id)
  74. if context is None:
  75. raise Forbidden("Invalid context_id")
  76. user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
  77. datasource_provider_id = DatasourceProviderID(provider_id)
  78. plugin_id = datasource_provider_id.plugin_id
  79. datasource_provider_service = DatasourceProviderService()
  80. oauth_client_params = datasource_provider_service.get_oauth_client(
  81. tenant_id=tenant_id,
  82. datasource_provider_id=datasource_provider_id,
  83. )
  84. if not oauth_client_params:
  85. raise NotFound()
  86. redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
  87. oauth_handler = OAuthHandler()
  88. oauth_response = oauth_handler.get_credentials(
  89. tenant_id=tenant_id,
  90. user_id=user_id,
  91. plugin_id=plugin_id,
  92. provider=datasource_provider_id.provider_name,
  93. redirect_uri=redirect_uri,
  94. system_credentials=oauth_client_params,
  95. request=request,
  96. )
  97. credential_id = context.get("credential_id")
  98. if credential_id:
  99. datasource_provider_service.reauthorize_datasource_oauth_provider(
  100. tenant_id=tenant_id,
  101. provider_id=datasource_provider_id,
  102. avatar_url=oauth_response.metadata.get("avatar_url") or None,
  103. name=oauth_response.metadata.get("name") or None,
  104. expire_at=oauth_response.expires_at,
  105. credentials=dict(oauth_response.credentials),
  106. credential_id=context.get("credential_id"),
  107. )
  108. else:
  109. datasource_provider_service.add_datasource_oauth_provider(
  110. tenant_id=tenant_id,
  111. provider_id=datasource_provider_id,
  112. avatar_url=oauth_response.metadata.get("avatar_url") or None,
  113. name=oauth_response.metadata.get("name") or None,
  114. expire_at=oauth_response.expires_at,
  115. credentials=dict(oauth_response.credentials),
  116. )
  117. return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
  118. class DatasourceAuth(Resource):
  119. @setup_required
  120. @login_required
  121. @account_initialization_required
  122. def post(self, provider_id: str):
  123. if not current_user.is_editor:
  124. raise Forbidden()
  125. parser = reqparse.RequestParser()
  126. parser.add_argument(
  127. "name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
  128. )
  129. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  130. args = parser.parse_args()
  131. datasource_provider_id = DatasourceProviderID(provider_id)
  132. datasource_provider_service = DatasourceProviderService()
  133. try:
  134. datasource_provider_service.add_datasource_api_key_provider(
  135. tenant_id=current_user.current_tenant_id,
  136. provider_id=datasource_provider_id,
  137. credentials=args["credentials"],
  138. name=args["name"],
  139. )
  140. except CredentialsValidateFailedError as ex:
  141. raise ValueError(str(ex))
  142. return {"result": "success"}, 200
  143. @setup_required
  144. @login_required
  145. @account_initialization_required
  146. def get(self, provider_id: str):
  147. datasource_provider_id = DatasourceProviderID(provider_id)
  148. datasource_provider_service = DatasourceProviderService()
  149. datasources = datasource_provider_service.list_datasource_credentials(
  150. tenant_id=current_user.current_tenant_id,
  151. provider=datasource_provider_id.provider_name,
  152. plugin_id=datasource_provider_id.plugin_id,
  153. )
  154. return {"result": datasources}, 200
  155. class DatasourceAuthDeleteApi(Resource):
  156. @setup_required
  157. @login_required
  158. @account_initialization_required
  159. def post(self, provider_id: str):
  160. datasource_provider_id = DatasourceProviderID(provider_id)
  161. plugin_id = datasource_provider_id.plugin_id
  162. provider_name = datasource_provider_id.provider_name
  163. if not current_user.is_editor:
  164. raise Forbidden()
  165. parser = reqparse.RequestParser()
  166. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  167. args = parser.parse_args()
  168. datasource_provider_service = DatasourceProviderService()
  169. datasource_provider_service.remove_datasource_credentials(
  170. tenant_id=current_user.current_tenant_id,
  171. auth_id=args["credential_id"],
  172. provider=provider_name,
  173. plugin_id=plugin_id,
  174. )
  175. return {"result": "success"}, 200
  176. class DatasourceAuthUpdateApi(Resource):
  177. @setup_required
  178. @login_required
  179. @account_initialization_required
  180. def post(self, provider_id: str):
  181. datasource_provider_id = DatasourceProviderID(provider_id)
  182. parser = reqparse.RequestParser()
  183. parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
  184. parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
  185. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  186. args = parser.parse_args()
  187. if not current_user.is_editor:
  188. raise Forbidden()
  189. datasource_provider_service = DatasourceProviderService()
  190. datasource_provider_service.update_datasource_credentials(
  191. tenant_id=current_user.current_tenant_id,
  192. auth_id=args["credential_id"],
  193. provider=datasource_provider_id.provider_name,
  194. plugin_id=datasource_provider_id.plugin_id,
  195. credentials=args.get("credentials", {}),
  196. name=args.get("name", None),
  197. )
  198. return {"result": "success"}, 201
  199. class DatasourceAuthListApi(Resource):
  200. @setup_required
  201. @login_required
  202. @account_initialization_required
  203. def get(self):
  204. datasource_provider_service = DatasourceProviderService()
  205. datasources = datasource_provider_service.get_all_datasource_credentials(
  206. tenant_id=current_user.current_tenant_id
  207. )
  208. return {"result": jsonable_encoder(datasources)}, 200
  209. class DatasourceHardCodeAuthListApi(Resource):
  210. @setup_required
  211. @login_required
  212. @account_initialization_required
  213. def get(self):
  214. datasource_provider_service = DatasourceProviderService()
  215. datasources = datasource_provider_service.get_hard_code_datasource_credentials(
  216. tenant_id=current_user.current_tenant_id
  217. )
  218. return {"result": jsonable_encoder(datasources)}, 200
  219. class DatasourceAuthOauthCustomClient(Resource):
  220. @setup_required
  221. @login_required
  222. @account_initialization_required
  223. def post(self, provider_id: str):
  224. if not current_user.is_editor:
  225. raise Forbidden()
  226. parser = reqparse.RequestParser()
  227. parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
  228. parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
  229. args = parser.parse_args()
  230. datasource_provider_id = DatasourceProviderID(provider_id)
  231. datasource_provider_service = DatasourceProviderService()
  232. datasource_provider_service.setup_oauth_custom_client_params(
  233. tenant_id=current_user.current_tenant_id,
  234. datasource_provider_id=datasource_provider_id,
  235. client_params=args.get("client_params", {}),
  236. enabled=args.get("enable_oauth_custom_client", False),
  237. )
  238. return {"result": "success"}, 200
  239. @setup_required
  240. @login_required
  241. @account_initialization_required
  242. def delete(self, provider_id: str):
  243. datasource_provider_id = DatasourceProviderID(provider_id)
  244. datasource_provider_service = DatasourceProviderService()
  245. datasource_provider_service.remove_oauth_custom_client_params(
  246. tenant_id=current_user.current_tenant_id,
  247. datasource_provider_id=datasource_provider_id,
  248. )
  249. return {"result": "success"}, 200
  250. class DatasourceAuthDefaultApi(Resource):
  251. @setup_required
  252. @login_required
  253. @account_initialization_required
  254. def post(self, provider_id: str):
  255. if not current_user.is_editor:
  256. raise Forbidden()
  257. parser = reqparse.RequestParser()
  258. parser.add_argument("id", type=str, required=True, nullable=False, location="json")
  259. args = parser.parse_args()
  260. datasource_provider_id = DatasourceProviderID(provider_id)
  261. datasource_provider_service = DatasourceProviderService()
  262. datasource_provider_service.set_default_datasource_provider(
  263. tenant_id=current_user.current_tenant_id,
  264. datasource_provider_id=datasource_provider_id,
  265. credential_id=args["id"],
  266. )
  267. return {"result": "success"}, 200
  268. class DatasourceUpdateProviderNameApi(Resource):
  269. @setup_required
  270. @login_required
  271. @account_initialization_required
  272. def post(self, provider_id: str):
  273. if not current_user.is_editor:
  274. raise Forbidden()
  275. parser = reqparse.RequestParser()
  276. parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
  277. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  278. args = parser.parse_args()
  279. datasource_provider_id = DatasourceProviderID(provider_id)
  280. datasource_provider_service = DatasourceProviderService()
  281. datasource_provider_service.update_datasource_provider_name(
  282. tenant_id=current_user.current_tenant_id,
  283. datasource_provider_id=datasource_provider_id,
  284. name=args["name"],
  285. credential_id=args["credential_id"],
  286. )
  287. return {"result": "success"}, 200
  288. api.add_resource(
  289. DatasourcePluginOAuthAuthorizationUrl,
  290. "/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
  291. )
  292. api.add_resource(
  293. DatasourceOAuthCallback,
  294. "/oauth/plugin/<path:provider_id>/datasource/callback",
  295. )
  296. api.add_resource(
  297. DatasourceAuth,
  298. "/auth/plugin/datasource/<path:provider_id>",
  299. )
  300. api.add_resource(
  301. DatasourceAuthUpdateApi,
  302. "/auth/plugin/datasource/<path:provider_id>/update",
  303. )
  304. api.add_resource(
  305. DatasourceAuthDeleteApi,
  306. "/auth/plugin/datasource/<path:provider_id>/delete",
  307. )
  308. api.add_resource(
  309. DatasourceAuthListApi,
  310. "/auth/plugin/datasource/list",
  311. )
  312. api.add_resource(
  313. DatasourceHardCodeAuthListApi,
  314. "/auth/plugin/datasource/default-list",
  315. )
  316. api.add_resource(
  317. DatasourceAuthOauthCustomClient,
  318. "/auth/plugin/datasource/<path:provider_id>/custom-client",
  319. )
  320. api.add_resource(
  321. DatasourceAuthDefaultApi,
  322. "/auth/plugin/datasource/<path:provider_id>/default",
  323. )
  324. api.add_resource(
  325. DatasourceUpdateProviderNameApi,
  326. "/auth/plugin/datasource/<path:provider_id>/update-name",
  327. )