You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

wraps.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from collections.abc import Callable
  2. from functools import wraps
  3. from typing import Optional, ParamSpec, TypeVar
  4. from flask import current_app, request
  5. from flask_login import user_logged_in
  6. from flask_restx import reqparse
  7. from pydantic import BaseModel
  8. from sqlalchemy.orm import Session
  9. from core.file.constants import DEFAULT_SERVICE_API_USER_ID
  10. from extensions.ext_database import db
  11. from libs.login import _get_user
  12. from models.account import Tenant
  13. from models.model import EndUser
  14. P = ParamSpec("P")
  15. R = TypeVar("R")
  16. def get_user(tenant_id: str, user_id: str | None) -> EndUser:
  17. """
  18. Get current user
  19. NOTE: user_id is not trusted, it could be maliciously set to any value.
  20. As a result, it could only be considered as an end user id.
  21. """
  22. try:
  23. with Session(db.engine) as session:
  24. if not user_id:
  25. user_id = DEFAULT_SERVICE_API_USER_ID
  26. user_model = (
  27. session.query(EndUser)
  28. .where(
  29. EndUser.session_id == user_id,
  30. EndUser.tenant_id == tenant_id,
  31. )
  32. .first()
  33. )
  34. if not user_model:
  35. user_model = EndUser(
  36. tenant_id=tenant_id,
  37. type="service_api",
  38. is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
  39. session_id=user_id,
  40. )
  41. session.add(user_model)
  42. session.commit()
  43. session.refresh(user_model)
  44. except Exception:
  45. raise ValueError("user not found")
  46. return user_model
  47. def get_user_tenant(view: Optional[Callable[P, R]] = None):
  48. def decorator(view_func: Callable[P, R]):
  49. @wraps(view_func)
  50. def decorated_view(*args: P.args, **kwargs: P.kwargs):
  51. # fetch json body
  52. parser = reqparse.RequestParser()
  53. parser.add_argument("tenant_id", type=str, required=True, location="json")
  54. parser.add_argument("user_id", type=str, required=True, location="json")
  55. p = parser.parse_args()
  56. user_id: Optional[str] = p.get("user_id")
  57. tenant_id: str = p.get("tenant_id")
  58. if not tenant_id:
  59. raise ValueError("tenant_id is required")
  60. if not user_id:
  61. user_id = DEFAULT_SERVICE_API_USER_ID
  62. del kwargs["tenant_id"]
  63. del kwargs["user_id"]
  64. try:
  65. tenant_model = (
  66. db.session.query(Tenant)
  67. .where(
  68. Tenant.id == tenant_id,
  69. )
  70. .first()
  71. )
  72. except Exception:
  73. raise ValueError("tenant not found")
  74. if not tenant_model:
  75. raise ValueError("tenant not found")
  76. kwargs["tenant_model"] = tenant_model
  77. user = get_user(tenant_id, user_id)
  78. kwargs["user_model"] = user
  79. current_app.login_manager._update_request_context_with_user(user) # type: ignore
  80. user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
  81. return view_func(*args, **kwargs)
  82. return decorated_view
  83. if view is None:
  84. return decorator
  85. else:
  86. return decorator(view)
  87. def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
  88. def decorator(view_func: Callable[P, R]):
  89. def decorated_view(*args: P.args, **kwargs: P.kwargs):
  90. try:
  91. data = request.get_json()
  92. except Exception:
  93. raise ValueError("invalid json")
  94. try:
  95. payload = payload_type(**data)
  96. except Exception as e:
  97. raise ValueError(f"invalid payload: {str(e)}")
  98. kwargs["payload"] = payload
  99. return view_func(*args, **kwargs)
  100. return decorated_view
  101. if view is None:
  102. return decorator
  103. else:
  104. return decorator(view)