Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

wraps.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from collections.abc import Callable
  2. from functools import wraps
  3. from typing import Optional
  4. from flask import current_app, request
  5. from flask_login import user_logged_in
  6. from flask_restful import reqparse
  7. from pydantic import BaseModel
  8. from sqlalchemy.orm import Session
  9. from extensions.ext_database import db
  10. from libs.login import _get_user
  11. from models.account import Account, Tenant
  12. from models.model import EndUser
  13. from services.account_service import AccountService
  14. def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
  15. try:
  16. with Session(db.engine) as session:
  17. if not user_id:
  18. user_id = "DEFAULT-USER"
  19. if user_id == "DEFAULT-USER":
  20. user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
  21. if not user_model:
  22. user_model = EndUser(
  23. tenant_id=tenant_id,
  24. type="service_api",
  25. is_anonymous=True if user_id == "DEFAULT-USER" else False,
  26. session_id=user_id,
  27. )
  28. session.add(user_model)
  29. session.commit()
  30. session.refresh(user_model)
  31. else:
  32. user_model = AccountService.load_user(user_id)
  33. if not user_model:
  34. user_model = session.query(EndUser).where(EndUser.id == user_id).first()
  35. if not user_model:
  36. raise ValueError("user not found")
  37. except Exception:
  38. raise ValueError("user not found")
  39. return user_model
  40. def get_user_tenant(view: Optional[Callable] = None):
  41. def decorator(view_func):
  42. @wraps(view_func)
  43. def decorated_view(*args, **kwargs):
  44. # fetch json body
  45. parser = reqparse.RequestParser()
  46. parser.add_argument("tenant_id", type=str, required=True, location="json")
  47. parser.add_argument("user_id", type=str, required=True, location="json")
  48. kwargs = parser.parse_args()
  49. user_id = kwargs.get("user_id")
  50. tenant_id = kwargs.get("tenant_id")
  51. if not tenant_id:
  52. raise ValueError("tenant_id is required")
  53. if not user_id:
  54. user_id = "DEFAULT-USER"
  55. del kwargs["tenant_id"]
  56. del kwargs["user_id"]
  57. try:
  58. tenant_model = (
  59. db.session.query(Tenant)
  60. .where(
  61. Tenant.id == tenant_id,
  62. )
  63. .first()
  64. )
  65. except Exception:
  66. raise ValueError("tenant not found")
  67. if not tenant_model:
  68. raise ValueError("tenant not found")
  69. kwargs["tenant_model"] = tenant_model
  70. user = get_user(tenant_id, user_id)
  71. kwargs["user_model"] = user
  72. current_app.login_manager._update_request_context_with_user(user) # type: ignore
  73. user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
  74. return view_func(*args, **kwargs)
  75. return decorated_view
  76. if view is None:
  77. return decorator
  78. else:
  79. return decorator(view)
  80. def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
  81. def decorator(view_func):
  82. def decorated_view(*args, **kwargs):
  83. try:
  84. data = request.get_json()
  85. except Exception:
  86. raise ValueError("invalid json")
  87. try:
  88. payload = payload_type(**data)
  89. except Exception as e:
  90. raise ValueError(f"invalid payload: {str(e)}")
  91. kwargs["payload"] = payload
  92. return view_func(*args, **kwargs)
  93. return decorated_view
  94. if view is None:
  95. return decorator
  96. else:
  97. return decorator(view)