Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import uuid
  2. from datetime import UTC, datetime, timedelta
  3. from flask import request
  4. from flask_restful import Resource
  5. from sqlalchemy import func, select
  6. from werkzeug.exceptions import NotFound, Unauthorized
  7. from configs import dify_config
  8. from controllers.web import api
  9. from controllers.web.error import WebAppAuthRequiredError
  10. from extensions.ext_database import db
  11. from libs.passport import PassportService
  12. from models.model import App, EndUser, Site
  13. from services.enterprise.enterprise_service import EnterpriseService
  14. from services.feature_service import FeatureService
  15. from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
  16. class PassportResource(Resource):
  17. """Base resource for passport."""
  18. def get(self):
  19. system_features = FeatureService.get_system_features()
  20. app_code = request.headers.get("X-App-Code")
  21. user_id = request.args.get("user_id")
  22. web_app_access_token = request.args.get("web_app_access_token")
  23. if app_code is None:
  24. raise Unauthorized("X-App-Code header is missing.")
  25. # exchange token for enterprise logined web user
  26. enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
  27. if enterprise_user_decoded:
  28. # a web user has already logged in, exchange a token for this app without redirecting to the login page
  29. return exchange_token_for_existing_web_user(
  30. app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
  31. )
  32. if system_features.webapp_auth.enabled:
  33. app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
  34. if not app_settings or not app_settings.access_mode == "public":
  35. raise WebAppAuthRequiredError()
  36. # get site from db and check if it is normal
  37. site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
  38. if not site:
  39. raise NotFound()
  40. # get app from db and check if it is normal and enable_site
  41. app_model = db.session.scalar(select(App).where(App.id == site.app_id))
  42. if not app_model or app_model.status != "normal" or not app_model.enable_site:
  43. raise NotFound()
  44. if user_id:
  45. end_user = db.session.scalar(
  46. select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
  47. )
  48. if end_user:
  49. pass
  50. else:
  51. end_user = EndUser(
  52. tenant_id=app_model.tenant_id,
  53. app_id=app_model.id,
  54. type="browser",
  55. is_anonymous=True,
  56. session_id=user_id,
  57. )
  58. db.session.add(end_user)
  59. db.session.commit()
  60. else:
  61. end_user = EndUser(
  62. tenant_id=app_model.tenant_id,
  63. app_id=app_model.id,
  64. type="browser",
  65. is_anonymous=True,
  66. session_id=generate_session_id(),
  67. )
  68. db.session.add(end_user)
  69. db.session.commit()
  70. payload = {
  71. "iss": site.app_id,
  72. "sub": "Web API Passport",
  73. "app_id": site.app_id,
  74. "app_code": app_code,
  75. "end_user_id": end_user.id,
  76. }
  77. tk = PassportService().issue(payload)
  78. return {
  79. "access_token": tk,
  80. }
  81. api.add_resource(PassportResource, "/passport")
  82. def decode_enterprise_webapp_user_id(jwt_token: str | None):
  83. """
  84. Decode the enterprise user session from the Authorization header.
  85. """
  86. if not jwt_token:
  87. return None
  88. decoded = PassportService().verify(jwt_token)
  89. source = decoded.get("token_source")
  90. if not source or source != "webapp_login_token":
  91. raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
  92. return decoded
  93. def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
  94. """
  95. Exchange a token for an existing web user session.
  96. """
  97. user_id = enterprise_user_decoded.get("user_id")
  98. end_user_id = enterprise_user_decoded.get("end_user_id")
  99. session_id = enterprise_user_decoded.get("session_id")
  100. user_auth_type = enterprise_user_decoded.get("auth_type")
  101. if not user_auth_type:
  102. raise Unauthorized("Missing auth_type in the token.")
  103. site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
  104. if not site:
  105. raise NotFound()
  106. app_model = db.session.scalar(select(App).where(App.id == site.app_id))
  107. if not app_model or app_model.status != "normal" or not app_model.enable_site:
  108. raise NotFound()
  109. app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
  110. if app_auth_type == WebAppAuthType.PUBLIC:
  111. return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
  112. elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
  113. raise WebAppAuthRequiredError("Please login as external user.")
  114. elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
  115. raise WebAppAuthRequiredError("Please login as internal user.")
  116. end_user = None
  117. if end_user_id:
  118. end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
  119. if session_id:
  120. end_user = db.session.scalar(
  121. select(EndUser).where(
  122. EndUser.session_id == session_id,
  123. EndUser.tenant_id == app_model.tenant_id,
  124. EndUser.app_id == app_model.id,
  125. )
  126. )
  127. if not end_user:
  128. if not session_id:
  129. raise NotFound("Missing session_id for existing web user.")
  130. end_user = EndUser(
  131. tenant_id=app_model.tenant_id,
  132. app_id=app_model.id,
  133. type="browser",
  134. is_anonymous=True,
  135. session_id=session_id,
  136. )
  137. db.session.add(end_user)
  138. db.session.commit()
  139. exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
  140. exp = int(exp_dt.timestamp())
  141. payload = {
  142. "iss": site.id,
  143. "sub": "Web API Passport",
  144. "app_id": site.app_id,
  145. "app_code": site.code,
  146. "user_id": user_id,
  147. "end_user_id": end_user.id,
  148. "auth_type": user_auth_type,
  149. "granted_at": int(datetime.now(UTC).timestamp()),
  150. "token_source": "webapp",
  151. "exp": exp,
  152. }
  153. token: str = PassportService().issue(payload)
  154. return {
  155. "access_token": token,
  156. }
  157. def _exchange_for_public_app_token(app_model, site, token_decoded):
  158. user_id = token_decoded.get("user_id")
  159. end_user = None
  160. if user_id:
  161. end_user = db.session.scalar(
  162. select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
  163. )
  164. if not end_user:
  165. end_user = EndUser(
  166. tenant_id=app_model.tenant_id,
  167. app_id=app_model.id,
  168. type="browser",
  169. is_anonymous=True,
  170. session_id=generate_session_id(),
  171. )
  172. db.session.add(end_user)
  173. db.session.commit()
  174. payload = {
  175. "iss": site.app_id,
  176. "sub": "Web API Passport",
  177. "app_id": site.app_id,
  178. "app_code": site.code,
  179. "end_user_id": end_user.id,
  180. }
  181. tk = PassportService().issue(payload)
  182. return {
  183. "access_token": tk,
  184. }
  185. def generate_session_id():
  186. """
  187. Generate a unique session ID.
  188. """
  189. while True:
  190. session_id = str(uuid.uuid4())
  191. existing_count = db.session.scalar(
  192. select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
  193. )
  194. if existing_count == 0:
  195. return session_id