| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- from collections.abc import Callable
- from functools import wraps
- from typing import Optional, ParamSpec, TypeVar, cast
-
- from flask import current_app, request
- from flask_login import user_logged_in
- from flask_restx import reqparse
- from pydantic import BaseModel
- from sqlalchemy.orm import Session
-
- from core.file.constants import DEFAULT_SERVICE_API_USER_ID
- from extensions.ext_database import db
- from libs.login import current_user
- from models.account import Tenant
- from models.model import EndUser
-
- P = ParamSpec("P")
- R = TypeVar("R")
-
-
- def get_user(tenant_id: str, user_id: str | None) -> EndUser:
- """
- Get current user
-
- NOTE: user_id is not trusted, it could be maliciously set to any value.
- As a result, it could only be considered as an end user id.
- """
- try:
- with Session(db.engine) as session:
- if not user_id:
- user_id = DEFAULT_SERVICE_API_USER_ID
-
- user_model = (
- session.query(EndUser)
- .where(
- EndUser.session_id == user_id,
- EndUser.tenant_id == tenant_id,
- )
- .first()
- )
- if not user_model:
- user_model = EndUser(
- tenant_id=tenant_id,
- type="service_api",
- is_anonymous=user_id == DEFAULT_SERVICE_API_USER_ID,
- session_id=user_id,
- )
- session.add(user_model)
- session.commit()
- session.refresh(user_model)
-
- except Exception:
- raise ValueError("user not found")
-
- return user_model
-
-
- def get_user_tenant(view: Optional[Callable[P, R]] = None):
- def decorator(view_func: Callable[P, R]):
- @wraps(view_func)
- def decorated_view(*args: P.args, **kwargs: P.kwargs):
- # fetch json body
- parser = reqparse.RequestParser()
- parser.add_argument("tenant_id", type=str, required=True, location="json")
- parser.add_argument("user_id", type=str, required=True, location="json")
-
- p = parser.parse_args()
-
- user_id = cast(str, p.get("user_id"))
- tenant_id = cast(str, p.get("tenant_id"))
-
- if not tenant_id:
- raise ValueError("tenant_id is required")
-
- if not user_id:
- user_id = DEFAULT_SERVICE_API_USER_ID
-
- try:
- tenant_model = (
- db.session.query(Tenant)
- .where(
- Tenant.id == tenant_id,
- )
- .first()
- )
- except Exception:
- raise ValueError("tenant not found")
-
- if not tenant_model:
- raise ValueError("tenant not found")
-
- kwargs["tenant_model"] = tenant_model
-
- user = get_user(tenant_id, user_id)
- kwargs["user_model"] = user
-
- current_app.login_manager._update_request_context_with_user(user) # type: ignore
- user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
-
- return view_func(*args, **kwargs)
-
- return decorated_view
-
- if view is None:
- return decorator
- else:
- return decorator(view)
-
-
- def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
- def decorator(view_func: Callable[P, R]):
- def decorated_view(*args: P.args, **kwargs: P.kwargs):
- try:
- data = request.get_json()
- except Exception:
- raise ValueError("invalid json")
-
- try:
- payload = payload_type(**data)
- except Exception as e:
- raise ValueError(f"invalid payload: {str(e)}")
-
- kwargs["payload"] = payload
- return view_func(*args, **kwargs)
-
- return decorated_view
-
- if view is None:
- return decorator
- else:
- return decorator(view)
|