You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from collections.abc import Callable
  2. from functools import wraps
  3. from typing import ParamSpec, TypeVar, cast
  4. from flask import current_app, request
  5. from flask_login import user_logged_in
  6. from flask_restx import reqparse
  7. from pydantic import BaseModel
  8. from sqlalchemy.orm import Session
  9. from extensions.ext_database import db
  10. from libs.login import current_user
  11. from models.account import Tenant
  12. from models.model import DefaultEndUserSessionID, EndUser
  13. P = ParamSpec("P")
  14. R = TypeVar("R")
  15. def get_user(tenant_id: str, user_id: str | None) -> EndUser:
  16. """
  17. Get current user
  18. NOTE: user_id is not trusted, it could be maliciously set to any value.
  19. As a result, it could only be considered as an end user id.
  20. """
  21. if not user_id:
  22. user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
  23. is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
  24. try:
  25. with Session(db.engine) as session:
  26. user_model = None
  27. if is_anonymous:
  28. user_model = (
  29. session.query(EndUser)
  30. .where(
  31. EndUser.session_id == user_id,
  32. EndUser.tenant_id == tenant_id,
  33. )
  34. .first()
  35. )
  36. else:
  37. user_model = (
  38. session.query(EndUser)
  39. .where(
  40. EndUser.id == user_id,
  41. EndUser.tenant_id == tenant_id,
  42. )
  43. .first()
  44. )
  45. if not user_model:
  46. user_model = EndUser(
  47. tenant_id=tenant_id,
  48. type="service_api",
  49. is_anonymous=is_anonymous,
  50. session_id=user_id,
  51. )
  52. session.add(user_model)
  53. session.commit()
  54. session.refresh(user_model)
  55. except Exception:
  56. raise ValueError("user not found")
  57. return user_model
  58. def get_user_tenant(view: Callable[P, R] | None = None):
  59. def decorator(view_func: Callable[P, R]):
  60. @wraps(view_func)
  61. def decorated_view(*args: P.args, **kwargs: P.kwargs):
  62. # fetch json body
  63. parser = reqparse.RequestParser()
  64. parser.add_argument("tenant_id", type=str, required=True, location="json")
  65. parser.add_argument("user_id", type=str, required=True, location="json")
  66. p = parser.parse_args()
  67. user_id = cast(str, p.get("user_id"))
  68. tenant_id = cast(str, p.get("tenant_id"))
  69. if not tenant_id:
  70. raise ValueError("tenant_id is required")
  71. if not user_id:
  72. user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
  73. try:
  74. tenant_model = (
  75. db.session.query(Tenant)
  76. .where(
  77. Tenant.id == tenant_id,
  78. )
  79. .first()
  80. )
  81. except Exception:
  82. raise ValueError("tenant not found")
  83. if not tenant_model:
  84. raise ValueError("tenant not found")
  85. kwargs["tenant_model"] = tenant_model
  86. user = get_user(tenant_id, user_id)
  87. kwargs["user_model"] = user
  88. current_app.login_manager._update_request_context_with_user(user) # type: ignore
  89. user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
  90. return view_func(*args, **kwargs)
  91. return decorated_view
  92. if view is None:
  93. return decorator
  94. else:
  95. return decorator(view)
  96. def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
  97. def decorator(view_func: Callable[P, R]):
  98. def decorated_view(*args: P.args, **kwargs: P.kwargs):
  99. try:
  100. data = request.get_json()
  101. except Exception:
  102. raise ValueError("invalid json")
  103. try:
  104. payload = payload_type(**data)
  105. except Exception as e:
  106. raise ValueError(f"invalid payload: {str(e)}")
  107. kwargs["payload"] = payload
  108. return view_func(*args, **kwargs)
  109. return decorated_view
  110. if view is None:
  111. return decorator
  112. else:
  113. return decorator(view)