| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import enum
- import secrets
- from datetime import UTC, datetime, timedelta
- from typing import Any, Optional, cast
-
- from werkzeug.exceptions import NotFound, Unauthorized
-
- from configs import dify_config
- from extensions.ext_database import db
- from libs.helper import TokenManager
- from libs.passport import PassportService
- from libs.password import compare_password
- from models.account import Account, AccountStatus
- from models.model import App, EndUser, Site
- from services.app_service import AppService
- from services.enterprise.enterprise_service import EnterpriseService
- from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
- from tasks.mail_email_code_login import send_email_code_login_mail_task
-
-
- class WebAppAuthType(enum.StrEnum):
- """Enum for web app authentication types."""
-
- PUBLIC = "public"
- INTERNAL = "internal"
- EXTERNAL = "external"
-
-
- class WebAppAuthService:
- """Service for web app authentication."""
-
- @staticmethod
- def authenticate(email: str, password: str) -> Account:
- """authenticate account with email and password"""
- account = db.session.query(Account).filter_by(email=email).first()
- if not account:
- raise AccountNotFoundError()
-
- if account.status == AccountStatus.BANNED.value:
- raise AccountLoginError("Account is banned.")
-
- if account.password is None or not compare_password(password, account.password, account.password_salt):
- raise AccountPasswordError("Invalid email or password.")
-
- return cast(Account, account)
-
- @classmethod
- def login(cls, account: Account) -> str:
- access_token = cls._get_account_jwt_token(account=account)
-
- return access_token
-
- @classmethod
- def get_user_through_email(cls, email: str):
- account = db.session.query(Account).filter(Account.email == email).first()
- if not account:
- return None
-
- if account.status == AccountStatus.BANNED.value:
- raise Unauthorized("Account is banned.")
-
- return account
-
- @classmethod
- def send_email_code_login_email(
- cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
- ):
- email = account.email if account else email
- if email is None:
- raise ValueError("Email must be provided.")
-
- code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
- token = TokenManager.generate_token(
- account=account, email=email, token_type="email_code_login", additional_data={"code": code}
- )
- send_email_code_login_mail_task.delay(
- language=language,
- to=account.email if account else email,
- code=code,
- )
-
- return token
-
- @classmethod
- def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
- return TokenManager.get_token_data(token, "email_code_login")
-
- @classmethod
- def revoke_email_code_login_token(cls, token: str):
- TokenManager.revoke_token(token, "email_code_login")
-
- @classmethod
- def create_end_user(cls, app_code, email) -> EndUser:
- site = db.session.query(Site).filter(Site.code == app_code).first()
- if not site:
- raise NotFound("Site not found.")
- app_model = db.session.query(App).filter(App.id == site.app_id).first()
- if not app_model:
- raise NotFound("App not found.")
- end_user = EndUser(
- tenant_id=app_model.tenant_id,
- app_id=app_model.id,
- type="browser",
- is_anonymous=False,
- session_id=email,
- name="enterpriseuser",
- external_user_id="enterpriseuser",
- )
- db.session.add(end_user)
- db.session.commit()
-
- return end_user
-
- @classmethod
- def _get_account_jwt_token(cls, account: Account) -> str:
- exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
- exp = int(exp_dt.timestamp())
-
- payload = {
- "sub": "Web API Passport",
- "user_id": account.id,
- "session_id": account.email,
- "token_source": "webapp_login_token",
- "auth_type": "internal",
- "exp": exp,
- }
-
- token: str = PassportService().issue(payload)
- return token
-
- @classmethod
- def is_app_require_permission_check(
- cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
- ) -> bool:
- """
- Check if the app requires permission check based on its access mode.
- """
- modes_requiring_permission_check = [
- "private",
- "private_all",
- ]
- if access_mode:
- return access_mode in modes_requiring_permission_check
-
- if not app_code and not app_id:
- raise ValueError("Either app_code or app_id must be provided.")
-
- if app_code:
- app_id = AppService.get_app_id_by_code(app_code)
- if not app_id:
- raise ValueError("App ID could not be determined from the provided app_code.")
-
- webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
- if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
- return True
- return False
-
- @classmethod
- def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType:
- """
- Get the authentication type for the app based on its access mode.
- """
- if not app_code and not access_mode:
- raise ValueError("Either app_code or access_mode must be provided.")
-
- if access_mode:
- if access_mode == "public":
- return WebAppAuthType.PUBLIC
- elif access_mode in ["private", "private_all"]:
- return WebAppAuthType.INTERNAL
- elif access_mode == "sso_verified":
- return WebAppAuthType.EXTERNAL
-
- if app_code:
- webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
- return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
-
- raise ValueError("Could not determine app authentication type.")
|