Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

webapp_auth_service.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import random
  2. from datetime import UTC, datetime, timedelta
  3. from typing import Any, Optional, cast
  4. from werkzeug.exceptions import NotFound, Unauthorized
  5. from configs import dify_config
  6. from controllers.web.error import WebAppAuthAccessDeniedError
  7. from extensions.ext_database import db
  8. from libs.helper import TokenManager
  9. from libs.passport import PassportService
  10. from libs.password import compare_password
  11. from models.account import Account, AccountStatus
  12. from models.model import App, EndUser, Site
  13. from services.enterprise.enterprise_service import EnterpriseService
  14. from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
  15. from services.feature_service import FeatureService
  16. from tasks.mail_email_code_login import send_email_code_login_mail_task
  17. class WebAppAuthService:
  18. """Service for web app authentication."""
  19. @staticmethod
  20. def authenticate(email: str, password: str) -> Account:
  21. """authenticate account with email and password"""
  22. account = Account.query.filter_by(email=email).first()
  23. if not account:
  24. raise AccountNotFoundError()
  25. if account.status == AccountStatus.BANNED.value:
  26. raise AccountLoginError("Account is banned.")
  27. if account.password is None or not compare_password(password, account.password, account.password_salt):
  28. raise AccountPasswordError("Invalid email or password.")
  29. return cast(Account, account)
  30. @classmethod
  31. def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
  32. site = db.session.query(Site).filter(Site.code == app_code).first()
  33. if not site:
  34. raise NotFound("Site not found.")
  35. access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
  36. return access_token
  37. @classmethod
  38. def get_user_through_email(cls, email: str):
  39. account = db.session.query(Account).filter(Account.email == email).first()
  40. if not account:
  41. return None
  42. if account.status == AccountStatus.BANNED.value:
  43. raise Unauthorized("Account is banned.")
  44. return account
  45. @classmethod
  46. def send_email_code_login_email(
  47. cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
  48. ):
  49. email = account.email if account else email
  50. if email is None:
  51. raise ValueError("Email must be provided.")
  52. code = "".join([str(random.randint(0, 9)) for _ in range(6)])
  53. token = TokenManager.generate_token(
  54. account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
  55. )
  56. send_email_code_login_mail_task.delay(
  57. language=language,
  58. to=account.email if account else email,
  59. code=code,
  60. )
  61. return token
  62. @classmethod
  63. def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
  64. return TokenManager.get_token_data(token, "webapp_email_code_login")
  65. @classmethod
  66. def revoke_email_code_login_token(cls, token: str):
  67. TokenManager.revoke_token(token, "webapp_email_code_login")
  68. @classmethod
  69. def create_end_user(cls, app_code, email) -> EndUser:
  70. site = db.session.query(Site).filter(Site.code == app_code).first()
  71. if not site:
  72. raise NotFound("Site not found.")
  73. app_model = db.session.query(App).filter(App.id == site.app_id).first()
  74. if not app_model:
  75. raise NotFound("App not found.")
  76. end_user = EndUser(
  77. tenant_id=app_model.tenant_id,
  78. app_id=app_model.id,
  79. type="browser",
  80. is_anonymous=False,
  81. session_id=email,
  82. name="enterpriseuser",
  83. external_user_id="enterpriseuser",
  84. )
  85. db.session.add(end_user)
  86. db.session.commit()
  87. return end_user
  88. @classmethod
  89. def _validate_user_accessibility(cls, account: Account, app_code: str):
  90. """Check if the user is allowed to access the app."""
  91. system_features = FeatureService.get_system_features()
  92. if system_features.webapp_auth.enabled:
  93. app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
  94. if (
  95. app_settings.access_mode != "public"
  96. and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
  97. ):
  98. raise WebAppAuthAccessDeniedError()
  99. @classmethod
  100. def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
  101. exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
  102. exp = int(exp_dt.timestamp())
  103. payload = {
  104. "iss": site.id,
  105. "sub": "Web API Passport",
  106. "app_id": site.app_id,
  107. "app_code": site.code,
  108. "user_id": account.id,
  109. "end_user_id": end_user_id,
  110. "token_source": "webapp",
  111. "exp": exp,
  112. }
  113. token: str = PassportService().issue(payload)
  114. return token