Selaa lähdekoodia

add typing to all wraps (#25405)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.0
Asuka Minato 1 kuukausi sitten
vanhempi
commit
38057b1b0e
No account linked to committer's email address

+ 7
- 4
api/controllers/console/app/wraps.py Näytä tiedosto

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional, Union
from typing import Optional, ParamSpec, TypeVar, Union

from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
@@ -8,6 +8,9 @@ from libs.login import current_user
from models import App, AppMode
from models.account import Account

P = ParamSpec("P")
R = TypeVar("R")


def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account)
@@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]:
return app_model


def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")


+ 13
- 10
api/controllers/inner_api/plugin/wraps.py Näytä tiedosto

@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from typing import Optional, ParamSpec, TypeVar

from flask import current_app, request
from flask_login import user_logged_in
@@ -14,6 +14,9 @@ from libs.login import _get_user
from models.account import Tenant
from models.model import EndUser

P = ParamSpec("P")
R = TypeVar("R")


def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
@@ -52,19 +55,19 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model


def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
def get_user_tenant(view: Optional[Callable[P, R]] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")

kwargs = parser.parse_args()
p = parser.parse_args()

user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
user_id: Optional[str] = p.get("user_id")
tenant_id: str = p.get("tenant_id")

if not tenant_id:
raise ValueError("tenant_id is required")
@@ -107,9 +110,9 @@ def get_user_tenant(view: Optional[Callable] = None):
return decorator(view)


def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()
except Exception:

+ 2
- 2
api/controllers/inner_api/wraps.py Näytä tiedosto

@@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
return decorated


def enterprise_inner_api_user_auth(view):
def enterprise_inner_api_user_auth(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
return view(*args, **kwargs)


+ 1
- 1
api/controllers/service_api/workspace/models.py Näytä tiedosto

@@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource):
}
)
@validate_dataset_token
def get(self, _, model_type):
def get(self, _, model_type: str):
"""Get available models by model type.

Returns a list of available models for the specified model type.

+ 8
- 7
api/controllers/service_api/wraps.py Näytä tiedosto

@@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Optional, ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar

from flask import current_app, request
from flask_login import user_logged_in
@@ -25,6 +25,7 @@ from services.feature_service import FeatureService

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


class WhereisUserArg(StrEnum):
@@ -42,10 +43,10 @@ class FetchUserArg(BaseModel):
required: bool = False


def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("app")

app_model = db.session.query(App).where(App.id == api_token.app_id).first()
@@ -189,10 +190,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor


def validate_dataset_token(view=None):
def decorator(view):
def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)

+ 5
- 5
api/controllers/web/wraps.py Näytä tiedosto

@@ -1,6 +1,7 @@
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar

from flask import request
from flask_restx import Resource
@@ -20,12 +21,11 @@ P = ParamSpec("P")
R = TypeVar("R")


def validate_jwt_token(view=None):
def decorator(view):
def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
app_model, end_user = decode_jwt_token()

return view(app_model, end_user, *args, **kwargs)

return decorated

+ 15
- 12
api/core/rag/datasource/vdb/matrixone/matrixone_vector.py Näytä tiedosto

@@ -1,8 +1,9 @@
import json
import logging
import uuid
from collections.abc import Callable
from functools import wraps
from typing import Any, Optional
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar

from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
@@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")
@@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel):
return values


def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

return wrapper


class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
@@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector):
self.client.delete()


T = TypeVar("T", bound=MatrixoneVector)


def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

return wrapper


class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:

+ 10
- 5
api/services/enterprise/plugin_manager_service.py Näytä tiedosto

@@ -6,10 +6,12 @@ from pydantic import BaseModel
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError

logger = logging.getLogger(__name__)

class PluginCredentialType(enum.Enum):
MODEL = 0
TOOL = 1

class PluginCredentialType(enum.IntEnum):
MODEL = enum.auto()
TOOL = enum.auto()

def to_number(self):
return self.value
@@ -47,6 +49,9 @@ class PluginManagerService:
if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")

logging.debug(
f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}"
logger.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
)

Loading…
Peruuta
Tallenna