| @@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info | |||
| !.vscode/launch.json.template | |||
| !.vscode/README.md | |||
| api/.vscode | |||
| web/.vscode | |||
| # vscode Code History Extension | |||
| .history | |||
| @@ -215,6 +216,13 @@ mise.toml | |||
| # Next.js build output | |||
| .next/ | |||
| # PWA generated files | |||
| web/public/sw.js | |||
| web/public/sw.js.map | |||
| web/public/workbox-*.js | |||
| web/public/workbox-*.js.map | |||
| web/public/fallback-*.js | |||
| # AI Assistant | |||
| .roo/ | |||
| api/.env.backup | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource, reqparse | |||
| @@ -6,6 +8,8 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| @@ -14,9 +18,9 @@ from extensions.ext_database import db | |||
| from models.model import App, InstalledApp, RecommendedApp | |||
| def admin_required(view): | |||
| def admin_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ADMIN_API_KEY: | |||
| raise Unauthorized("API key is invalid.") | |||
| @@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource): | |||
| custom="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| key = ApiToken.generate_api_key(self.token_prefix or "", 24) | |||
| api_token = ApiToken() | |||
| setattr(api_token, self.resource_id_field, resource_id) | |||
| api_token.tenant_id = current_user.current_tenant_id | |||
| @@ -1,5 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import cast | |||
| from typing import Concatenate, ParamSpec, TypeVar, cast | |||
| import flask_login | |||
| from flask import jsonify, request | |||
| @@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, | |||
| from .. import api | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def oauth_server_client_id_required(view): | |||
| def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(self: T, *args: P.args, **kwargs: P.kwargs): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("client_id", type=str, required=True, location="json") | |||
| parsed_args = parser.parse_args() | |||
| @@ -30,18 +35,15 @@ def oauth_server_client_id_required(view): | |||
| if not oauth_provider_app: | |||
| raise NotFound("client_id is invalid") | |||
| kwargs["oauth_provider_app"] = oauth_provider_app | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, *args, **kwargs) | |||
| return decorated | |||
| def oauth_server_access_token_required(view): | |||
| def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| oauth_provider_app = kwargs.get("oauth_provider_app") | |||
| if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): | |||
| if not isinstance(oauth_provider_app, OAuthProviderApp): | |||
| raise BadRequest("Invalid oauth_provider_app") | |||
| authorization_header = request.headers.get("Authorization") | |||
| @@ -79,9 +81,7 @@ def oauth_server_access_token_required(view): | |||
| response.headers["WWW-Authenticate"] = "Bearer" | |||
| return response | |||
| kwargs["account"] = account | |||
| return view(*args, **kwargs) | |||
| return view(self, oauth_provider_app, account, *args, **kwargs) | |||
| return decorated | |||
| @@ -1,9 +1,9 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from models.model import Account | |||
| from services.billing_service import BillingService | |||
| @@ -17,9 +17,10 @@ class Subscription(Resource): | |||
| parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | |||
| parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | |||
| args = parser.parse_args() | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_subscription( | |||
| args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | |||
| ) | |||
| @@ -31,7 +32,9 @@ class Invoices(Resource): | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | |||
| @@ -475,6 +475,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| data_source_info = document.data_source_info_dict | |||
| if document.data_source_type == "upload_file": | |||
| if not data_source_info: | |||
| continue | |||
| file_id = data_source_info["upload_file_id"] | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| @@ -491,6 +493,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "notion_import": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| @@ -503,6 +507,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "website_crawl": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| @@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource): | |||
| def get(self, installed_app: InstalledApp): | |||
| """Get app meta""" | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise ValueError("App not found") | |||
| return AppService().get_app_meta(app_model) | |||
| @@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| Run workflow | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| Stop workflow task | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Concatenate, Optional, ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from flask_restx import Resource | |||
| @@ -13,19 +15,15 @@ from services.app_service import AppService | |||
| from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| T = TypeVar("T") | |||
| def installed_app_required(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not kwargs.get("installed_app_id"): | |||
| raise ValueError("missing installed_app_id in path parameters") | |||
| installed_app_id = kwargs.get("installed_app_id") | |||
| installed_app_id = str(installed_app_id) | |||
| del kwargs["installed_app_id"] | |||
| def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .where( | |||
| @@ -52,10 +50,10 @@ def installed_app_required(view=None): | |||
| return decorator | |||
| def user_allowed_to_access_app(view=None): | |||
| def decorator(view): | |||
| def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||
| @wraps(view) | |||
| def decorated(installed_app: InstalledApp, *args, **kwargs): | |||
| def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): | |||
| feature = FeatureService.get_system_features() | |||
| if feature.webapp_auth.enabled: | |||
| app_id = installed_app.app_id | |||
| @@ -1,4 +1,6 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask_login import current_user | |||
| from sqlalchemy.orm import Session | |||
| @@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden | |||
| from extensions.ext_database import db | |||
| from models.account import TenantPluginPermission | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def plugin_permission_required( | |||
| install_required: bool = False, | |||
| debug_required: bool = False, | |||
| ): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| user = current_user | |||
| tenant_id = user.current_tenant_id | |||
| @@ -2,7 +2,9 @@ import contextlib | |||
| import json | |||
| import os | |||
| import time | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import abort, request | |||
| from flask_login import current_user | |||
| @@ -19,10 +21,13 @@ from services.operation_service import OperationService | |||
| from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def account_initialization_required(view): | |||
| def account_initialization_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check account initialization | |||
| account = current_user | |||
| @@ -34,9 +39,9 @@ def account_initialization_required(view): | |||
| return decorated | |||
| def only_edition_cloud(view): | |||
| def only_edition_cloud(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "CLOUD": | |||
| abort(404) | |||
| @@ -45,9 +50,9 @@ def only_edition_cloud(view): | |||
| return decorated | |||
| def only_edition_enterprise(view): | |||
| def only_edition_enterprise(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if not dify_config.ENTERPRISE_ENABLED: | |||
| abort(404) | |||
| @@ -56,9 +61,9 @@ def only_edition_enterprise(view): | |||
| return decorated | |||
| def only_edition_self_hosted(view): | |||
| def only_edition_self_hosted(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| abort(404) | |||
| @@ -67,9 +72,9 @@ def only_edition_self_hosted(view): | |||
| return decorated | |||
| def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled: | |||
| abort(403, "Billing feature is not enabled.") | |||
| @@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view): | |||
| def cloud_edition_billing_resource_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| members = features.members | |||
| @@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| if resource == "add_segment": | |||
| @@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): | |||
| def cloud_edition_billing_rate_limit_check(resource: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| if resource == "knowledge": | |||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | |||
| if knowledge_rate_limit.enabled: | |||
| @@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str): | |||
| return interceptor | |||
| def cloud_utm_record(view): | |||
| def cloud_utm_record(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| with contextlib.suppress(Exception): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| @@ -194,9 +199,9 @@ def cloud_utm_record(view): | |||
| return decorated | |||
| def setup_required(view): | |||
| def setup_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| # check setup | |||
| if ( | |||
| dify_config.EDITION == "SELF_HOSTED" | |||
| @@ -212,9 +217,9 @@ def setup_required(view): | |||
| return decorated | |||
| def enterprise_license_required(view): | |||
| def enterprise_license_required(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| settings = FeatureService.get_system_features() | |||
| if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: | |||
| raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") | |||
| @@ -224,9 +229,9 @@ def enterprise_license_required(view): | |||
| return decorated | |||
| def email_password_login_enabled(view): | |||
| def email_password_login_enabled(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_email_password_login: | |||
| return view(*args, **kwargs) | |||
| @@ -237,9 +242,9 @@ def email_password_login_enabled(view): | |||
| return decorated | |||
| def enable_change_email(view): | |||
| def enable_change_email(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_system_features() | |||
| if features.enable_change_email: | |||
| return view(*args, **kwargs) | |||
| @@ -250,9 +255,9 @@ def enable_change_email(view): | |||
| return decorated | |||
| def is_allow_transfer_owner(view): | |||
| def is_allow_transfer_owner(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.is_allow_transfer_workspace: | |||
| return view(*args, **kwargs) | |||
| @@ -3,7 +3,7 @@ from collections.abc import Callable | |||
| from datetime import timedelta | |||
| from enum import StrEnum, auto | |||
| from functools import wraps | |||
| from typing import Optional | |||
| from typing import Optional, ParamSpec, TypeVar | |||
| from flask import current_app, request | |||
| from flask_login import user_logged_in | |||
| @@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog | |||
| from models.model import ApiToken, App, EndUser | |||
| from services.feature_service import FeatureService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class WhereisUserArg(StrEnum): | |||
| """ | |||
| @@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| raise Forbidden("The workspace's status is archived.") | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .where(Tenant.id == api_token.tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.role.in_(["owner"])) | |||
| .where(Tenant.status == TenantStatus.NORMAL) | |||
| .one_or_none() | |||
| ) # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = db.session.query(Account).where(Account.id == ta.account_id).first() | |||
| # Login admin | |||
| if account: | |||
| account.current_tenant = tenant | |||
| current_app.login_manager._update_request_context_with_user(account) # type: ignore | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore | |||
| else: | |||
| raise Unauthorized("Tenant owner account does not exist.") | |||
| else: | |||
| raise Unauthorized("Tenant does not exist.") | |||
| kwargs["app_model"] = app_model | |||
| if fetch_user_arg: | |||
| @@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def decorated(*args, **kwargs): | |||
| def interceptor(view: Callable[P, R]): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| @@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s | |||
| def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | |||
| def interceptor(view): | |||
| def interceptor(view: Callable[P, R]): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| if resource == "knowledge": | |||
| @@ -1,5 +1,6 @@ | |||
| from datetime import UTC, datetime | |||
| from functools import wraps | |||
| from typing import ParamSpec, TypeVar | |||
| from flask import request | |||
| from flask_restx import Resource | |||
| @@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def validate_jwt_token(view=None): | |||
| def decorator(view): | |||
| @@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| raise MessageNotExistsError() | |||
| current_app_model_config = app_model.app_model_config | |||
| if not current_app_model_config: | |||
| raise MoreLikeThisDisabledError() | |||
| more_like_this = current_app_model_config.more_like_this_dict | |||
| if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: | |||
| @@ -124,6 +124,7 @@ class TokenBufferMemory: | |||
| messages = list(reversed(thread_messages)) | |||
| curr_message_tokens = 0 | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| # Process user message with files | |||
| @@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| from typing import ParamSpec, TypeVar | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| class MatrixoneConfig(BaseModel): | |||
| @@ -334,7 +334,8 @@ class NotionExtractor(BaseExtractor): | |||
| last_edited_time = self.get_notion_last_edited_time() | |||
| data_source_info = document_model.data_source_info_dict | |||
| data_source_info["last_edited_time"] = last_edited_time | |||
| if data_source_info: | |||
| data_source_info["last_edited_time"] = last_edited_time | |||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update( | |||
| {DocumentModel.data_source_info: json.dumps(data_source_info)} | |||
| @@ -1,5 +1,5 @@ | |||
| import json | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, Self | |||
| from core.mcp.types import Tool as RemoteMCPTool | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| @@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController): | |||
| return ToolProviderType.MCP | |||
| @classmethod | |||
| def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": | |||
| def from_db(cls, db_provider: MCPToolProvider) -> Self: | |||
| """ | |||
| from db provider | |||
| """ | |||
| @@ -777,7 +777,7 @@ class ToolManager: | |||
| if provider is None: | |||
| raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | |||
| controller = MCPToolProviderController._from_db(provider) | |||
| controller = MCPToolProviderController.from_db(provider) | |||
| return controller | |||
| @@ -932,7 +932,7 @@ class ToolManager: | |||
| tenant_id: str, | |||
| provider_type: ToolProviderType, | |||
| provider_id: str, | |||
| ) -> Union[str, dict]: | |||
| ) -> Union[str, dict[str, Any]]: | |||
| """ | |||
| get the tool icon | |||
| @@ -3,7 +3,7 @@ from collections.abc import Generator, Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from core.variables import ArrayVariable, IntegerVariable, NoneVariable | |||
| from core.variables import IntegerVariable, NoneSegment | |||
| from core.variables.segments import ArrayAnySegment, ArraySegment | |||
| from core.workflow.entities import VariablePool | |||
| from core.workflow.enums import ( | |||
| @@ -97,10 +97,10 @@ class IterationNode(Node): | |||
| if not variable: | |||
| raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") | |||
| if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): | |||
| if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): | |||
| raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") | |||
| if isinstance(variable, NoneVariable) or len(variable.value) == 0: | |||
| if isinstance(variable, NoneSegment) or len(variable.value) == 0: | |||
| # Try our best to preserve the type informat. | |||
| if isinstance(variable, ArraySegment): | |||
| output = variable.model_copy(update={"value": []}) | |||
| @@ -50,6 +50,7 @@ from .exc import ( | |||
| ) | |||
| from .prompts import ( | |||
| CHAT_EXAMPLE, | |||
| CHAT_GENERATE_JSON_PROMPT, | |||
| CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, | |||
| COMPLETION_GENERATE_JSON_PROMPT, | |||
| FUNCTION_CALLING_EXTRACTOR_EXAMPLE, | |||
| @@ -746,7 +747,7 @@ class ParameterExtractorNode(Node): | |||
| if model_mode == ModelMode.CHAT: | |||
| system_prompt_messages = ChatModelMessage( | |||
| role=PromptMessageRole.SYSTEM, | |||
| text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), | |||
| text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), | |||
| ) | |||
| user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) | |||
| return [system_prompt_messages, user_prompt_message] | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Callable | |||
| from functools import wraps | |||
| from typing import Union, cast | |||
| @@ -12,9 +13,13 @@ from models.model import EndUser | |||
| #: A proxy for the current user. If no user is logged in, this will be an | |||
| #: anonymous user | |||
| current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) | |||
| from typing import ParamSpec, TypeVar | |||
| P = ParamSpec("P") | |||
| R = TypeVar("R") | |||
| def login_required(func): | |||
| def login_required(func: Callable[P, R]): | |||
| """ | |||
| If you decorate a view with this, it will ensure that the current user is | |||
| logged in and authenticated before calling the actual view. (If they are | |||
| @@ -49,17 +54,12 @@ def login_required(func): | |||
| """ | |||
| @wraps(func) | |||
| def decorated_view(*args, **kwargs): | |||
| def decorated_view(*args: P.args, **kwargs: P.kwargs): | |||
| if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | |||
| pass | |||
| elif current_user is not None and not current_user.is_authenticated: | |||
| return current_app.login_manager.unauthorized() # type: ignore | |||
| # flask 1.x compatibility | |||
| # current_app.ensure_sync is only available in Flask >= 2.0 | |||
| if callable(getattr(current_app, "ensure_sync", None)): | |||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||
| return func(*args, **kwargs) | |||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||
| return decorated_view | |||
| @@ -1,10 +1,10 @@ | |||
| import enum | |||
| import json | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| import sqlalchemy as sa | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||
| from sqlalchemy import DateTime, String, func, select | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | |||
| @@ -225,11 +225,11 @@ class Tenant(Base): | |||
| ) | |||
| @property | |||
| def custom_config_dict(self): | |||
| def custom_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.custom_config) if self.custom_config else {} | |||
| @custom_config_dict.setter | |||
| def custom_config_dict(self, value: dict): | |||
| def custom_config_dict(self, value: dict[str, Any]) -> None: | |||
| self.custom_config = json.dumps(value) | |||
| @@ -286,7 +286,7 @@ class DatasetProcessRule(Base): | |||
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | |||
| } | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "dataset_id": self.dataset_id, | |||
| @@ -295,7 +295,7 @@ class DatasetProcessRule(Base): | |||
| } | |||
| @property | |||
| def rules_dict(self): | |||
| def rules_dict(self) -> dict[str, Any] | None: | |||
| try: | |||
| return json.loads(self.rules) if self.rules else None | |||
| except JSONDecodeError: | |||
| @@ -392,10 +392,10 @@ class Document(Base): | |||
| return status | |||
| @property | |||
| def data_source_info_dict(self): | |||
| def data_source_info_dict(self) -> dict[str, Any] | None: | |||
| if self.data_source_info: | |||
| try: | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||
| except JSONDecodeError: | |||
| data_source_info_dict = {} | |||
| @@ -403,10 +403,10 @@ class Document(Base): | |||
| return None | |||
| @property | |||
| def data_source_detail_dict(self): | |||
| def data_source_detail_dict(self) -> dict[str, Any]: | |||
| if self.data_source_info: | |||
| if self.data_source_type == "upload_file": | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | |||
| @@ -425,7 +425,8 @@ class Document(Base): | |||
| } | |||
| } | |||
| elif self.data_source_type in {"notion_import", "website_crawl"}: | |||
| return json.loads(self.data_source_info) | |||
| result: dict[str, Any] = json.loads(self.data_source_info) | |||
| return result | |||
| return {} | |||
| @property | |||
| @@ -471,7 +472,7 @@ class Document(Base): | |||
| return self.updated_at | |||
| @property | |||
| def doc_metadata_details(self): | |||
| def doc_metadata_details(self) -> list[dict[str, Any]] | None: | |||
| if self.doc_metadata: | |||
| document_metadatas = ( | |||
| db.session.query(DatasetMetadata) | |||
| @@ -481,9 +482,9 @@ class Document(Base): | |||
| ) | |||
| .all() | |||
| ) | |||
| metadata_list = [] | |||
| metadata_list: list[dict[str, Any]] = [] | |||
| for metadata in document_metadatas: | |||
| metadata_dict = { | |||
| metadata_dict: dict[str, Any] = { | |||
| "id": metadata.id, | |||
| "name": metadata.name, | |||
| "type": metadata.type, | |||
| @@ -497,13 +498,13 @@ class Document(Base): | |||
| return None | |||
| @property | |||
| def process_rule_dict(self): | |||
| if self.dataset_process_rule_id: | |||
| def process_rule_dict(self) -> dict[str, Any] | None: | |||
| if self.dataset_process_rule_id and self.dataset_process_rule: | |||
| return self.dataset_process_rule.to_dict() | |||
| return None | |||
| def get_built_in_fields(self): | |||
| built_in_fields = [] | |||
| def get_built_in_fields(self) -> list[dict[str, Any]]: | |||
| built_in_fields: list[dict[str, Any]] = [] | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| @@ -546,7 +547,7 @@ class Document(Base): | |||
| ) | |||
| return built_in_fields | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| @@ -592,13 +593,13 @@ class Document(Base): | |||
| "data_source_info_dict": self.data_source_info_dict, | |||
| "average_segment_length": self.average_segment_length, | |||
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |||
| "dataset": self.dataset.to_dict() if self.dataset else None, | |||
| "dataset": None, # Dataset class doesn't have a to_dict method | |||
| "segment_count": self.segment_count, | |||
| "hit_count": self.hit_count, | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict): | |||
| def from_dict(cls, data: dict[str, Any]): | |||
| return cls( | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| @@ -711,46 +712,48 @@ class DocumentSegment(Base): | |||
| ) | |||
| @property | |||
| def child_chunks(self): | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule.mode == "hierarchical": | |||
| rules = Rule(**process_rule.rules_dict) | |||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| else: | |||
| return [] | |||
| else: | |||
| def child_chunks(self) -> list[Any]: | |||
| if not self.document: | |||
| return [] | |||
| def get_child_chunks(self): | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule.mode == "hierarchical": | |||
| rules = Rule(**process_rule.rules_dict) | |||
| if rules.parent_mode: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| else: | |||
| return [] | |||
| else: | |||
| if process_rule and process_rule.mode == "hierarchical": | |||
| rules_dict = process_rule.rules_dict | |||
| if rules_dict: | |||
| rules = Rule(**rules_dict) | |||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| return [] | |||
| def get_child_chunks(self) -> list[Any]: | |||
| if not self.document: | |||
| return [] | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule and process_rule.mode == "hierarchical": | |||
| rules_dict = process_rule.rules_dict | |||
| if rules_dict: | |||
| rules = Rule(**rules_dict) | |||
| if rules.parent_mode: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| return [] | |||
| @property | |||
| def sign_content(self): | |||
| def sign_content(self) -> str: | |||
| return self.get_sign_content() | |||
| def get_sign_content(self): | |||
| signed_urls = [] | |||
| def get_sign_content(self) -> str: | |||
| signed_urls: list[tuple[int, int, str]] = [] | |||
| text = self.content | |||
| # For data before v0.10.0 | |||
| @@ -890,17 +893,22 @@ class DatasetKeywordTable(Base): | |||
| ) | |||
| @property | |||
| def keyword_table_dict(self): | |||
| def keyword_table_dict(self) -> dict[str, set[Any]] | None: | |||
| class SetDecoder(json.JSONDecoder): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||
| def object_hook(self, dct): | |||
| if isinstance(dct, dict): | |||
| for keyword, node_idxs in dct.items(): | |||
| if isinstance(node_idxs, list): | |||
| dct[keyword] = set(node_idxs) | |||
| return dct | |||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||
| def object_hook(dct: Any) -> Any: | |||
| if isinstance(dct, dict): | |||
| result: dict[str, Any] = {} | |||
| items = cast(dict[str, Any], dct).items() | |||
| for keyword, node_idxs in items: | |||
| if isinstance(node_idxs, list): | |||
| result[keyword] = set(cast(list[Any], node_idxs)) | |||
| else: | |||
| result[keyword] = node_idxs | |||
| return result | |||
| return dct | |||
| super().__init__(object_hook=object_hook, *args, **kwargs) | |||
| # get dataset | |||
| dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | |||
| @@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base): | |||
| updated_by = mapped_column(StringUUID, nullable=True) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| @@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base): | |||
| } | |||
| @property | |||
| def settings_dict(self): | |||
| def settings_dict(self) -> dict[str, Any] | None: | |||
| try: | |||
| return json.loads(self.settings) if self.settings else None | |||
| except JSONDecodeError: | |||
| return None | |||
| @property | |||
| def dataset_bindings(self): | |||
| def dataset_bindings(self) -> list[dict[str, Any]]: | |||
| external_knowledge_bindings = ( | |||
| db.session.query(ExternalKnowledgeBindings) | |||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||
| @@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base): | |||
| ) | |||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | |||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | |||
| dataset_bindings = [] | |||
| dataset_bindings: list[dict[str, Any]] = [] | |||
| for dataset in datasets: | |||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | |||
| @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| @@ -18,7 +18,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType | |||
| from core.file import helpers as file_helpers | |||
| from core.tools.signature import sign_tool_file | |||
| from core.workflow.enums import WorkflowExecutionStatus | |||
| from libs.helper import generate_string | |||
| from libs.helper import generate_string # type: ignore[import-not-found] | |||
| from .account import Account, Tenant | |||
| from .base import Base | |||
| @@ -96,7 +96,7 @@ class App(Base): | |||
| use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | |||
| @property | |||
| def desc_or_prompt(self): | |||
| def desc_or_prompt(self) -> str: | |||
| if self.description: | |||
| return self.description | |||
| else: | |||
| @@ -107,12 +107,12 @@ class App(Base): | |||
| return "" | |||
| @property | |||
| def site(self): | |||
| def site(self) -> Optional["Site"]: | |||
| site = db.session.query(Site).where(Site.app_id == self.id).first() | |||
| return site | |||
| @property | |||
| def app_model_config(self): | |||
| def app_model_config(self) -> Optional["AppModelConfig"]: | |||
| if self.app_model_config_id: | |||
| return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | |||
| @@ -128,11 +128,11 @@ class App(Base): | |||
| return None | |||
| @property | |||
| def api_base_url(self): | |||
| def api_base_url(self) -> str: | |||
| return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | |||
| @property | |||
| def tenant(self): | |||
| def tenant(self) -> Optional[Tenant]: | |||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| return tenant | |||
| @@ -160,9 +160,8 @@ class App(Base): | |||
| return str(self.mode) | |||
| @property | |||
| def deleted_tools(self) -> list: | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.tools.tool_manager import ToolManager | |||
| def deleted_tools(self) -> list[dict[str, str]]: | |||
| from core.tools.tool_manager import ToolManager, ToolProviderType | |||
| from services.plugin.plugin_service import PluginService | |||
| # get agent mode tools | |||
| @@ -242,7 +241,7 @@ class App(Base): | |||
| provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | |||
| } | |||
| deleted_tools = [] | |||
| deleted_tools: list[dict[str, str]] = [] | |||
| for tool in tools: | |||
| keys = list(tool.keys()) | |||
| @@ -275,7 +274,7 @@ class App(Base): | |||
| return deleted_tools | |||
| @property | |||
| def tags(self): | |||
| def tags(self) -> list["Tag"]: | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | |||
| @@ -291,7 +290,7 @@ class App(Base): | |||
| return tags or [] | |||
| @property | |||
| def author_name(self): | |||
| def author_name(self) -> Optional[str]: | |||
| if self.created_by: | |||
| account = db.session.query(Account).where(Account.id == self.created_by).first() | |||
| if account: | |||
| @@ -334,20 +333,20 @@ class AppModelConfig(Base): | |||
| file_upload = mapped_column(sa.Text) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @property | |||
| def model_dict(self): | |||
| def model_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.model) if self.model else {} | |||
| @property | |||
| def suggested_questions_list(self): | |||
| def suggested_questions_list(self) -> list[str]: | |||
| return json.loads(self.suggested_questions) if self.suggested_questions else [] | |||
| @property | |||
| def suggested_questions_after_answer_dict(self): | |||
| def suggested_questions_after_answer_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.suggested_questions_after_answer) | |||
| if self.suggested_questions_after_answer | |||
| @@ -355,19 +354,19 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def speech_to_text_dict(self): | |||
| def speech_to_text_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | |||
| @property | |||
| def text_to_speech_dict(self): | |||
| def text_to_speech_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | |||
| @property | |||
| def retriever_resource_dict(self): | |||
| def retriever_resource_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | |||
| @property | |||
| def annotation_reply_dict(self): | |||
| def annotation_reply_dict(self) -> dict[str, Any]: | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | |||
| ) | |||
| @@ -390,11 +389,11 @@ class AppModelConfig(Base): | |||
| return {"enabled": False} | |||
| @property | |||
| def more_like_this_dict(self): | |||
| def more_like_this_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | |||
| @property | |||
| def sensitive_word_avoidance_dict(self): | |||
| def sensitive_word_avoidance_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.sensitive_word_avoidance) | |||
| if self.sensitive_word_avoidance | |||
| @@ -402,15 +401,15 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def external_data_tools_list(self) -> list[dict]: | |||
| def external_data_tools_list(self) -> list[dict[str, Any]]: | |||
| return json.loads(self.external_data_tools) if self.external_data_tools else [] | |||
| @property | |||
| def user_input_form_list(self): | |||
| def user_input_form_list(self) -> list[dict[str, Any]]: | |||
| return json.loads(self.user_input_form) if self.user_input_form else [] | |||
| @property | |||
| def agent_mode_dict(self): | |||
| def agent_mode_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.agent_mode) | |||
| if self.agent_mode | |||
| @@ -418,17 +417,17 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def chat_prompt_config_dict(self): | |||
| def chat_prompt_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | |||
| @property | |||
| def completion_prompt_config_dict(self): | |||
| def completion_prompt_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | |||
| @property | |||
| def dataset_configs_dict(self): | |||
| def dataset_configs_dict(self) -> dict[str, Any]: | |||
| if self.dataset_configs: | |||
| dataset_configs: dict = json.loads(self.dataset_configs) | |||
| dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) | |||
| if "retrieval_model" not in dataset_configs: | |||
| return {"retrieval_model": "single"} | |||
| else: | |||
| @@ -438,7 +437,7 @@ class AppModelConfig(Base): | |||
| } | |||
| @property | |||
| def file_upload_dict(self): | |||
| def file_upload_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.file_upload) | |||
| if self.file_upload | |||
| @@ -452,7 +451,7 @@ class AppModelConfig(Base): | |||
| } | |||
| ) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "opening_statement": self.opening_statement, | |||
| "suggested_questions": self.suggested_questions_list, | |||
| @@ -546,7 +545,7 @@ class RecommendedApp(Base): | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @@ -570,12 +569,12 @@ class InstalledApp(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @property | |||
| def tenant(self): | |||
| def tenant(self) -> Optional[Tenant]: | |||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| return tenant | |||
| @@ -622,7 +621,7 @@ class Conversation(Base): | |||
| mode: Mapped[str] = mapped_column(String(255)) | |||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| summary = mapped_column(sa.Text) | |||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||
| introduction = mapped_column(sa.Text) | |||
| system_instruction = mapped_column(sa.Text) | |||
| system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | |||
| @@ -652,7 +651,7 @@ class Conversation(Base): | |||
| is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | |||
| @property | |||
| def inputs(self): | |||
| def inputs(self) -> dict[str, Any]: | |||
| inputs = self._inputs.copy() | |||
| # Convert file mapping to File object | |||
| @@ -660,22 +659,39 @@ class Conversation(Base): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| if ( | |||
| isinstance(value, dict) | |||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| ): | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| value_dict = cast(dict[str, Any], value) | |||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||
| elif isinstance(value, list): | |||
| value_list = cast(list[Any], value) | |||
| if all( | |||
| isinstance(item, dict) | |||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| for item in value_list | |||
| ): | |||
| file_list: list[File] = [] | |||
| for item in value_list: | |||
| if not isinstance(item, dict): | |||
| continue | |||
| item_dict = cast(dict[str, Any], item) | |||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||
| elif item_dict["transfer_method"] in [ | |||
| FileTransferMethod.LOCAL_FILE, | |||
| FileTransferMethod.REMOTE_URL, | |||
| ]: | |||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||
| inputs[key] = file_list | |||
| return inputs | |||
| @@ -685,8 +701,10 @@ class Conversation(Base): | |||
| for k, v in inputs.items(): | |||
| if isinstance(v, File): | |||
| inputs[k] = v.model_dump() | |||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||
| inputs[k] = [item.model_dump() for item in v] | |||
| elif isinstance(v, list): | |||
| v_list = cast(list[Any], v) | |||
| if all(isinstance(item, File) for item in v_list): | |||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||
| self._inputs = inputs | |||
| @property | |||
| @@ -826,7 +844,7 @@ class Conversation(Base): | |||
| ) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| return session.query(App).where(App.id == self.app_id).first() | |||
| @@ -840,7 +858,7 @@ class Conversation(Base): | |||
| return None | |||
| @property | |||
| def from_account_name(self): | |||
| def from_account_name(self) -> Optional[str]: | |||
| if self.from_account_id: | |||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | |||
| if account: | |||
| @@ -849,10 +867,10 @@ class Conversation(Base): | |||
| return None | |||
| @property | |||
| def in_debug_mode(self): | |||
| def in_debug_mode(self) -> bool: | |||
| return self.override_model_configs is not None | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -898,7 +916,7 @@ class Message(Base): | |||
| model_id = mapped_column(String(255), nullable=True) | |||
| override_model_configs = mapped_column(sa.Text) | |||
| conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | |||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||
| query: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| message = mapped_column(sa.JSON, nullable=False) | |||
| message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | |||
| @@ -925,28 +943,45 @@ class Message(Base): | |||
| workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | |||
| @property | |||
| def inputs(self): | |||
| def inputs(self) -> dict[str, Any]: | |||
| inputs = self._inputs.copy() | |||
| for key, value in inputs.items(): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| if ( | |||
| isinstance(value, dict) | |||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| ): | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| value_dict = cast(dict[str, Any], value) | |||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||
| elif isinstance(value, list): | |||
| value_list = cast(list[Any], value) | |||
| if all( | |||
| isinstance(item, dict) | |||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| for item in value_list | |||
| ): | |||
| file_list: list[File] = [] | |||
| for item in value_list: | |||
| if not isinstance(item, dict): | |||
| continue | |||
| item_dict = cast(dict[str, Any], item) | |||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||
| elif item_dict["transfer_method"] in [ | |||
| FileTransferMethod.LOCAL_FILE, | |||
| FileTransferMethod.REMOTE_URL, | |||
| ]: | |||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||
| inputs[key] = file_list | |||
| return inputs | |||
| @inputs.setter | |||
| @@ -955,8 +990,10 @@ class Message(Base): | |||
| for k, v in inputs.items(): | |||
| if isinstance(v, File): | |||
| inputs[k] = v.model_dump() | |||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||
| inputs[k] = [item.model_dump() for item in v] | |||
| elif isinstance(v, list): | |||
| v_list = cast(list[Any], v) | |||
| if all(isinstance(item, File) for item in v_list): | |||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||
| self._inputs = inputs | |||
| @property | |||
| @@ -1084,15 +1121,15 @@ class Message(Base): | |||
| return None | |||
| @property | |||
| def in_debug_mode(self): | |||
| def in_debug_mode(self) -> bool: | |||
| return self.override_model_configs is not None | |||
| @property | |||
| def message_metadata_dict(self): | |||
| def message_metadata_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.message_metadata) if self.message_metadata else {} | |||
| @property | |||
| def agent_thoughts(self): | |||
| def agent_thoughts(self) -> list["MessageAgentThought"]: | |||
| return ( | |||
| db.session.query(MessageAgentThought) | |||
| .where(MessageAgentThought.message_id == self.id) | |||
| @@ -1101,11 +1138,11 @@ class Message(Base): | |||
| ) | |||
| @property | |||
| def retriever_resources(self): | |||
| def retriever_resources(self) -> Any | list[Any]: | |||
| return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | |||
| @property | |||
| def message_files(self): | |||
| def message_files(self) -> list[dict[str, Any]]: | |||
| from factories import file_factory | |||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | |||
| @@ -1113,7 +1150,7 @@ class Message(Base): | |||
| if not current_app: | |||
| raise ValueError(f"App {self.app_id} not found") | |||
| files = [] | |||
| files: list[File] = [] | |||
| for message_file in message_files: | |||
| if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | |||
| if message_file.upload_file_id is None: | |||
| @@ -1160,7 +1197,7 @@ class Message(Base): | |||
| ) | |||
| files.append(file) | |||
| result = [ | |||
| result: list[dict[str, Any]] = [ | |||
| {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | |||
| for (file, message_file) in zip(files, message_files) | |||
| ] | |||
| @@ -1177,7 +1214,7 @@ class Message(Base): | |||
| return None | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -1201,7 +1238,7 @@ class Message(Base): | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict): | |||
| def from_dict(cls, data: dict[str, Any]) -> "Message": | |||
| return cls( | |||
| id=data["id"], | |||
| app_id=data["app_id"], | |||
| @@ -1251,7 +1288,7 @@ class MessageFeedback(Base): | |||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | |||
| return account | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": str(self.id), | |||
| "app_id": str(self.app_id), | |||
| @@ -1436,7 +1473,18 @@ class EndUser(Base, UserMixin): | |||
| type: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| external_user_id = mapped_column(String(255), nullable=True) | |||
| name = mapped_column(String(255)) | |||
| is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||
| _is_anonymous: Mapped[bool] = mapped_column( | |||
| "is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true") | |||
| ) | |||
| @property | |||
| def is_anonymous(self) -> Literal[False]: | |||
| return False | |||
| @is_anonymous.setter | |||
| def is_anonymous(self, value: bool) -> None: | |||
| self._is_anonymous = value | |||
| session_id: Mapped[str] = mapped_column() | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -1462,7 +1510,7 @@ class AppMCPServer(Base): | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @staticmethod | |||
| def generate_server_code(n): | |||
| def generate_server_code(n: int) -> str: | |||
| while True: | |||
| result = generate_string(n) | |||
| while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | |||
| @@ -1519,7 +1567,7 @@ class Site(Base): | |||
| self._custom_disclaimer = value | |||
| @staticmethod | |||
| def generate_code(n): | |||
| def generate_code(n: int) -> str: | |||
| while True: | |||
| result = generate_string(n) | |||
| while db.session.query(Site).where(Site.code == result).count() > 0: | |||
| @@ -1550,7 +1598,7 @@ class ApiToken(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @staticmethod | |||
| def generate_api_key(prefix, n): | |||
| def generate_api_key(prefix: str, n: int) -> str: | |||
| while True: | |||
| result = prefix + generate_string(n) | |||
| if db.session.scalar(select(exists().where(ApiToken.token == result))): | |||
| @@ -1690,7 +1738,7 @@ class MessageAgentThought(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| @property | |||
| def files(self): | |||
| def files(self) -> list[Any]: | |||
| if self.message_files: | |||
| return cast(list[Any], json.loads(self.message_files)) | |||
| else: | |||
| @@ -1701,32 +1749,32 @@ class MessageAgentThought(Base): | |||
| return self.tool.split(";") if self.tool else [] | |||
| @property | |||
| def tool_labels(self): | |||
| def tool_labels(self) -> dict[str, Any]: | |||
| try: | |||
| if self.tool_labels_str: | |||
| return cast(dict, json.loads(self.tool_labels_str)) | |||
| return cast(dict[str, Any], json.loads(self.tool_labels_str)) | |||
| else: | |||
| return {} | |||
| except Exception: | |||
| return {} | |||
| @property | |||
| def tool_meta(self): | |||
| def tool_meta(self) -> dict[str, Any]: | |||
| try: | |||
| if self.tool_meta_str: | |||
| return cast(dict, json.loads(self.tool_meta_str)) | |||
| return cast(dict[str, Any], json.loads(self.tool_meta_str)) | |||
| else: | |||
| return {} | |||
| except Exception: | |||
| return {} | |||
| @property | |||
| def tool_inputs_dict(self): | |||
| def tool_inputs_dict(self) -> dict[str, Any]: | |||
| tools = self.tools | |||
| try: | |||
| if self.tool_input: | |||
| data = json.loads(self.tool_input) | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for tool in tools: | |||
| if tool in data: | |||
| result[tool] = data[tool] | |||
| @@ -1742,12 +1790,12 @@ class MessageAgentThought(Base): | |||
| return {} | |||
| @property | |||
| def tool_outputs_dict(self): | |||
| def tool_outputs_dict(self) -> dict[str, Any]: | |||
| tools = self.tools | |||
| try: | |||
| if self.observation: | |||
| data = json.loads(self.observation) | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for tool in tools: | |||
| if tool in data: | |||
| result[tool] = data[tool] | |||
| @@ -1845,14 +1893,14 @@ class TraceAppConfig(Base): | |||
| is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||
| @property | |||
| def tracing_config_dict(self): | |||
| def tracing_config_dict(self) -> dict[str, Any]: | |||
| return self.tracing_config or {} | |||
| @property | |||
| def tracing_config_str(self): | |||
| def tracing_config_str(self) -> str: | |||
| return json.dumps(self.tracing_config_dict) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -17,7 +17,7 @@ class ProviderType(Enum): | |||
| SYSTEM = "system" | |||
| @staticmethod | |||
| def value_of(value): | |||
| def value_of(value: str) -> "ProviderType": | |||
| for member in ProviderType: | |||
| if member.value == value: | |||
| return member | |||
| @@ -35,7 +35,7 @@ class ProviderQuotaType(Enum): | |||
| """hosted trial quota""" | |||
| @staticmethod | |||
| def value_of(value): | |||
| def value_of(value: str) -> "ProviderQuotaType": | |||
| for member in ProviderQuotaType: | |||
| if member.value == value: | |||
| return member | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from datetime import datetime | |||
| from typing import TYPE_CHECKING, Optional, cast | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| from urllib.parse import urlparse | |||
| import sqlalchemy as sa | |||
| @@ -58,8 +58,8 @@ class ToolOAuthTenantClient(Base): | |||
| encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| @property | |||
| def oauth_params(self): | |||
| return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) | |||
| def oauth_params(self) -> dict[str, Any]: | |||
| return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}")) | |||
| class BuiltinToolProvider(Base): | |||
| @@ -100,8 +100,8 @@ class BuiltinToolProvider(Base): | |||
| expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | |||
| @property | |||
| def credentials(self): | |||
| return cast(dict, json.loads(self.encrypted_credentials)) | |||
| def credentials(self) -> dict[str, Any]: | |||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) | |||
| class ApiToolProvider(Base): | |||
| @@ -154,8 +154,8 @@ class ApiToolProvider(Base): | |||
| return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | |||
| @property | |||
| def credentials(self): | |||
| return dict(json.loads(self.credentials_str)) | |||
| def credentials(self) -> dict[str, Any]: | |||
| return dict[str, Any](json.loads(self.credentials_str)) | |||
| @property | |||
| def user(self) -> Account | None: | |||
| @@ -299,9 +299,9 @@ class MCPToolProvider(Base): | |||
| return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| @property | |||
| def credentials(self): | |||
| def credentials(self) -> dict[str, Any]: | |||
| try: | |||
| return cast(dict, json.loads(self.encrypted_credentials)) or {} | |||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} | |||
| except Exception: | |||
| return {} | |||
| @@ -341,12 +341,12 @@ class MCPToolProvider(Base): | |||
| return mask_url(self.decrypted_server_url) | |||
| @property | |||
| def decrypted_credentials(self): | |||
| def decrypted_credentials(self) -> dict[str, Any]: | |||
| from core.helper.provider_cache import NoOpProviderCredentialCache | |||
| from core.tools.mcp_tool.provider import MCPToolProviderController | |||
| from core.tools.utils.encryption import create_provider_encrypter | |||
| provider_controller = MCPToolProviderController._from_db(self) | |||
| provider_controller = MCPToolProviderController.from_db(self) | |||
| encrypter, _ = create_provider_encrypter( | |||
| tenant_id=self.tenant_id, | |||
| @@ -354,7 +354,7 @@ class MCPToolProvider(Base): | |||
| cache=NoOpProviderCredentialCache(), | |||
| ) | |||
| return encrypter.decrypt(self.credentials) # type: ignore | |||
| return encrypter.decrypt(self.credentials) | |||
| class ToolModelInvoke(Base): | |||
| @@ -1,29 +1,34 @@ | |||
| import enum | |||
| from typing import Generic, TypeVar | |||
| import uuid | |||
| from typing import Any, Generic, TypeVar | |||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | |||
| from sqlalchemy.dialects.postgresql import UUID | |||
| from sqlalchemy.engine.interfaces import Dialect | |||
| from sqlalchemy.sql.type_api import TypeEngine | |||
| class StringUUID(TypeDecorator): | |||
| class StringUUID(TypeDecorator[uuid.UUID | str | None]): | |||
| impl = CHAR | |||
| cache_ok = True | |||
| def process_bind_param(self, value, dialect): | |||
| def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| elif dialect.name == "postgresql": | |||
| return str(value) | |||
| else: | |||
| return value.hex | |||
| if isinstance(value, uuid.UUID): | |||
| return value.hex | |||
| return value | |||
| def load_dialect_impl(self, dialect): | |||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||
| if dialect.name == "postgresql": | |||
| return dialect.type_descriptor(UUID()) | |||
| else: | |||
| return dialect.type_descriptor(CHAR(36)) | |||
| def process_result_value(self, value, dialect): | |||
| def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| return str(value) | |||
| @@ -32,7 +37,7 @@ class StringUUID(TypeDecorator): | |||
| _E = TypeVar("_E", bound=enum.StrEnum) | |||
| class EnumText(TypeDecorator, Generic[_E]): | |||
| class EnumText(TypeDecorator[_E | None], Generic[_E]): | |||
| impl = VARCHAR | |||
| cache_ok = True | |||
| @@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]): | |||
| # leave some rooms for future longer enum values. | |||
| self._length = max(max_enum_value_len, 20) | |||
| def process_bind_param(self, value: _E | str | None, dialect): | |||
| def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| if isinstance(value, self._enum_class): | |||
| return value.value | |||
| elif isinstance(value, str): | |||
| self._enum_class(value) | |||
| return value | |||
| else: | |||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||
| # Since _E is bound to StrEnum which inherits from str, at this point value must be str | |||
| self._enum_class(value) | |||
| return value | |||
| def load_dialect_impl(self, dialect): | |||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||
| return dialect.type_descriptor(VARCHAR(self._length)) | |||
| def process_result_value(self, value, dialect) -> _E | None: | |||
| def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: | |||
| if value is None: | |||
| return value | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"expected str, got {type(value)}") | |||
| # Type annotation guarantees value is str at this point | |||
| return self._enum_class(value) | |||
| def compare_values(self, x, y): | |||
| def compare_values(self, x: _E | None, y: _E | None) -> bool: | |||
| if x is None or y is None: | |||
| return x is y | |||
| return x == y | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from uuid import uuid4 | |||
| import sqlalchemy as sa | |||
| @@ -224,7 +224,7 @@ class Workflow(Base): | |||
| raise WorkflowDataError("nodes not found in workflow graph") | |||
| try: | |||
| node_config = next(filter(lambda node: node["id"] == node_id, nodes)) | |||
| node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) | |||
| except StopIteration: | |||
| raise NodeNotFoundError(node_id) | |||
| assert isinstance(node_config, dict) | |||
| @@ -289,7 +289,7 @@ class Workflow(Base): | |||
| def features_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.features) if self.features else {} | |||
| def user_input_form(self, to_old_structure: bool = False): | |||
| def user_input_form(self, to_old_structure: bool = False) -> list[Any]: | |||
| # get start node from graph | |||
| if not self.graph: | |||
| return [] | |||
| @@ -306,7 +306,7 @@ class Workflow(Base): | |||
| variables: list[Any] = start_node.get("data", {}).get("variables", []) | |||
| if to_old_structure: | |||
| old_structure_variables = [] | |||
| old_structure_variables: list[dict[str, Any]] = [] | |||
| for variable in variables: | |||
| old_structure_variables.append({variable["type"]: variable}) | |||
| @@ -346,9 +346,7 @@ class Workflow(Base): | |||
| @property | |||
| def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | |||
| # TODO: find some way to init `self._environment_variables` when instance created. | |||
| if self._environment_variables is None: | |||
| self._environment_variables = "{}" | |||
| # _environment_variables is guaranteed to be non-None due to server_default="{}" | |||
| # Use workflow.tenant_id to avoid relying on request user in background threads | |||
| tenant_id = self.tenant_id | |||
| @@ -362,17 +360,18 @@ class Workflow(Base): | |||
| ] | |||
| # decrypt secret variables value | |||
| def decrypt_func(var): | |||
| def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: | |||
| if isinstance(var, SecretVariable): | |||
| return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | |||
| return var | |||
| else: | |||
| raise AssertionError("this statement should be unreachable.") | |||
| # Other variable types are not supported for environment variables | |||
| raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") | |||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( | |||
| map(decrypt_func, results) | |||
| ) | |||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ | |||
| decrypt_func(var) for var in results | |||
| ] | |||
| return decrypted_results | |||
| @environment_variables.setter | |||
| @@ -400,7 +399,7 @@ class Workflow(Base): | |||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | |||
| # encrypt secret variables value | |||
| def encrypt_func(var): | |||
| def encrypt_func(var: Variable) -> Variable: | |||
| if isinstance(var, SecretVariable): | |||
| return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| else: | |||
| @@ -430,9 +429,7 @@ class Workflow(Base): | |||
| @property | |||
| def conversation_variables(self) -> Sequence[Variable]: | |||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||
| if self._conversation_variables is None: | |||
| self._conversation_variables = "{}" | |||
| # _conversation_variables is guaranteed to be non-None due to server_default="{}" | |||
| variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | |||
| results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | |||
| @@ -577,7 +574,7 @@ class WorkflowRun(Base): | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict) -> "WorkflowRun": | |||
| def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": | |||
| return cls( | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| @@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base): | |||
| __tablename__ = "workflow_node_executions" | |||
| @declared_attr | |||
| def __table_args__(cls): # noqa | |||
| @classmethod | |||
| def __table_args__(cls) -> Any: | |||
| return ( | |||
| PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | |||
| Index( | |||
| @@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base): | |||
| # MyPy may flag the following line because it doesn't recognize that | |||
| # the `declared_attr` decorator passes the receiving class as the first | |||
| # argument to this method, allowing us to reference class attributes. | |||
| cls.created_at.desc(), # type: ignore | |||
| cls.created_at.desc(), | |||
| ), | |||
| ) | |||
| @@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base): | |||
| return json.loads(self.execution_metadata) if self.execution_metadata else {} | |||
| @property | |||
| def extras(self): | |||
| def extras(self) -> dict[str, Any]: | |||
| from core.tools.tool_manager import ToolManager | |||
| extras = {} | |||
| extras: dict[str, Any] = {} | |||
| if self.execution_metadata_dict: | |||
| from core.workflow.nodes import NodeType | |||
| if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | |||
| tool_info = self.execution_metadata_dict["tool_info"] | |||
| tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] | |||
| extras["icon"] = ToolManager.get_tool_icon( | |||
| tenant_id=self.tenant_id, | |||
| provider_type=tool_info["provider_type"], | |||
| @@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base): | |||
| # making this attribute harder to access from outside the class. | |||
| __value: Segment | None | |||
| def __init__(self, *args, **kwargs): | |||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||
| """ | |||
| The constructor of `WorkflowDraftVariable` is not intended for | |||
| direct use outside this file. Its solo purpose is setup private state | |||
| @@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base): | |||
| self.__value = None | |||
| def get_selector(self) -> list[str]: | |||
| selector = json.loads(self.selector) | |||
| selector: Any = json.loads(self.selector) | |||
| if not isinstance(selector, list): | |||
| logger.error( | |||
| "invalid selector loaded from database, type=%s, value=%s", | |||
| type(selector), | |||
| type(selector).__name__, | |||
| self.selector, | |||
| ) | |||
| raise ValueError("invalid selector.") | |||
| return selector | |||
| return cast(list[str], selector) | |||
| def _set_selector(self, value: list[str]): | |||
| self.selector = json.dumps(value) | |||
| @@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base): | |||
| # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | |||
| if isinstance(value, dict): | |||
| if not maybe_file_object(value): | |||
| return value | |||
| return cast(Any, value) | |||
| return File.model_validate(value) | |||
| elif isinstance(value, list) and value: | |||
| first = value[0] | |||
| value_list = cast(list[Any], value) | |||
| first: Any = value_list[0] | |||
| if not maybe_file_object(first): | |||
| return value | |||
| return [File.model_validate(i) for i in value] | |||
| return cast(Any, value) | |||
| file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] | |||
| return cast(Any, file_list) | |||
| else: | |||
| return value | |||
| return cast(Any, value) | |||
| @classmethod | |||
| def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: | |||
| @@ -6,7 +6,6 @@ | |||
| "tests/", | |||
| "migrations/", | |||
| ".venv/", | |||
| "models/", | |||
| "core/", | |||
| "controllers/", | |||
| "tasks/", | |||
| @@ -1,8 +1,7 @@ | |||
| import threading | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| import pytz | |||
| from flask_login import current_user | |||
| import contexts | |||
| from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager | |||
| @@ -10,6 +9,7 @@ from core.plugin.impl.agent import PluginAgentClient | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.tools.tool_manager import ToolManager | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message, MessageAgentThought | |||
| @@ -61,14 +61,15 @@ class AgentService: | |||
| executor = executor.name | |||
| else: | |||
| executor = "Unknown" | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.timezone is not None | |||
| timezone = pytz.timezone(current_user.timezone) | |||
| app_model_config = app_model.app_model_config | |||
| if not app_model_config: | |||
| raise ValueError("App model config not found") | |||
| result = { | |||
| result: dict[str, Any] = { | |||
| "meta": { | |||
| "status": "success", | |||
| "executor": executor, | |||
| @@ -2,7 +2,6 @@ import uuid | |||
| from typing import Optional | |||
| import pandas as pd | |||
| from flask_login import current_user | |||
| from sqlalchemy import or_, select | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation | |||
| from services.feature_service import FeatureService | |||
| from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task | |||
| @@ -24,6 +25,7 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -62,6 +64,7 @@ class AppAnnotationService: | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() | |||
| assert current_user.current_tenant_id is not None | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay( | |||
| annotation.id, | |||
| @@ -84,6 +87,8 @@ class AppAnnotationService: | |||
| enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | |||
| # send batch add segments task | |||
| redis_client.setnx(enable_app_annotation_job_key, "waiting") | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| enable_annotation_reply_task.delay( | |||
| str(job_id), | |||
| app_id, | |||
| @@ -97,6 +102,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def disable_app_annotation(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" | |||
| cache_result = redis_client.get(disable_app_annotation_key) | |||
| if cache_result is not None: | |||
| @@ -113,6 +120,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -145,6 +154,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def export_annotation_list_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -164,6 +175,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -193,6 +206,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -230,6 +245,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def delete_app_annotation(cls, app_id: str, annotation_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -269,6 +286,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -317,6 +336,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def batch_import_app_annotations(cls, app_id, file: FileStorage): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -355,6 +376,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -425,6 +448,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_app_annotation_setting_by_app_id(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -451,6 +476,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -491,6 +518,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def clear_all_annotations(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -2,7 +2,6 @@ import json | |||
| import logging | |||
| from typing import Optional, TypedDict, cast | |||
| from flask_login import current_user | |||
| from flask_sqlalchemy.pagination import Pagination | |||
| from configs import dify_config | |||
| @@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager | |||
| from events.app_event import app_was_created | |||
| from extensions.ext_database import db | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, AppMode, AppModelConfig, Site | |||
| from models.tools import ApiToolProvider | |||
| @@ -168,9 +168,13 @@ class AppService: | |||
| """ | |||
| Get App | |||
| """ | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get original app model config | |||
| if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | |||
| model_config = app.app_model_config | |||
| if not model_config: | |||
| return app | |||
| agent_mode = model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| for tool in agent_mode.get("tools") or []: | |||
| @@ -205,7 +209,8 @@ class AppService: | |||
| pass | |||
| # override agent mode | |||
| model_config.agent_mode = json.dumps(agent_mode) | |||
| if model_config: | |||
| model_config.agent_mode = json.dumps(agent_mode) | |||
| class ModifiedApp(App): | |||
| """ | |||
| @@ -239,6 +244,7 @@ class AppService: | |||
| :param args: request args | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.name = args["name"] | |||
| app.description = args["description"] | |||
| app.icon_type = args["icon_type"] | |||
| @@ -259,6 +265,7 @@ class AppService: | |||
| :param name: new name | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.name = name | |||
| app.updated_by = current_user.id | |||
| app.updated_at = naive_utc_now() | |||
| @@ -274,6 +281,7 @@ class AppService: | |||
| :param icon_background: new icon_background | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.icon = icon | |||
| app.icon_background = icon_background | |||
| app.updated_by = current_user.id | |||
| @@ -291,7 +299,7 @@ class AppService: | |||
| """ | |||
| if enable_site == app.enable_site: | |||
| return app | |||
| assert current_user is not None | |||
| app.enable_site = enable_site | |||
| app.updated_by = current_user.id | |||
| app.updated_at = naive_utc_now() | |||
| @@ -308,6 +316,7 @@ class AppService: | |||
| """ | |||
| if enable_api == app.enable_api: | |||
| return app | |||
| assert current_user is not None | |||
| app.enable_api = enable_api | |||
| app.updated_by = current_user.id | |||
| @@ -12,7 +12,7 @@ from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from models.enums import MessageStatus | |||
| from models.model import App, AppMode, AppModelConfig, Message | |||
| from models.model import App, AppMode, Message | |||
| from services.errors.audio import ( | |||
| AudioTooLargeServiceError, | |||
| NoAudioUploadedServiceError, | |||
| @@ -40,7 +40,9 @@ class AudioService: | |||
| if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | |||
| raise ValueError("Speech to text is not enabled") | |||
| else: | |||
| app_model_config: AppModelConfig = app_model.app_model_config | |||
| app_model_config = app_model.app_model_config | |||
| if not app_model_config: | |||
| raise ValueError("Speech to text is not enabled") | |||
| if not app_model_config.speech_to_text_dict["enabled"]: | |||
| raise ValueError("Speech to text is not enabled") | |||
| @@ -70,7 +70,7 @@ class BillingService: | |||
| return response.json() | |||
| @staticmethod | |||
| def is_tenant_owner_or_admin(current_user): | |||
| def is_tenant_owner_or_admin(current_user: Account): | |||
| tenant_id = current_user.current_tenant_id | |||
| join: Optional[TenantAccountJoin] = ( | |||
| @@ -8,7 +8,7 @@ import uuid | |||
| from collections import Counter | |||
| from typing import Any, Literal, Optional | |||
| from flask_login import current_user | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import exists, func, select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -26,6 +26,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs import helper | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account, TenantAccountRole | |||
| from models.dataset import ( | |||
| AppDatasetJoin, | |||
| @@ -498,8 +499,11 @@ class DatasetService: | |||
| data: Update data dictionary | |||
| filtered_data: Filtered update data to modify | |||
| """ | |||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||
| try: | |||
| model_manager = ModelManager() | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=data["embedding_model_provider"], | |||
| @@ -611,8 +615,12 @@ class DatasetService: | |||
| data: Update data dictionary | |||
| filtered_data: Filtered update data to modify | |||
| """ | |||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||
| model_manager = ModelManager() | |||
| try: | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=data["embedding_model_provider"], | |||
| @@ -720,6 +728,8 @@ class DatasetService: | |||
| @staticmethod | |||
| def get_dataset_auto_disable_logs(dataset_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | |||
| return { | |||
| @@ -924,6 +934,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | |||
| assert isinstance(current_user, Account) | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| @@ -973,7 +985,7 @@ class DocumentService: | |||
| file_ids = [ | |||
| document.data_source_info_dict["upload_file_id"] | |||
| for document in documents | |||
| if document.data_source_type == "upload_file" | |||
| if document.data_source_type == "upload_file" and document.data_source_info_dict | |||
| ] | |||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | |||
| @@ -983,6 +995,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def rename_document(dataset_id: str, document_id: str, name: str) -> Document: | |||
| assert isinstance(current_user, Account) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise ValueError("Dataset not found.") | |||
| @@ -1012,6 +1026,7 @@ class DocumentService: | |||
| if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: | |||
| raise DocumentIndexingError() | |||
| # update document to be paused | |||
| assert current_user is not None | |||
| document.is_paused = True | |||
| document.paused_by = current_user.id | |||
| document.paused_at = naive_utc_now() | |||
| @@ -1067,8 +1082,9 @@ class DocumentService: | |||
| # sync document indexing | |||
| document.indexing_status = "waiting" | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info["mode"] = "scrape" | |||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||
| if data_source_info: | |||
| data_source_info["mode"] = "scrape" | |||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -1097,6 +1113,9 @@ class DocumentService: | |||
| # check doc_form | |||
| DatasetService.check_doc_form(dataset, knowledge_config.doc_form) | |||
| # check document limit | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -1433,6 +1452,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_tenant_documents_count(): | |||
| assert isinstance(current_user, Account) | |||
| documents_count = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| @@ -1453,6 +1474,8 @@ class DocumentService: | |||
| dataset_process_rule: Optional[DatasetProcessRule] = None, | |||
| created_from: str = "web", | |||
| ): | |||
| assert isinstance(current_user, Account) | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| document = DocumentService.get_document(dataset.id, document_data.original_document_id) | |||
| if document is None: | |||
| @@ -1512,7 +1535,7 @@ class DocumentService: | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .where( | |||
| db.and_( | |||
| sa.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| @@ -1573,6 +1596,9 @@ class DocumentService: | |||
| @staticmethod | |||
| def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -2012,6 +2038,9 @@ class SegmentService: | |||
| @classmethod | |||
| def create_segment(cls, args: dict, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| content = args["content"] | |||
| doc_id = str(uuid.uuid4()) | |||
| segment_hash = helper.generate_text_hash(content) | |||
| @@ -2074,6 +2103,9 @@ class SegmentService: | |||
| @classmethod | |||
| def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| lock_name = f"multi_add_segment_lock_document_id_{document.id}" | |||
| increment_word_count = 0 | |||
| with redis_client.lock(lock_name, timeout=600): | |||
| @@ -2157,6 +2189,9 @@ class SegmentService: | |||
| @classmethod | |||
| def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| indexing_cache_key = f"segment_{segment.id}_indexing" | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| @@ -2348,6 +2383,7 @@ class SegmentService: | |||
| @classmethod | |||
| def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| segments = ( | |||
| db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) | |||
| .where( | |||
| @@ -2378,6 +2414,8 @@ class SegmentService: | |||
| def update_segments_status( | |||
| cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document | |||
| ): | |||
| assert current_user is not None | |||
| # Check if segment_ids is not empty to avoid WHERE false condition | |||
| if not segment_ids or len(segment_ids) == 0: | |||
| return | |||
| @@ -2440,6 +2478,8 @@ class SegmentService: | |||
| def create_child_chunk( | |||
| cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset | |||
| ) -> ChildChunk: | |||
| assert isinstance(current_user, Account) | |||
| lock_name = f"add_child_lock_{segment.id}" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| index_node_id = str(uuid.uuid4()) | |||
| @@ -2487,6 +2527,8 @@ class SegmentService: | |||
| document: Document, | |||
| dataset: Dataset, | |||
| ) -> list[ChildChunk]: | |||
| assert isinstance(current_user, Account) | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where( | |||
| @@ -2561,6 +2603,8 @@ class SegmentService: | |||
| document: Document, | |||
| dataset: Dataset, | |||
| ) -> ChildChunk: | |||
| assert current_user is not None | |||
| try: | |||
| child_chunk.content = content | |||
| child_chunk.word_count = len(content) | |||
| @@ -2591,6 +2635,8 @@ class SegmentService: | |||
| def get_child_chunks( | |||
| cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None | |||
| ): | |||
| assert isinstance(current_user, Account) | |||
| query = ( | |||
| select(ChildChunk) | |||
| .filter_by( | |||
| @@ -114,8 +114,9 @@ class ExternalDatasetService: | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: | |||
| args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||
| settings = args.get("settings") | |||
| if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict: | |||
| settings["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||
| external_knowledge_api.name = args.get("name") | |||
| external_knowledge_api.description = args.get("description", "") | |||
| @@ -3,7 +3,6 @@ import os | |||
| import uuid | |||
| from typing import Any, Literal, Union | |||
| from flask_login import current_user | |||
| from werkzeug.exceptions import NotFound | |||
| from configs import dify_config | |||
| @@ -19,6 +18,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.helper import extract_tenant_id | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.enums import CreatorUserRole | |||
| from models.model import EndUser, UploadFile | |||
| @@ -111,6 +111,9 @@ class FileService: | |||
| @staticmethod | |||
| def upload_text(text: str, text_name: str) -> UploadFile: | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| if len(text_name) > 200: | |||
| text_name = text_name[:200] | |||
| # user uuid as file name | |||
| @@ -226,7 +226,7 @@ class MCPToolManageService: | |||
| def update_mcp_provider_credentials( | |||
| cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | |||
| ): | |||
| provider_controller = MCPToolProviderController._from_db(mcp_provider) | |||
| provider_controller = MCPToolProviderController.from_db(mcp_provider) | |||
| tool_configuration = ProviderConfigEncrypter( | |||
| tenant_id=mcp_provider.tenant_id, | |||
| config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] | |||
| @@ -1,10 +1,11 @@ | |||
| import json | |||
| from unittest.mock import MagicMock, patch | |||
| from unittest.mock import MagicMock, create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from models.account import Account | |||
| from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought | |||
| from services.account_service import AccountService, TenantService | |||
| from services.agent_service import AgentService | |||
| @@ -21,7 +22,7 @@ class TestAgentService: | |||
| patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, | |||
| patch("services.agent_service.ToolManager") as mock_tool_manager, | |||
| patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, | |||
| patch("services.agent_service.current_user") as mock_current_user, | |||
| patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||
| patch("services.app_service.FeatureService") as mock_feature_service, | |||
| patch("services.app_service.EnterpriseService") as mock_enterprise_service, | |||
| patch("services.app_service.ModelManager") as mock_model_manager, | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from werkzeug.exceptions import NotFound | |||
| from models.account import Account | |||
| from models.model import MessageAnnotation | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.app_service import AppService | |||
| @@ -24,7 +25,9 @@ class TestAnnotationService: | |||
| patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, | |||
| patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, | |||
| patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, | |||
| patch("services.annotation_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.annotation_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| ): | |||
| # Setup default mock returns | |||
| mock_account_feature_service.get_features.return_value.billing.enabled = False | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from constants.model_template import default_app_templates | |||
| from models.account import Account | |||
| from models.model import App, Site | |||
| from services.account_service import AccountService, TenantService | |||
| from services.app_service import AppService | |||
| @@ -161,8 +162,13 @@ class TestAppService: | |||
| app_service = AppService() | |||
| created_app = app_service.create_app(tenant.id, app_args, account) | |||
| # Get app using the service | |||
| retrieved_app = app_service.get_app(created_app) | |||
| # Get app using the service - needs current_user mock | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| retrieved_app = app_service.get_app(created_app) | |||
| # Verify retrieved app matches created app | |||
| assert retrieved_app.id == created_app.id | |||
| @@ -406,7 +412,11 @@ class TestAppService: | |||
| "use_icon_as_answer_icon": True, | |||
| } | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app(app, update_args) | |||
| # Verify updated fields | |||
| @@ -456,7 +466,11 @@ class TestAppService: | |||
| # Update app name | |||
| new_name = "New App Name" | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_name(app, new_name) | |||
| assert updated_app.name == new_name | |||
| @@ -504,7 +518,11 @@ class TestAppService: | |||
| # Update app icon | |||
| new_icon = "🌟" | |||
| new_icon_background = "#FFD93D" | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) | |||
| assert updated_app.icon == new_icon | |||
| @@ -551,13 +569,17 @@ class TestAppService: | |||
| original_site_status = app.enable_site | |||
| # Update site status to disabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_site_status(app, False) | |||
| assert updated_app.enable_site is False | |||
| assert updated_app.updated_by == account.id | |||
| # Update site status back to enabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_site_status(updated_app, True) | |||
| assert updated_app.enable_site is True | |||
| assert updated_app.updated_by == account.id | |||
| @@ -602,13 +624,17 @@ class TestAppService: | |||
| original_api_status = app.enable_api | |||
| # Update API status to disabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_api_status(app, False) | |||
| assert updated_app.enable_api is False | |||
| assert updated_app.updated_by == account.id | |||
| # Update API status back to enabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_api_status(updated_app, True) | |||
| assert updated_app.enable_api is True | |||
| assert updated_app.updated_by == account.id | |||
| @@ -1,6 +1,6 @@ | |||
| import hashlib | |||
| from io import BytesIO | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -417,11 +417,12 @@ class TestFileService: | |||
| text = "This is a test text content" | |||
| text_name = "test_text.txt" | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | |||
| assert upload_file is not None | |||
| @@ -443,11 +444,12 @@ class TestFileService: | |||
| text = "test content" | |||
| long_name = "a" * 250 # Longer than 200 characters | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=long_name) | |||
| # Verify name was truncated | |||
| @@ -846,11 +848,12 @@ class TestFileService: | |||
| text = "" | |||
| text_name = "empty.txt" | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | |||
| assert upload_file is not None | |||
| @@ -1,4 +1,4 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -17,7 +17,9 @@ class TestMetadataService: | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.metadata_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.metadata_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| patch("services.metadata_service.redis_client") as mock_redis_client, | |||
| patch("services.dataset_service.DocumentService") as mock_document_service, | |||
| ): | |||
| @@ -1,4 +1,4 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -17,7 +17,7 @@ class TestTagService: | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.tag_service.current_user") as mock_current_user, | |||
| patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||
| ): | |||
| # Setup default mock returns | |||
| mock_current_user.current_tenant_id = "test-tenant-id" | |||
| @@ -1,5 +1,5 @@ | |||
| from datetime import datetime | |||
| from unittest.mock import MagicMock, patch | |||
| from unittest.mock import MagicMock, create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -231,9 +231,10 @@ class TestWebsiteService: | |||
| fake = Faker() | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="firecrawl", | |||
| @@ -285,9 +286,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="watercrawl", | |||
| @@ -336,9 +338,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request for single page crawling | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -389,9 +392,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request with invalid provider | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="invalid_provider", | |||
| @@ -419,9 +423,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") | |||
| @@ -463,9 +468,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") | |||
| @@ -502,9 +508,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") | |||
| @@ -544,9 +551,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request with invalid provider | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") | |||
| @@ -569,9 +577,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Mock missing credentials | |||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None | |||
| @@ -597,9 +606,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Mock missing API key in config | |||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { | |||
| "config": {"base_url": "https://api.example.com"} | |||
| @@ -995,9 +1005,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request for sub-page crawling | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -1054,9 +1065,10 @@ class TestWebsiteService: | |||
| mock_external_service_dependencies["requests"].get.return_value = mock_failed_response | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -1096,9 +1108,10 @@ class TestWebsiteService: | |||
| mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") | |||
| @@ -154,7 +154,7 @@ class TestEnumText: | |||
| TestCase( | |||
| name="session insert with invalid type", | |||
| action=lambda s: _session_insert_with_value(s, 1), | |||
| exc_type=TypeError, | |||
| exc_type=ValueError, | |||
| ), | |||
| TestCase( | |||
| name="insert with invalid value", | |||
| @@ -164,7 +164,7 @@ class TestEnumText: | |||
| TestCase( | |||
| name="insert with invalid type", | |||
| action=lambda s: _insert_with_user(s, 1), | |||
| exc_type=TypeError, | |||
| exc_type=ValueError, | |||
| ), | |||
| ] | |||
| for idx, c in enumerate(cases, 1): | |||
| @@ -2,11 +2,12 @@ import datetime | |||
| from typing import Any, Optional | |||
| # Mock redis_client before importing dataset_service | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.account import Account | |||
| from models.dataset import Dataset, ExternalKnowledgeBindings | |||
| from services.dataset_service import DatasetService | |||
| from services.errors.account import NoPermissionError | |||
| @@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory: | |||
| @staticmethod | |||
| def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: | |||
| """Create a mock current user.""" | |||
| current_user = Mock() | |||
| current_user = create_autospec(Account, instance=True) | |||
| current_user.current_tenant_id = tenant_id | |||
| return current_user | |||
| @@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset: | |||
| "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" | |||
| ) as mock_get_binding, | |||
| patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, | |||
| patch("services.dataset_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.dataset_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| ): | |||
| mock_current_user.current_tenant_id = "tenant-123" | |||
| yield { | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from flask_restx import reqparse | |||
| from werkzeug.exceptions import BadRequest | |||
| from models.account import Account | |||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | |||
| from services.metadata_service import MetadataService | |||
| @@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation: | |||
| mock_metadata_args.name = None | |||
| mock_metadata_args.type = "string" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # Should crash with TypeError | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||
| # Test update method as well | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | |||
| @@ -1,8 +1,9 @@ | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from flask_restx import reqparse | |||
| from models.account import Account | |||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | |||
| from services.metadata_service import MetadataService | |||
| @@ -24,20 +25,22 @@ class TestMetadataNullableBug: | |||
| mock_metadata_args.name = None # This will cause len() to crash | |||
| mock_metadata_args.type = "string" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # This should crash with TypeError when calling len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||
| def test_metadata_service_update_with_none_name_crashes(self): | |||
| """Test that MetadataService.update_metadata_name crashes when name is None.""" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # This should crash with TypeError when calling len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | |||
| @@ -81,10 +84,11 @@ class TestMetadataNullableBug: | |||
| mock_metadata_args.name = None # From args["name"] | |||
| mock_metadata_args.type = None # From args["type"] | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # Step 4: Service layer crashes on len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||
| @@ -72,6 +72,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx | |||
| const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false) | |||
| const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false) | |||
| const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([]) | |||
| const [showExportWarning, setShowExportWarning] = useState(false) | |||
| const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ | |||
| name, | |||
| @@ -159,6 +160,14 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx | |||
| onExport() | |||
| return | |||
| } | |||
| setShowExportWarning(true) | |||
| } | |||
| const handleConfirmExport = async () => { | |||
| if (!appDetail) | |||
| return | |||
| setShowExportWarning(false) | |||
| try { | |||
| const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) | |||
| const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') | |||
| @@ -407,6 +416,16 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx | |||
| onClose={() => setSecretEnvList([])} | |||
| /> | |||
| )} | |||
| {showExportWarning && ( | |||
| <Confirm | |||
| type="info" | |||
| isShow={showExportWarning} | |||
| title={t('workflow.sidebar.exportWarning')} | |||
| content={t('workflow.sidebar.exportWarningDesc')} | |||
| onConfirm={handleConfirmExport} | |||
| onCancel={() => setShowExportWarning(false)} | |||
| /> | |||
| )} | |||
| </div> | |||
| ) | |||
| } | |||
| @@ -32,6 +32,7 @@ export type ActionButtonProps = { | |||
| size?: 'xs' | 's' | 'm' | 'l' | 'xl' | |||
| state?: ActionButtonState | |||
| styleCss?: CSSProperties | |||
| ref?: React.Ref<HTMLButtonElement> | |||
| } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof actionButtonVariants> | |||
| function getActionButtonState(state: ActionButtonState) { | |||
| @@ -49,24 +50,22 @@ function getActionButtonState(state: ActionButtonState) { | |||
| } | |||
| } | |||
| const ActionButton = React.forwardRef<HTMLButtonElement, ActionButtonProps>( | |||
| ({ className, size, state = ActionButtonState.Default, styleCss, children, ...props }, ref) => { | |||
| return ( | |||
| <button | |||
| type='button' | |||
| className={classNames( | |||
| actionButtonVariants({ className, size }), | |||
| getActionButtonState(state), | |||
| )} | |||
| ref={ref} | |||
| style={styleCss} | |||
| {...props} | |||
| > | |||
| {children} | |||
| </button> | |||
| ) | |||
| }, | |||
| ) | |||
| const ActionButton = ({ className, size, state = ActionButtonState.Default, styleCss, children, ref, ...props }: ActionButtonProps) => { | |||
| return ( | |||
| <button | |||
| type='button' | |||
| className={classNames( | |||
| actionButtonVariants({ className, size }), | |||
| getActionButtonState(state), | |||
| )} | |||
| ref={ref} | |||
| style={styleCss} | |||
| {...props} | |||
| > | |||
| {children} | |||
| </button> | |||
| ) | |||
| } | |||
| ActionButton.displayName = 'ActionButton' | |||
| export default ActionButton | |||
| @@ -35,27 +35,26 @@ export type ButtonProps = { | |||
| loading?: boolean | |||
| styleCss?: CSSProperties | |||
| spinnerClassName?: string | |||
| ref?: React.Ref<HTMLButtonElement> | |||
| } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof buttonVariants> | |||
| const Button = React.forwardRef<HTMLButtonElement, ButtonProps>( | |||
| ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ...props }, ref) => { | |||
| return ( | |||
| <button | |||
| type='button' | |||
| className={classNames( | |||
| buttonVariants({ variant, size, className }), | |||
| destructive && 'btn-destructive', | |||
| )} | |||
| ref={ref} | |||
| style={styleCss} | |||
| {...props} | |||
| > | |||
| {children} | |||
| {loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />} | |||
| </button> | |||
| ) | |||
| }, | |||
| ) | |||
| const Button = ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ref, ...props }: ButtonProps) => { | |||
| return ( | |||
| <button | |||
| type='button' | |||
| className={classNames( | |||
| buttonVariants({ variant, size, className }), | |||
| destructive && 'btn-destructive', | |||
| )} | |||
| ref={ref} | |||
| style={styleCss} | |||
| {...props} | |||
| > | |||
| {children} | |||
| {loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />} | |||
| </button> | |||
| ) | |||
| } | |||
| Button.displayName = 'Button' | |||
| export default Button | |||
| @@ -30,9 +30,10 @@ export type InputProps = { | |||
| wrapperClassName?: string | |||
| styleCss?: CSSProperties | |||
| unit?: string | |||
| ref?: React.Ref<HTMLInputElement> | |||
| } & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants> | |||
| const Input = React.forwardRef<HTMLInputElement, InputProps>(({ | |||
| const Input = ({ | |||
| size, | |||
| disabled, | |||
| destructive, | |||
| @@ -46,8 +47,9 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({ | |||
| placeholder, | |||
| onChange = noop, | |||
| unit, | |||
| ref, | |||
| ...props | |||
| }, ref) => { | |||
| }: InputProps) => { | |||
| const { t } = useTranslation() | |||
| return ( | |||
| <div className={cn('relative w-full', wrapperClassName)}> | |||
| @@ -93,7 +95,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({ | |||
| } | |||
| </div> | |||
| ) | |||
| }) | |||
| } | |||
| Input.displayName = 'Input' | |||
| @@ -107,10 +107,13 @@ const initMermaid = () => { | |||
| return isMermaidInitialized | |||
| } | |||
| const Flowchart = React.forwardRef((props: { | |||
| type FlowchartProps = { | |||
| PrimitiveCode: string | |||
| theme?: 'light' | 'dark' | |||
| }, ref) => { | |||
| ref?: React.Ref<HTMLDivElement> | |||
| } | |||
| const Flowchart = (props: FlowchartProps) => { | |||
| const { t } = useTranslation() | |||
| const [svgString, setSvgString] = useState<string | null>(null) | |||
| const [look, setLook] = useState<'classic' | 'handDrawn'>('classic') | |||
| @@ -490,7 +493,7 @@ const Flowchart = React.forwardRef((props: { | |||
| } | |||
| return ( | |||
| <div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> | |||
| <div ref={props.ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> | |||
| <div className={themeClasses.segmented}> | |||
| <div className="msh-segmented-group"> | |||
| <label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1"> | |||
| @@ -572,7 +575,7 @@ const Flowchart = React.forwardRef((props: { | |||
| )} | |||
| </div> | |||
| ) | |||
| }) | |||
| } | |||
| Flowchart.displayName = 'Flowchart' | |||
| @@ -24,30 +24,29 @@ export type TextareaProps = { | |||
| disabled?: boolean | |||
| destructive?: boolean | |||
| styleCss?: CSSProperties | |||
| ref?: React.Ref<HTMLTextAreaElement> | |||
| } & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants> | |||
| const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>( | |||
| ({ className, value, onChange, disabled, size, destructive, styleCss, ...props }, ref) => { | |||
| return ( | |||
| <textarea | |||
| ref={ref} | |||
| style={styleCss} | |||
| className={cn( | |||
| 'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs', | |||
| textareaVariants({ size }), | |||
| disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled', | |||
| destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive', | |||
| className, | |||
| )} | |||
| value={value ?? ''} | |||
| onChange={onChange} | |||
| disabled={disabled} | |||
| {...props} | |||
| > | |||
| </textarea> | |||
| ) | |||
| }, | |||
| ) | |||
| const Textarea = ({ className, value, onChange, disabled, size, destructive, styleCss, ref, ...props }: TextareaProps) => { | |||
| return ( | |||
| <textarea | |||
| ref={ref} | |||
| style={styleCss} | |||
| className={cn( | |||
| 'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs', | |||
| textareaVariants({ size }), | |||
| disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled', | |||
| destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive', | |||
| className, | |||
| )} | |||
| value={value ?? ''} | |||
| onChange={onChange} | |||
| disabled={disabled} | |||
| {...props} | |||
| > | |||
| </textarea> | |||
| ) | |||
| } | |||
| Textarea.displayName = 'Textarea' | |||
| export default Textarea | |||
| @@ -1,14 +1,14 @@ | |||
| import type { ComponentProps, FC, ReactNode } from 'react' | |||
| import { forwardRef } from 'react' | |||
| import classNames from '@/utils/classnames' | |||
| export type PreviewContainerProps = ComponentProps<'div'> & { | |||
| header: ReactNode | |||
| mainClassName?: string | |||
| ref?: React.Ref<HTMLDivElement> | |||
| } | |||
| export const PreviewContainer: FC<PreviewContainerProps> = forwardRef((props, ref) => { | |||
| const { children, className, header, mainClassName, ...rest } = props | |||
| export const PreviewContainer: FC<PreviewContainerProps> = (props) => { | |||
| const { children, className, header, mainClassName, ref, ...rest } = props | |||
| return <div className={className}> | |||
| <div | |||
| {...rest} | |||
| @@ -25,5 +25,5 @@ export const PreviewContainer: FC<PreviewContainerProps> = forwardRef((props, re | |||
| </main> | |||
| </div> | |||
| </div> | |||
| }) | |||
| } | |||
| PreviewContainer.displayName = 'PreviewContainer' | |||
| @@ -740,84 +740,6 @@ Workflow applications offers non-session support and is ideal for translation, a | |||
| --- | |||
| <Heading | |||
| url='/files/:file_id/preview' | |||
| method='GET' | |||
| title='File Preview' | |||
| name='#file-preview' | |||
| /> | |||
| <Row> | |||
| <Col> | |||
| Preview or download uploaded files. This endpoint allows you to access files that have been previously uploaded via the File Upload API. | |||
| <i>Files can only be accessed if they belong to messages within the requesting application.</i> | |||
| ### Path Parameters | |||
| - `file_id` (string) Required | |||
| The unique identifier of the file to preview, obtained from the File Upload API response. | |||
| ### Query Parameters | |||
| - `as_attachment` (boolean) Optional | |||
| Whether to force download the file as an attachment. Default is `false` (preview in browser). | |||
| ### Response | |||
| Returns the file content with appropriate headers for browser display or download. | |||
| - `Content-Type` Set based on file mime type | |||
| - `Content-Length` File size in bytes (if available) | |||
| - `Content-Disposition` Set to "attachment" if `as_attachment=true` | |||
| - `Cache-Control` Caching headers for performance | |||
| - `Accept-Ranges` Set to "bytes" for audio/video files | |||
| ### Errors | |||
| - 400, `invalid_param`, abnormal parameter input | |||
| - 403, `file_access_denied`, file access denied or file does not belong to current application | |||
| - 404, `file_not_found`, file not found or has been deleted | |||
| - 500, internal server error | |||
| </Col> | |||
| <Col sticky> | |||
| ### Request Example | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||
| --header 'Authorization: Bearer {api_key}'`} | |||
| /> | |||
| ### Download as Attachment | |||
| <CodeGroup | |||
| title="Download Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview?as_attachment=true" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||
| --header 'Authorization: Bearer {api_key}' \\ | |||
| --output downloaded_file.png`} | |||
| /> | |||
| ### Response Headers Example | |||
| <CodeGroup title="Response Headers"> | |||
| ```http {{ title: 'Headers - Image Preview' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| ### Download Response Headers | |||
| <CodeGroup title="Download Response Headers"> | |||
| ```http {{ title: 'Headers - File Download' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Heading | |||
| url='/workflows/logs' | |||
| method='GET' | |||
| @@ -736,84 +736,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| --- | |||
| <Heading | |||
| url='/files/:file_id/preview' | |||
| method='GET' | |||
| title='ファイルプレビュー' | |||
| name='#file-preview' | |||
| /> | |||
| <Row> | |||
| <Col> | |||
| アップロードされたファイルをプレビューまたはダウンロードします。このエンドポイントを使用すると、以前にファイルアップロード API でアップロードされたファイルにアクセスできます。 | |||
| <i>ファイルは、リクエストしているアプリケーションのメッセージ範囲内にある場合のみアクセス可能です。</i> | |||
| ### パスパラメータ | |||
| - `file_id` (string) 必須 | |||
| プレビューするファイルの一意識別子。ファイルアップロード API レスポンスから取得します。 | |||
| ### クエリパラメータ | |||
| - `as_attachment` (boolean) オプション | |||
| ファイルを添付ファイルとして強制ダウンロードするかどうか。デフォルトは `false`(ブラウザでプレビュー)。 | |||
| ### レスポンス | |||
| ブラウザ表示またはダウンロード用の適切なヘッダー付きでファイル内容を返します。 | |||
| - `Content-Type` ファイル MIME タイプに基づいて設定 | |||
| - `Content-Length` ファイルサイズ(バイト、利用可能な場合) | |||
| - `Content-Disposition` `as_attachment=true` の場合は "attachment" に設定 | |||
| - `Cache-Control` パフォーマンス向上のためのキャッシュヘッダー | |||
| - `Accept-Ranges` 音声/動画ファイルの場合は "bytes" に設定 | |||
| ### エラー | |||
| - 400, `invalid_param`, パラメータ入力異常 | |||
| - 403, `file_access_denied`, ファイルアクセス拒否またはファイルが現在のアプリケーションに属していません | |||
| - 404, `file_not_found`, ファイルが見つからないか削除されています | |||
| - 500, サーバー内部エラー | |||
| </Col> | |||
| <Col sticky> | |||
| ### リクエスト例 | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||
| --header 'Authorization: Bearer {api_key}'`} | |||
| /> | |||
| ### 添付ファイルとしてダウンロード | |||
| <CodeGroup | |||
| title="Download Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview?as_attachment=true" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||
| --header 'Authorization: Bearer {api_key}' \\ | |||
| --output downloaded_file.png`} | |||
| /> | |||
| ### レスポンスヘッダー例 | |||
| <CodeGroup title="Response Headers"> | |||
| ```http {{ title: 'ヘッダー - 画像プレビュー' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| ### ダウンロードレスポンスヘッダー | |||
| <CodeGroup title="Download Response Headers"> | |||
| ```http {{ title: 'ヘッダー - ファイルダウンロード' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Heading | |||
| url='/workflows/logs' | |||
| method='GET' | |||
| @@ -727,83 +727,6 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 | |||
| </Row> | |||
| --- | |||
| <Heading | |||
| url='/files/:file_id/preview' | |||
| method='GET' | |||
| title='文件预览' | |||
| name='#file-preview' | |||
| /> | |||
| <Row> | |||
| <Col> | |||
| 预览或下载已上传的文件。此端点允许您访问先前通过文件上传 API 上传的文件。 | |||
| <i>文件只能在属于请求应用程序的消息范围内访问。</i> | |||
| ### 路径参数 | |||
| - `file_id` (string) 必需 | |||
| 要预览的文件的唯一标识符,从文件上传 API 响应中获得。 | |||
| ### 查询参数 | |||
| - `as_attachment` (boolean) 可选 | |||
| 是否强制将文件作为附件下载。默认为 `false`(在浏览器中预览)。 | |||
| ### 响应 | |||
| 返回带有适当浏览器显示或下载标头的文件内容。 | |||
| - `Content-Type` 根据文件 MIME 类型设置 | |||
| - `Content-Length` 文件大小(以字节为单位,如果可用) | |||
| - `Content-Disposition` 如果 `as_attachment=true` 则设置为 "attachment" | |||
| - `Cache-Control` 用于性能的缓存标头 | |||
| - `Accept-Ranges` 对于音频/视频文件设置为 "bytes" | |||
| ### 错误 | |||
| - 400, `invalid_param`, 参数输入异常 | |||
| - 403, `file_access_denied`, 文件访问被拒绝或文件不属于当前应用程序 | |||
| - 404, `file_not_found`, 文件未找到或已被删除 | |||
| - 500, 服务内部错误 | |||
| </Col> | |||
| <Col sticky> | |||
| ### 请求示例 | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||
| --header 'Authorization: Bearer {api_key}'`} | |||
| /> | |||
| ### 作为附件下载 | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="GET" | |||
| label="/files/:file_id/preview?as_attachment=true" | |||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||
| --header 'Authorization: Bearer {api_key}' \\ | |||
| --output downloaded_file.png`} | |||
| /> | |||
| ### 响应标头示例 | |||
| <CodeGroup title="Response Headers"> | |||
| ```http {{ title: 'Headers - 图片预览' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| ### 文件下载响应标头 | |||
| <CodeGroup title="Download Response Headers"> | |||
| ```http {{ title: 'Headers - 文件下载' }} | |||
| Content-Type: image/png | |||
| Content-Length: 1024 | |||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||
| Cache-Control: public, max-age=3600 | |||
| ``` | |||
| </CodeGroup> | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Heading | |||
| url='/workflows/logs' | |||
| method='GET' | |||
| @@ -1,5 +1,4 @@ | |||
| 'use client' | |||
| import type { ForwardRefRenderFunction } from 'react' | |||
| import { useImperativeHandle } from 'react' | |||
| import React, { useCallback, useEffect, useMemo, useState } from 'react' | |||
| import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types' | |||
| @@ -21,6 +20,7 @@ type Props = { | |||
| onDeSelectAll: () => void | |||
| onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void | |||
| isFromMarketPlace?: boolean | |||
| ref?: React.Ref<ExposeRefs> | |||
| } | |||
| export type ExposeRefs = { | |||
| @@ -28,7 +28,7 @@ export type ExposeRefs = { | |||
| deSelectAllPlugins: () => void | |||
| } | |||
| const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({ | |||
| const InstallByDSLList = ({ | |||
| allPlugins, | |||
| selectedPlugins, | |||
| onSelect, | |||
| @@ -36,7 +36,8 @@ const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({ | |||
| onDeSelectAll, | |||
| onLoadedAllPlugin, | |||
| isFromMarketPlace, | |||
| }, ref) => { | |||
| ref, | |||
| }: Props) => { | |||
| const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) | |||
| // DSL has id, to get plugin info to show more info | |||
| const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => { | |||
| @@ -268,4 +269,4 @@ const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({ | |||
| </> | |||
| ) | |||
| } | |||
| export default React.forwardRef(InstallByDSLList) | |||
| export default InstallByDSLList | |||
| @@ -82,9 +82,7 @@ const PluginTypeSwitch = ({ | |||
| }, [showSearchParams, handleActivePluginTypeChange]) | |||
| useEffect(() => { | |||
| window.addEventListener('popstate', () => { | |||
| handlePopState() | |||
| }) | |||
| window.addEventListener('popstate', handlePopState) | |||
| return () => { | |||
| window.removeEventListener('popstate', handlePopState) | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| 'use client' | |||
| import React, { forwardRef, useEffect, useImperativeHandle, useMemo, useRef } from 'react' | |||
| import React, { useEffect, useImperativeHandle, useMemo, useRef } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll' | |||
| import Item from './item' | |||
| @@ -17,18 +17,20 @@ export type ListProps = { | |||
| tags: string[] | |||
| toolContentClassName?: string | |||
| disableMaxWidth?: boolean | |||
| ref?: React.Ref<ListRef> | |||
| } | |||
| export type ListRef = { handleScroll: () => void } | |||
| const List = forwardRef<ListRef, ListProps>(({ | |||
| const List = ({ | |||
| wrapElemRef, | |||
| searchText, | |||
| tags, | |||
| list, | |||
| toolContentClassName, | |||
| disableMaxWidth = false, | |||
| }, ref) => { | |||
| ref, | |||
| }: ListProps) => { | |||
| const { t } = useTranslation() | |||
| const hasFilter = !searchText | |||
| const hasRes = list.length > 0 | |||
| @@ -125,7 +127,7 @@ const List = forwardRef<ListRef, ListProps>(({ | |||
| </div> | |||
| </> | |||
| ) | |||
| }) | |||
| } | |||
| List.displayName = 'List' | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noLastRunFound: 'Kein vorheriger Lauf gefunden', | |||
| lastOutput: 'Letzte Ausgabe', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Aktuelle gespeicherte Version exportieren', | |||
| exportWarningDesc: 'Dies wird die derzeit gespeicherte Version Ihres Workflows exportieren. Wenn Sie ungespeicherte Änderungen im Editor haben, speichern Sie diese bitte zuerst, indem Sie die Exportoption im Workflow-Canvas verwenden.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -139,6 +139,10 @@ const translation = { | |||
| export: 'Export DSL with secret values ', | |||
| }, | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Export Current Saved Version', | |||
| exportWarningDesc: 'This will export the current saved version of your workflow. If you have unsaved changes in the editor, please save them first by using the export option in the workflow canvas.', | |||
| }, | |||
| chatVariable: { | |||
| panelTitle: 'Conversation Variables', | |||
| panelDescription: 'Conversation Variables are used to store interactive information that LLM needs to remember, including conversation history, uploaded files, user preferences. They are read-write. ', | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noMatchingInputsFound: 'No se encontraron entradas coincidentes de la última ejecución.', | |||
| lastOutput: 'Última salida', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Exportar la versión guardada actual', | |||
| exportWarningDesc: 'Esto exportará la versión guardada actual de tu flujo de trabajo. Si tienes cambios no guardados en el editor, guárdalos primero utilizando la opción de exportar en el lienzo del flujo de trabajo.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRunError: 'نتوانستم ورودیهای آخرین اجرای را کپی کنم', | |||
| lastOutput: 'آخرین خروجی', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'صادرات نسخه ذخیره شده فعلی', | |||
| exportWarningDesc: 'این نسخه فعلی ذخیره شده از کار خود را صادر خواهد کرد. اگر تغییرات غیرذخیره شدهای در ویرایشگر دارید، لطفاً ابتدا از گزینه صادرات در بوم کار برای ذخیره آنها استفاده کنید.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRunError: 'Échec de la copie des entrées de la dernière exécution', | |||
| lastOutput: 'Dernière sortie', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Exporter la version enregistrée actuelle', | |||
| exportWarningDesc: 'Cela exportera la version actuelle enregistrée de votre flux de travail. Si vous avez des modifications non enregistrées dans l\'éditeur, veuillez d\'abord les enregistrer en utilisant l\'option d\'exportation dans le canevas du flux de travail.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1023,6 +1023,10 @@ const translation = { | |||
| copyLastRunError: 'अंतिम रन इनपुट को कॉपी करने में विफल', | |||
| lastOutput: 'अंतिम आउटपुट', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'वर्तमान सहेजी गई संस्करण निर्यात करें', | |||
| exportWarningDesc: 'यह आपके कार्यप्रवाह का वर्तमान सहेजा हुआ संस्करण निर्यात करेगा। यदि आपके संपादक में कोई असहेजा किए गए परिवर्तन हैं, तो कृपया पहले उन्हें सहेजें, कार्यप्रवाह कैनवास में निर्यात विकल्प का उपयोग करके।', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -967,6 +967,10 @@ const translation = { | |||
| lastOutput: 'Keluaran Terakhir', | |||
| noLastRunFound: 'Tidak ada eksekusi sebelumnya ditemukan', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Ekspor Versi Tersimpan Saat Ini', | |||
| exportWarningDesc: 'Ini akan mengekspor versi terkini dari alur kerja Anda yang telah disimpan. Jika Anda memiliki perubahan yang belum disimpan di editor, harap simpan terlebih dahulu dengan menggunakan opsi ekspor di kanvas alur kerja.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1029,6 +1029,10 @@ const translation = { | |||
| noLastRunFound: 'Nessuna esecuzione precedente trovata', | |||
| lastOutput: 'Ultimo output', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Esporta la versione salvata corrente', | |||
| exportWarningDesc: 'Questo exporterà l\'attuale versione salvata del tuo flusso di lavoro. Se hai modifiche non salvate nell\'editor, ti preghiamo di salvarle prima utilizzando l\'opzione di esportazione nel canvas del flusso di lavoro.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -139,6 +139,10 @@ const translation = { | |||
| export: 'シークレット値付きでエクスポート', | |||
| }, | |||
| }, | |||
| sidebar: { | |||
| exportWarning: '現在保存されているバージョンをエクスポート', | |||
| exportWarningDesc: 'これは現在保存されているワークフローのバージョンをエクスポートします。エディターで未保存の変更がある場合は、まずワークフローキャンバスのエクスポートオプションを使用して保存してください。', | |||
| }, | |||
| chatVariable: { | |||
| panelTitle: '会話変数', | |||
| panelDescription: '対話情報を保存・管理(会話履歴/ファイル/ユーザー設定など)。書き換えができます。', | |||
| @@ -1054,6 +1054,10 @@ const translation = { | |||
| copyLastRunError: '마지막 실행 입력을 복사하는 데 실패했습니다.', | |||
| lastOutput: '마지막 출력', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: '현재 저장된 버전 내보내기', | |||
| exportWarningDesc: '이 작업은 현재 저장된 워크플로우 버전을 내보냅니다. 편집기에서 저장되지 않은 변경 사항이 있는 경우, 먼저 워크플로우 캔버스의 내보내기 옵션을 사용하여 저장해 주세요.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRunError: 'Nie udało się skopiować danych wejściowych z ostatniego uruchomienia', | |||
| lastOutput: 'Ostatni wynik', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Eksportuj obecną zapisaną wersję', | |||
| exportWarningDesc: 'To wyeksportuje aktualnie zapisaną wersję twojego przepływu pracy. Jeśli masz niesave\'owane zmiany w edytorze, najpierw je zapisz, korzystając z opcji eksportu w kanwie przepływu pracy.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRun: 'Copiar Última Execução', | |||
| lastOutput: 'Última Saída', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Exportar a versão salva atual', | |||
| exportWarningDesc: 'Isto irá exportar a versão atual salva do seu fluxo de trabalho. Se você tiver alterações não salvas no editor, por favor, salve-as primeiro utilizando a opção de exportação na tela do fluxo de trabalho.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRunError: 'Nu s-au putut copia ultimele intrări de rulare', | |||
| lastOutput: 'Ultimul rezultat', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Exportați versiunea salvată curentă', | |||
| exportWarningDesc: 'Aceasta va exporta versiunea curent salvată a fluxului dumneavoastră de lucru. Dacă aveți modificări nesalvate în editor, vă rugăm să le salvați mai întâi utilizând opțiunea de export din canvasul fluxului de lucru.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noMatchingInputsFound: 'Не найдено соответствующих входных данных из последнего запуска.', | |||
| lastOutput: 'Последний вывод', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Экспортировать текущую сохранённую версию', | |||
| exportWarningDesc: 'Это экспортирует текущую сохранённую версию вашего рабочего процесса. Если у вас есть несохранённые изменения в редакторе, сначала сохраните их с помощью опции экспорта на полотне рабочего процесса.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noMatchingInputsFound: 'Ni podatkov, ki bi ustrezali prejšnjemu zagonu', | |||
| lastOutput: 'Nazadnje izhod', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Izvozi trenutna shranjena različica', | |||
| exportWarningDesc: 'To bo izvozilo trenutno shranjeno različico vašega delovnega toka. Če imate neshranjene spremembe v urejevalniku, jih najprej shranite z uporabo možnosti izvoza na platnu delovnega toka.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noMatchingInputsFound: 'ไม่พบข้อมูลที่ตรงกันจากการรันครั้งล่าสุด', | |||
| lastOutput: 'ผลลัพธ์สุดท้าย', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'ส่งออกเวอร์ชันที่บันทึกปัจจุบัน', | |||
| exportWarningDesc: 'นี่จะส่งออกเวอร์ชันที่บันทึกไว้ปัจจุบันของเวิร์กโฟลว์ของคุณ หากคุณมีการเปลี่ยนแปลงที่ยังไม่ได้บันทึกในแก้ไข กรุณาบันทึกมันก่อนโดยใช้ตัวเลือกส่งออกในผืนผ้าใบเวิร์กโฟลว์', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1004,6 +1004,10 @@ const translation = { | |||
| copyLastRunError: 'Son çalışma girdilerini kopyalamak başarısız oldu.', | |||
| lastOutput: 'Son Çıktı', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Mevcut Kaydedilmiş Versiyonu Dışa Aktar', | |||
| exportWarningDesc: 'Bu, çalışma akışınızın mevcut kaydedilmiş sürümünü dışa aktaracaktır. Editörde kaydedilmemiş değişiklikleriniz varsa, lütfen önce bunları çalışma akışı alanındaki dışa aktarma seçeneğini kullanarak kaydedin.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noMatchingInputsFound: 'Не знайдено відповідних вхідних даних з останнього запуску', | |||
| lastOutput: 'Останній вихід', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Експортувати поточну збережену версію', | |||
| exportWarningDesc: 'Це експортує поточну збережену версію вашого робочого процесу. Якщо у вас є незбережені зміни в редакторі, будь ласка, спочатку збережіть їх, використовуючи опцію експорту на полотні робочого процесу.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| copyLastRunError: 'Không thể sao chép đầu vào của lần chạy trước', | |||
| lastOutput: 'Đầu ra cuối cùng', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: 'Xuất Phiên Bản Đã Lưu Hiện Tại', | |||
| exportWarningDesc: 'Điều này sẽ xuất phiên bản hiện tại đã được lưu của quy trình làm việc của bạn. Nếu bạn có những thay đổi chưa được lưu trong trình soạn thảo, vui lòng lưu chúng trước bằng cách sử dụng tùy chọn xuất trong bản vẽ quy trình.', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -139,6 +139,10 @@ const translation = { | |||
| export: '导出包含 Secret 值的 DSL', | |||
| }, | |||
| }, | |||
| sidebar: { | |||
| exportWarning: '导出当前已保存版本', | |||
| exportWarningDesc: '这将导出您工作流的当前已保存版本。如果您在编辑器中有未保存的更改,请先使用工作流画布中的导出选项保存它们。', | |||
| }, | |||
| chatVariable: { | |||
| panelTitle: '会话变量', | |||
| panelDescription: '会话变量用于存储 LLM 需要的上下文信息,如用户偏好、对话历史等。它是可读写的。', | |||
| @@ -1003,6 +1003,10 @@ const translation = { | |||
| noLastRunFound: '沒有找到之前的運行', | |||
| lastOutput: '最後的輸出', | |||
| }, | |||
| sidebar: { | |||
| exportWarning: '導出當前保存的版本', | |||
| exportWarningDesc: '這將導出當前保存的工作流程版本。如果您在編輯器中有未保存的更改,請先通過使用工作流程畫布中的導出選項來保存它們。', | |||
| }, | |||
| } | |||
| export default translation | |||
| @@ -1 +0,0 @@ | |||
| (()=>{"use strict";self.fallback=async e=>"document"===e.destination?caches.match("/_offline.html",{ignoreSearch:!0}):Response.error()})(); | |||
| @@ -0,0 +1,84 @@ | |||
| import { DataType } from '@/app/components/datasets/metadata/types' | |||
| import { act, renderHook } from '@testing-library/react' | |||
| import { QueryClient, QueryClientProvider } from '@tanstack/react-query' | |||
| import { useBatchUpdateDocMetadata } from '@/service/knowledge/use-metadata' | |||
| import { useDocumentListKey } from './use-document' | |||
| // Mock the post function to avoid real network requests | |||
| jest.mock('@/service/base', () => ({ | |||
| post: jest.fn().mockResolvedValue({ success: true }), | |||
| })) | |||
| const NAME_SPACE = 'dataset-metadata' | |||
| describe('useBatchUpdateDocMetadata', () => { | |||
| let queryClient: QueryClient | |||
| beforeEach(() => { | |||
| // Create a fresh QueryClient before each test | |||
| queryClient = new QueryClient() | |||
| }) | |||
| // Wrapper for React Query context | |||
| const wrapper = ({ children }: { children: React.ReactNode }) => ( | |||
| <QueryClientProvider client={queryClient}>{children}</QueryClientProvider> | |||
| ) | |||
| it('should correctly invalidate dataset and document caches', async () => { | |||
| const { result } = renderHook(() => useBatchUpdateDocMetadata(), { wrapper }) | |||
| // Spy on queryClient.invalidateQueries | |||
| const invalidateSpy = jest.spyOn(queryClient, 'invalidateQueries') | |||
| // Correct payload type: each document has its own metadata_list array | |||
| const payload = { | |||
| dataset_id: 'dataset-1', | |||
| metadata_list: [ | |||
| { | |||
| document_id: 'doc-1', | |||
| metadata_list: [ | |||
| { key: 'title-1', id: '01', name: 'name-1', type: DataType.string, value: 'new title 01' }, | |||
| ], | |||
| }, | |||
| { | |||
| document_id: 'doc-2', | |||
| metadata_list: [ | |||
| { key: 'title-2', id: '02', name: 'name-1', type: DataType.string, value: 'new title 02' }, | |||
| ], | |||
| }, | |||
| ], | |||
| } | |||
| // Execute the mutation | |||
| await act(async () => { | |||
| await result.current.mutateAsync(payload) | |||
| }) | |||
| // Expect invalidateQueries to have been called exactly 5 times | |||
| expect(invalidateSpy).toHaveBeenCalledTimes(5) | |||
| // Dataset cache invalidation | |||
| expect(invalidateSpy).toHaveBeenNthCalledWith(1, { | |||
| queryKey: [NAME_SPACE, 'dataset', 'dataset-1'], | |||
| }) | |||
| // Document list cache invalidation | |||
| expect(invalidateSpy).toHaveBeenNthCalledWith(2, { | |||
| queryKey: [NAME_SPACE, 'document', 'dataset-1'], | |||
| }) | |||
| // useDocumentListKey cache invalidation | |||
| expect(invalidateSpy).toHaveBeenNthCalledWith(3, { | |||
| queryKey: [...useDocumentListKey, 'dataset-1'], | |||
| }) | |||
| // Single document cache invalidation | |||
| expect(invalidateSpy.mock.calls.slice(3)).toEqual( | |||
| expect.arrayContaining([ | |||
| [{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-1'] }], | |||
| [{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-2'] }], | |||
| ]), | |||
| ) | |||
| }) | |||
| }) | |||
| @@ -119,7 +119,7 @@ export const useBatchUpdateDocMetadata = () => { | |||
| }) | |||
| // meta data in document list | |||
| await queryClient.invalidateQueries({ | |||
| queryKey: [NAME_SPACE, 'dataset', payload.dataset_id], | |||
| queryKey: [NAME_SPACE, 'document', payload.dataset_id], | |||
| }) | |||
| await queryClient.invalidateQueries({ | |||
| queryKey: [...useDocumentListKey, payload.dataset_id], | |||