Explorar el Código

add typing to all wraps (#25405)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.0
Asuka Minato hace 1 mes
padre
commit
38057b1b0e
No account linked to committer's email address

+ 7
- 4
api/controllers/console/app/wraps.py Ver fichero

from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Optional, Union
from typing import Optional, ParamSpec, TypeVar, Union


from controllers.console.app.error import AppNotFoundError from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from models import App, AppMode from models import App, AppMode
from models.account import Account from models.account import Account


P = ParamSpec("P")
R = TypeVar("R")



def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
return app_model return app_model




def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"): if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters") raise ValueError("missing app_id in path parameters")



+ 13
- 10
api/controllers/inner_api/plugin/wraps.py Ver fichero

from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Optional
from typing import Optional, ParamSpec, TypeVar


from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from models.account import Tenant from models.account import Tenant
from models.model import EndUser from models.model import EndUser


P = ParamSpec("P")
R = TypeVar("R")



def get_user(tenant_id: str, user_id: str | None) -> EndUser: def get_user(tenant_id: str, user_id: str | None) -> EndUser:
""" """
return user_model return user_model




def get_user_tenant(view: Optional[Callable] = None):
def decorator(view_func):
def get_user_tenant(view: Optional[Callable[P, R]] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body # fetch json body
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json")


kwargs = parser.parse_args()
p = parser.parse_args()


user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
user_id: Optional[str] = p.get("user_id")
tenant_id: str = p.get("tenant_id")


if not tenant_id: if not tenant_id:
raise ValueError("tenant_id is required") raise ValueError("tenant_id is required")
return decorator(view) return decorator(view)




def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try: try:
data = request.get_json() data = request.get_json()
except Exception: except Exception:

+ 2
- 2
api/controllers/inner_api/wraps.py Ver fichero

return decorated return decorated




def enterprise_inner_api_user_auth(view):
def enterprise_inner_api_user_auth(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API: if not dify_config.INNER_API:
return view(*args, **kwargs) return view(*args, **kwargs)



+ 1
- 1
api/controllers/service_api/workspace/models.py Ver fichero

} }
) )
@validate_dataset_token @validate_dataset_token
def get(self, _, model_type):
def get(self, _, model_type: str):
"""Get available models by model type. """Get available models by model type.


Returns a list of available models for the specified model type. Returns a list of available models for the specified model type.

+ 8
- 7
api/controllers/service_api/wraps.py Ver fichero

from datetime import timedelta from datetime import timedelta
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Optional, ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar


from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in


P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T")




class WhereisUserArg(StrEnum): class WhereisUserArg(StrEnum):
required: bool = False required: bool = False




def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("app") api_token = validate_and_get_api_token("app")


app_model = db.session.query(App).where(App.id == api_token.app_id).first() app_model = db.session.query(App).where(App.id == api_token.app_id).first()
return interceptor return interceptor




def validate_dataset_token(view=None):
def decorator(view):
def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)

+ 5
- 5
api/controllers/web/wraps.py Ver fichero

from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeVar


from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
R = TypeVar("R") R = TypeVar("R")




def validate_jwt_token(view=None):
def decorator(view):
def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
app_model, end_user = decode_jwt_token() app_model, end_user = decode_jwt_token()

return view(app_model, end_user, *args, **kwargs) return view(app_model, end_user, *args, **kwargs)


return decorated return decorated

+ 15
- 12
api/core/rag/datasource/vdb/matrixone/matrixone_vector.py Ver fichero

import json import json
import logging import logging
import uuid import uuid
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any, Optional
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar


from mo_vector.client import MoVectorClient # type: ignore from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from models.dataset import Dataset from models.dataset import Dataset


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar


P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
return values return values




def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

return wrapper


class MatrixoneVector(BaseVector): class MatrixoneVector(BaseVector):
""" """
Matrixone vector storage implementation. Matrixone vector storage implementation.
self.client.delete() self.client.delete()




T = TypeVar("T", bound=MatrixoneVector)


def ensure_client(func: Callable[Concatenate[T, P], R]):
@wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)

return wrapper


class MatrixoneVectorFactory(AbstractVectorFactory): class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict: if dataset.index_struct_dict:

+ 10
- 5
api/services/enterprise/plugin_manager_service.py Ver fichero

from services.enterprise.base import EnterprisePluginManagerRequest from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError from services.errors.base import BaseServiceError


logger = logging.getLogger(__name__)


class PluginCredentialType(enum.Enum):
MODEL = 0
TOOL = 1

class PluginCredentialType(enum.IntEnum):
MODEL = enum.auto()
TOOL = enum.auto()


def to_number(self): def to_number(self):
return self.value return self.value
if not ret.get("result", False): if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")


logging.debug(
f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}"
logger.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
) )

Cargando…
Cancelar
Guardar