Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

passport.py 7.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import uuid
  2. from datetime import UTC, datetime, timedelta
  3. from flask import request
  4. from flask_restful import Resource
  5. from werkzeug.exceptions import NotFound, Unauthorized
  6. from configs import dify_config
  7. from controllers.web import api
  8. from controllers.web.error import WebAppAuthRequiredError
  9. from extensions.ext_database import db
  10. from libs.passport import PassportService
  11. from models.model import App, EndUser, Site
  12. from services.enterprise.enterprise_service import EnterpriseService
  13. from services.feature_service import FeatureService
  14. from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
  15. class PassportResource(Resource):
  16. """Base resource for passport."""
  17. def get(self):
  18. system_features = FeatureService.get_system_features()
  19. app_code = request.headers.get("X-App-Code")
  20. user_id = request.args.get("user_id")
  21. web_app_access_token = request.args.get("web_app_access_token")
  22. if app_code is None:
  23. raise Unauthorized("X-App-Code header is missing.")
  24. # exchange token for enterprise logined web user
  25. enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
  26. if enterprise_user_decoded:
  27. # a web user has already logged in, exchange a token for this app without redirecting to the login page
  28. return exchange_token_for_existing_web_user(
  29. app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
  30. )
  31. if system_features.webapp_auth.enabled:
  32. app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
  33. if not app_settings or not app_settings.access_mode == "public":
  34. raise WebAppAuthRequiredError()
  35. # get site from db and check if it is normal
  36. site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
  37. if not site:
  38. raise NotFound()
  39. # get app from db and check if it is normal and enable_site
  40. app_model = db.session.query(App).filter(App.id == site.app_id).first()
  41. if not app_model or app_model.status != "normal" or not app_model.enable_site:
  42. raise NotFound()
  43. if user_id:
  44. end_user = (
  45. db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
  46. )
  47. if end_user:
  48. pass
  49. else:
  50. end_user = EndUser(
  51. tenant_id=app_model.tenant_id,
  52. app_id=app_model.id,
  53. type="browser",
  54. is_anonymous=True,
  55. session_id=user_id,
  56. )
  57. db.session.add(end_user)
  58. db.session.commit()
  59. else:
  60. end_user = EndUser(
  61. tenant_id=app_model.tenant_id,
  62. app_id=app_model.id,
  63. type="browser",
  64. is_anonymous=True,
  65. session_id=generate_session_id(),
  66. )
  67. db.session.add(end_user)
  68. db.session.commit()
  69. payload = {
  70. "iss": site.app_id,
  71. "sub": "Web API Passport",
  72. "app_id": site.app_id,
  73. "app_code": app_code,
  74. "end_user_id": end_user.id,
  75. }
  76. tk = PassportService().issue(payload)
  77. return {
  78. "access_token": tk,
  79. }
  80. api.add_resource(PassportResource, "/passport")
  81. def decode_enterprise_webapp_user_id(jwt_token: str | None):
  82. """
  83. Decode the enterprise user session from the Authorization header.
  84. """
  85. if not jwt_token:
  86. return None
  87. decoded = PassportService().verify(jwt_token)
  88. source = decoded.get("token_source")
  89. if not source or source != "webapp_login_token":
  90. raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
  91. return decoded
  92. def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
  93. """
  94. Exchange a token for an existing web user session.
  95. """
  96. user_id = enterprise_user_decoded.get("user_id")
  97. end_user_id = enterprise_user_decoded.get("end_user_id")
  98. session_id = enterprise_user_decoded.get("session_id")
  99. user_auth_type = enterprise_user_decoded.get("auth_type")
  100. if not user_auth_type:
  101. raise Unauthorized("Missing auth_type in the token.")
  102. site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
  103. if not site:
  104. raise NotFound()
  105. app_model = db.session.query(App).filter(App.id == site.app_id).first()
  106. if not app_model or app_model.status != "normal" or not app_model.enable_site:
  107. raise NotFound()
  108. app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
  109. if app_auth_type == WebAppAuthType.PUBLIC:
  110. return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
  111. elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
  112. raise WebAppAuthRequiredError("Please login as external user.")
  113. elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
  114. raise WebAppAuthRequiredError("Please login as internal user.")
  115. end_user = None
  116. if end_user_id:
  117. end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
  118. if session_id:
  119. end_user = (
  120. db.session.query(EndUser)
  121. .filter(
  122. EndUser.session_id == session_id,
  123. EndUser.tenant_id == app_model.tenant_id,
  124. EndUser.app_id == app_model.id,
  125. )
  126. .first()
  127. )
  128. if not end_user:
  129. if not session_id:
  130. raise NotFound("Missing session_id for existing web user.")
  131. end_user = EndUser(
  132. tenant_id=app_model.tenant_id,
  133. app_id=app_model.id,
  134. type="browser",
  135. is_anonymous=True,
  136. session_id=session_id,
  137. )
  138. db.session.add(end_user)
  139. db.session.commit()
  140. exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
  141. exp = int(exp_dt.timestamp())
  142. payload = {
  143. "iss": site.id,
  144. "sub": "Web API Passport",
  145. "app_id": site.app_id,
  146. "app_code": site.code,
  147. "user_id": user_id,
  148. "end_user_id": end_user.id,
  149. "auth_type": user_auth_type,
  150. "granted_at": int(datetime.now(UTC).timestamp()),
  151. "token_source": "webapp",
  152. "exp": exp,
  153. }
  154. token: str = PassportService().issue(payload)
  155. return {
  156. "access_token": token,
  157. }
  158. def _exchange_for_public_app_token(app_model, site, token_decoded):
  159. user_id = token_decoded.get("user_id")
  160. end_user = None
  161. if user_id:
  162. end_user = (
  163. db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
  164. )
  165. if not end_user:
  166. end_user = EndUser(
  167. tenant_id=app_model.tenant_id,
  168. app_id=app_model.id,
  169. type="browser",
  170. is_anonymous=True,
  171. session_id=generate_session_id(),
  172. )
  173. db.session.add(end_user)
  174. db.session.commit()
  175. payload = {
  176. "iss": site.app_id,
  177. "sub": "Web API Passport",
  178. "app_id": site.app_id,
  179. "app_code": site.code,
  180. "end_user_id": end_user.id,
  181. }
  182. tk = PassportService().issue(payload)
  183. return {
  184. "access_token": tk,
  185. }
  186. def generate_session_id():
  187. """
  188. Generate a unique session ID.
  189. """
  190. while True:
  191. session_id = str(uuid.uuid4())
  192. existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
  193. if existing_count == 0:
  194. return session_id