Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/2.0.0-beta.2^2
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource, reqparse | |||
| @@ -6,6 +8,8 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| @@ -14,9 +18,9 @@ from extensions.ext_database import db | |||
| from models.model import App, InstalledApp, RecommendedApp | |||
| def admin_required(view): | |||
| def admin_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ADMIN_API_KEY: | |||
| raise Unauthorized("API key is invalid.") | |||
| @@ -1,5 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import cast | |||
| from typing import Concatenate, ParamSpec, TypeVar, cast | |||
| import flask_login | |||
| from flask import jsonify, request | |||
| @@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, | |||
| from .. import api | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def oauth_server_client_id_required(view): | |||
| def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(self: T, *args: P.args, **kwargs: P.kwargs): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("client_id", type=str, required=True, location="json") | |||
| parsed_args = parser.parse_args() | |||
| @@ -30,18 +35,15 @@ def oauth_server_client_id_required(view): | |||
| if not oauth_provider_app: | |||
| raise NotFound("client_id is invalid") | |||
| kwargs["oauth_provider_app"] = oauth_provider_app | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, *args, **kwargs) | |||
| return decorated | |||
| def oauth_server_access_token_required(view): | |||
| def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| oauth_provider_app = kwargs.get("oauth_provider_app") | |||
| if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): | |||
| if not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| raise BadRequest("Invalid oauth_provider_app") | |||
| authorization_header = request.headers.get("Authorization") | |||
| @@ -79,9 +81,7 @@ def oauth_server_access_token_required(view): | |||
| response.headers["WWW-Authenticate"] = "Bearer" | |||
| return response | |||
| kwargs["account"] = account | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, account, *args, **kwargs) | |||
| return decorated | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Concatenate, Optional, ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from flask_restx import Resource | |||
| @@ -13,19 +15,15 @@ from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def installed_app_required(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not kwargs.get("installed_app_id"): | |||
| raise ValueError("missing installed_app_id in path parameters") | |||
| installed_app_id = kwargs.get("installed_app_id") | |||
| installed_app_id = str(installed_app_id) | |||
| del kwargs["installed_app_id"] | |||
| def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .where( | |||
| @@ -52,10 +50,10 @@ def installed_app_required(view=None): | |||
| return decorator | |||
| def user_allowed_to_access_app(view=None): | |||
| def decorator(view): | |||
| def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app: InstalledApp, *args, **kwargs): | |||
| def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): | |||
| feature = FeatureService.get_system_features() | |||
| if feature.webapp_auth.enabled: | |||
| app_id = installed_app.app_id | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from sqlalchemy.orm import Session | |||
| @@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden | |||
| from extensions.ext_database import db | |||
| from models.account import TenantPluginPermission | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def plugin_permission_required( | |||
| install_required: bool = False, | |||
| debug_required: bool = False, | |||
| ): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| @@ -2,7 +2,9 @@ import contextlib | |||
| import json | |||
| import os | |||
| import time | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import abort, request | |||
| from flask_login import current_user | |||
| @@ -19,10 +21,13 @@ from services.operation_service import OperationService | |||
| from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def account_initialization_required(view): | |||
| def account_initialization_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check account initialization | |||
| account = current_user | |||
| @@ -34,9 +39,9 @@ def account_initialization_required(view): | |||
| return decorated | |||
| def only_edition_cloud(view): | |||
| def only_edition_cloud(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "CLOUD": | |||
| abort(404) | |||
| @@ -45,9 +50,9 @@ def only_edition_cloud(view): | |||
| return decorated | |||
| def only_edition_enterprise(view): | |||
| def only_edition_enterprise(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ENTERPRISE_ENABLED: | |||
| abort(404) | |||
| @@ -56,9 +61,9 @@ def only_edition_enterprise(view): | |||
| return decorated | |||
| def only_edition_self_hosted(view): | |||
| def only_edition_self_hosted(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| abort(404) | |||
| @@ -67,9 +72,9 @@ def only_edition_self_hosted(view): | |||
| return decorated | |||
| def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled: | |||
| abort(403, "Billing feature is not enabled.") | |||
| @@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_resource_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| members = features.members | |||
| @@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| if resource == "add_segment": | |||
| @@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def cloud_edition_billing_rate_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if resource == "knowledge": | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| @@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str): | |||
| return interceptor | |||
| def cloud_utm_record(view): | |||
| def cloud_utm_record(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| with contextlib.suppress(Exception): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| @@ -194,9 +199,9 @@ def cloud_utm_record(view): | |||
| return decorated | |||
| def setup_required(view): | |||
| def setup_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check setup | |||
| if ( | |||
| dify_config.EDITION == "SELF_HOSTED" | |||
| @@ -212,9 +217,9 @@ def setup_required(view): | |||
| return decorated | |||
| def enterprise_license_required(view): | |||
| def enterprise_license_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| settings = FeatureService.get_system_features() | |||
| if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: | |||
| raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") | |||
| @@ -224,9 +229,9 @@ def enterprise_license_required(view): | |||
| return decorated | |||
| def email_password_login_enabled(view): | |||
| def email_password_login_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_email_password_login: | |||
| return view(*args, **kwargs) | |||
| @@ -237,9 +242,9 @@ def email_password_login_enabled(view): | |||
| return decorated | |||
| def enable_change_email(view): | |||
| def enable_change_email(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_change_email: | |||
| return view(*args, **kwargs) | |||
| @@ -250,9 +255,9 @@ def enable_change_email(view): | |||
| return decorated | |||
| def is_allow_transfer_owner(view): | |||
| def is_allow_transfer_owner(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.is_allow_transfer_workspace: | |||
| return view(*args, **kwargs) | |||
| @@ -3,7 +3,7 @@ from collections.abc import Callable | |||
| from datetime import timedelta | |||
| from enum import StrEnum, auto | |||
| from functools import wraps | |||
| from typing import Optional | |||
| from typing import Optional, ParamSpec, TypeVar | |||
| from flask import current_app, request | |||
| from flask_login import user_logged_in | |||
| @@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog | |||
| from models.model import ApiToken, App, EndUser | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class WhereisUserArg(StrEnum): | |||
| """ | |||
| @@ -118,8 +121,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def decorated(*args, **kwargs): | |||
| def interceptor(view: Callable[P, R]): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| @@ -148,9 +151,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -170,9 +173,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s | |||
| def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| if resource == "knowledge": | |||
| @@ -1,5 +1,6 @@ | |||
| from datetime import UTC, datetime | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource | |||
| @@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def validate_jwt_token(view=None): | |||
| def decorator(view): | |||
| @@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| from typing import ParamSpec, TypeVar | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class MatrixoneConfig(BaseModel): | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Union, cast | |||
| @@ -12,9 +13,13 @@ from models.model import EndUser | |||
| #: A proxy for the current user. If no user is logged in, this will be an | |||
| #: anonymous user | |||
| current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) | |||
| from typing import ParamSpec, TypeVar | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def login_required(func): | |||
| def login_required(func: Callable[P, R]): | |||
| """ | |||
| If you decorate a view with this, it will ensure that the current user is | |||
| logged in and authenticated before calling the actual view. (If they are | |||
| @@ -49,17 +54,12 @@ def login_required(func): | |||
| """ | |||
| @wraps(func) | |||
| def decorated_view(*args, **kwargs): | |||
| def decorated_view(*args: P.args, **kwargs: P.kwargs): | |||
| if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | |||
| pass | |||
| elif current_user is not None and not current_user.is_authenticated: | |||
| return current_app.login_manager.unauthorized() # type: ignore | |||
| # flask 1.x compatibility | |||
| # current_app.ensure_sync is only available in Flask >= 2.0 | |||
| if callable(getattr(current_app, "ensure_sync", None)): | |||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||
| return func(*args, **kwargs) | |||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||
| return decorated_view | |||