| !.vscode/launch.json.template | !.vscode/launch.json.template | ||||
| !.vscode/README.md | !.vscode/README.md | ||||
| api/.vscode | api/.vscode | ||||
| web/.vscode | |||||
| # vscode Code History Extension | # vscode Code History Extension | ||||
| .history | .history | ||||
| # Next.js build output | # Next.js build output | ||||
| .next/ | .next/ | ||||
| # PWA generated files | |||||
| web/public/sw.js | |||||
| web/public/sw.js.map | |||||
| web/public/workbox-*.js | |||||
| web/public/workbox-*.js.map | |||||
| web/public/fallback-*.js | |||||
| # AI Assistant | # AI Assistant | ||||
| .roo/ | .roo/ | ||||
| api/.env.backup | api/.env.backup |
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import ParamSpec, TypeVar | |||||
| from flask import request | from flask import request | ||||
| from flask_restx import Resource, reqparse | from flask_restx import Resource, reqparse | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from werkzeug.exceptions import NotFound, Unauthorized | from werkzeug.exceptions import NotFound, Unauthorized | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from constants.languages import supported_language | from constants.languages import supported_language | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from models.model import App, InstalledApp, RecommendedApp | from models.model import App, InstalledApp, RecommendedApp | ||||
| def admin_required(view): | |||||
| def admin_required(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| if not dify_config.ADMIN_API_KEY: | if not dify_config.ADMIN_API_KEY: | ||||
| raise Unauthorized("API key is invalid.") | raise Unauthorized("API key is invalid.") | ||||
| custom="max_keys_exceeded", | custom="max_keys_exceeded", | ||||
| ) | ) | ||||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||||
| key = ApiToken.generate_api_key(self.token_prefix or "", 24) | |||||
| api_token = ApiToken() | api_token = ApiToken() | ||||
| setattr(api_token, self.resource_id_field, resource_id) | setattr(api_token, self.resource_id_field, resource_id) | ||||
| api_token.tenant_id = current_user.current_tenant_id | api_token.tenant_id = current_user.current_tenant_id |
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import cast | |||||
| from typing import Concatenate, ParamSpec, TypeVar, cast | |||||
| import flask_login | import flask_login | ||||
| from flask import jsonify, request | from flask import jsonify, request | ||||
| from .. import api | from .. import api | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| T = TypeVar("T") | |||||
| def oauth_server_client_id_required(view): | |||||
| def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(self: T, *args: P.args, **kwargs: P.kwargs): | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("client_id", type=str, required=True, location="json") | parser.add_argument("client_id", type=str, required=True, location="json") | ||||
| parsed_args = parser.parse_args() | parsed_args = parser.parse_args() | ||||
| if not oauth_provider_app: | if not oauth_provider_app: | ||||
| raise NotFound("client_id is invalid") | raise NotFound("client_id is invalid") | ||||
| kwargs["oauth_provider_app"] = oauth_provider_app | |||||
| return view(*args, **kwargs) | |||||
| return view(self, oauth_provider_app, *args, **kwargs) | |||||
| return decorated | return decorated | ||||
| def oauth_server_access_token_required(view): | |||||
| def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| oauth_provider_app = kwargs.get("oauth_provider_app") | |||||
| if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): | |||||
| def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs): | |||||
| if not isinstance(oauth_provider_app, OAuthProviderApp): | |||||
| raise BadRequest("Invalid oauth_provider_app") | raise BadRequest("Invalid oauth_provider_app") | ||||
| authorization_header = request.headers.get("Authorization") | authorization_header = request.headers.get("Authorization") | ||||
| response.headers["WWW-Authenticate"] = "Bearer" | response.headers["WWW-Authenticate"] = "Bearer" | ||||
| return response | return response | ||||
| kwargs["account"] = account | |||||
| return view(*args, **kwargs) | |||||
| return view(self, oauth_provider_app, account, *args, **kwargs) | |||||
| return decorated | return decorated | ||||
| from flask_login import current_user | |||||
| from flask_restx import Resource, reqparse | from flask_restx import Resource, reqparse | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required | from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required | ||||
| from libs.login import login_required | |||||
| from libs.login import current_user, login_required | |||||
| from models.model import Account | |||||
| from services.billing_service import BillingService | from services.billing_service import BillingService | ||||
| parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | ||||
| parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| assert isinstance(current_user, Account) | |||||
| BillingService.is_tenant_owner_or_admin(current_user) | BillingService.is_tenant_owner_or_admin(current_user) | ||||
| assert current_user.current_tenant_id is not None | |||||
| return BillingService.get_subscription( | return BillingService.get_subscription( | ||||
| args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | ||||
| ) | ) | ||||
| @account_initialization_required | @account_initialization_required | ||||
| @only_edition_cloud | @only_edition_cloud | ||||
| def get(self): | def get(self): | ||||
| assert isinstance(current_user, Account) | |||||
| BillingService.is_tenant_owner_or_admin(current_user) | BillingService.is_tenant_owner_or_admin(current_user) | ||||
| assert current_user.current_tenant_id is not None | |||||
| return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | ||||
| data_source_info = document.data_source_info_dict | data_source_info = document.data_source_info_dict | ||||
| if document.data_source_type == "upload_file": | if document.data_source_type == "upload_file": | ||||
| if not data_source_info: | |||||
| continue | |||||
| file_id = data_source_info["upload_file_id"] | file_id = data_source_info["upload_file_id"] | ||||
| file_detail = ( | file_detail = ( | ||||
| db.session.query(UploadFile) | db.session.query(UploadFile) | ||||
| extract_settings.append(extract_setting) | extract_settings.append(extract_setting) | ||||
| elif document.data_source_type == "notion_import": | elif document.data_source_type == "notion_import": | ||||
| if not data_source_info: | |||||
| continue | |||||
| extract_setting = ExtractSetting( | extract_setting = ExtractSetting( | ||||
| datasource_type=DatasourceType.NOTION.value, | datasource_type=DatasourceType.NOTION.value, | ||||
| notion_info={ | notion_info={ | ||||
| ) | ) | ||||
| extract_settings.append(extract_setting) | extract_settings.append(extract_setting) | ||||
| elif document.data_source_type == "website_crawl": | elif document.data_source_type == "website_crawl": | ||||
| if not data_source_info: | |||||
| continue | |||||
| extract_setting = ExtractSetting( | extract_setting = ExtractSetting( | ||||
| datasource_type=DatasourceType.WEBSITE.value, | datasource_type=DatasourceType.WEBSITE.value, | ||||
| website_info={ | website_info={ |
| def get(self, installed_app: InstalledApp): | def get(self, installed_app: InstalledApp): | ||||
| """Get app meta""" | """Get app meta""" | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise ValueError("App not found") | |||||
| return AppService().get_app_meta(app_model) | return AppService().get_app_meta(app_model) | ||||
| Run workflow | Run workflow | ||||
| """ | """ | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise NotWorkflowAppError() | |||||
| app_mode = AppMode.value_of(app_model.mode) | app_mode = AppMode.value_of(app_model.mode) | ||||
| if app_mode != AppMode.WORKFLOW: | if app_mode != AppMode.WORKFLOW: | ||||
| raise NotWorkflowAppError() | raise NotWorkflowAppError() | ||||
| Stop workflow task | Stop workflow task | ||||
| """ | """ | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise NotWorkflowAppError() | |||||
| app_mode = AppMode.value_of(app_model.mode) | app_mode = AppMode.value_of(app_model.mode) | ||||
| if app_mode != AppMode.WORKFLOW: | if app_mode != AppMode.WORKFLOW: | ||||
| raise NotWorkflowAppError() | raise NotWorkflowAppError() |
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import Concatenate, Optional, ParamSpec, TypeVar | |||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask_restx import Resource | from flask_restx import Resource | ||||
| from services.enterprise.enterprise_service import EnterpriseService | from services.enterprise.enterprise_service import EnterpriseService | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| T = TypeVar("T") | |||||
| def installed_app_required(view=None): | |||||
| def decorator(view): | |||||
| @wraps(view) | |||||
| def decorated(*args, **kwargs): | |||||
| if not kwargs.get("installed_app_id"): | |||||
| raise ValueError("missing installed_app_id in path parameters") | |||||
| installed_app_id = kwargs.get("installed_app_id") | |||||
| installed_app_id = str(installed_app_id) | |||||
| del kwargs["installed_app_id"] | |||||
| def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||||
| @wraps(view) | |||||
| def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): | |||||
| installed_app = ( | installed_app = ( | ||||
| db.session.query(InstalledApp) | db.session.query(InstalledApp) | ||||
| .where( | .where( | ||||
| return decorator | return decorator | ||||
| def user_allowed_to_access_app(view=None): | |||||
| def decorator(view): | |||||
| def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None): | |||||
| def decorator(view: Callable[Concatenate[InstalledApp, P], R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(installed_app: InstalledApp, *args, **kwargs): | |||||
| def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): | |||||
| feature = FeatureService.get_system_features() | feature = FeatureService.get_system_features() | ||||
| if feature.webapp_auth.enabled: | if feature.webapp_auth.enabled: | ||||
| app_id = installed_app.app_id | app_id = installed_app.app_id |
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import ParamSpec, TypeVar | |||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import TenantPluginPermission | from models.account import TenantPluginPermission | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| def plugin_permission_required( | def plugin_permission_required( | ||||
| install_required: bool = False, | install_required: bool = False, | ||||
| debug_required: bool = False, | debug_required: bool = False, | ||||
| ): | ): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| user = current_user | user = current_user | ||||
| tenant_id = user.current_tenant_id | tenant_id = user.current_tenant_id | ||||
| import json | import json | ||||
| import os | import os | ||||
| import time | import time | ||||
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import ParamSpec, TypeVar | |||||
| from flask import abort, request | from flask import abort, request | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout | from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| def account_initialization_required(view): | |||||
| def account_initialization_required(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| # check account initialization | # check account initialization | ||||
| account = current_user | account = current_user | ||||
| return decorated | return decorated | ||||
| def only_edition_cloud(view): | |||||
| def only_edition_cloud(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| if dify_config.EDITION != "CLOUD": | if dify_config.EDITION != "CLOUD": | ||||
| abort(404) | abort(404) | ||||
| return decorated | return decorated | ||||
| def only_edition_enterprise(view): | |||||
| def only_edition_enterprise(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| if not dify_config.ENTERPRISE_ENABLED: | if not dify_config.ENTERPRISE_ENABLED: | ||||
| abort(404) | abort(404) | ||||
| return decorated | return decorated | ||||
| def only_edition_self_hosted(view): | |||||
| def only_edition_self_hosted(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| if dify_config.EDITION != "SELF_HOSTED": | if dify_config.EDITION != "SELF_HOSTED": | ||||
| abort(404) | abort(404) | ||||
| return decorated | return decorated | ||||
| def cloud_edition_billing_enabled(view): | |||||
| def cloud_edition_billing_enabled(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if not features.billing.enabled: | if not features.billing.enabled: | ||||
| abort(403, "Billing feature is not enabled.") | abort(403, "Billing feature is not enabled.") | ||||
| def cloud_edition_billing_resource_check(resource: str): | def cloud_edition_billing_resource_check(resource: str): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| members = features.members | members = features.members | ||||
| def cloud_edition_billing_knowledge_limit_check(resource: str): | def cloud_edition_billing_knowledge_limit_check(resource: str): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| if resource == "add_segment": | if resource == "add_segment": | ||||
| def cloud_edition_billing_rate_limit_check(resource: str): | def cloud_edition_billing_rate_limit_check(resource: str): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| if resource == "knowledge": | if resource == "knowledge": | ||||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) | ||||
| if knowledge_rate_limit.enabled: | if knowledge_rate_limit.enabled: | ||||
| return interceptor | return interceptor | ||||
| def cloud_utm_record(view): | |||||
| def cloud_utm_record(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| with contextlib.suppress(Exception): | with contextlib.suppress(Exception): | ||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| return decorated | return decorated | ||||
| def setup_required(view): | |||||
| def setup_required(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| # check setup | # check setup | ||||
| if ( | if ( | ||||
| dify_config.EDITION == "SELF_HOSTED" | dify_config.EDITION == "SELF_HOSTED" | ||||
| return decorated | return decorated | ||||
| def enterprise_license_required(view): | |||||
| def enterprise_license_required(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| settings = FeatureService.get_system_features() | settings = FeatureService.get_system_features() | ||||
| if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: | if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: | ||||
| raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") | raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") | ||||
| return decorated | return decorated | ||||
| def email_password_login_enabled(view): | |||||
| def email_password_login_enabled(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_system_features() | features = FeatureService.get_system_features() | ||||
| if features.enable_email_password_login: | if features.enable_email_password_login: | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) | ||||
| return decorated | return decorated | ||||
| def enable_change_email(view): | |||||
| def enable_change_email(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_system_features() | features = FeatureService.get_system_features() | ||||
| if features.enable_change_email: | if features.enable_change_email: | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) | ||||
| return decorated | return decorated | ||||
| def is_allow_transfer_owner(view): | |||||
| def is_allow_transfer_owner(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if features.is_allow_transfer_workspace: | if features.is_allow_transfer_workspace: | ||||
| return view(*args, **kwargs) | return view(*args, **kwargs) |
| from datetime import timedelta | from datetime import timedelta | ||||
| from enum import StrEnum, auto | from enum import StrEnum, auto | ||||
| from functools import wraps | from functools import wraps | ||||
| from typing import Optional | |||||
| from typing import Optional, ParamSpec, TypeVar | |||||
| from flask import current_app, request | from flask import current_app, request | ||||
| from flask_login import user_logged_in | from flask_login import user_logged_in | ||||
| from models.model import ApiToken, App, EndUser | from models.model import ApiToken, App, EndUser | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| class WhereisUserArg(StrEnum): | class WhereisUserArg(StrEnum): | ||||
| """ | """ | ||||
| if tenant.status == TenantStatus.ARCHIVE: | if tenant.status == TenantStatus.ARCHIVE: | ||||
| raise Forbidden("The workspace's status is archived.") | raise Forbidden("The workspace's status is archived.") | ||||
| tenant_account_join = ( | |||||
| db.session.query(Tenant, TenantAccountJoin) | |||||
| .where(Tenant.id == api_token.tenant_id) | |||||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||||
| .where(TenantAccountJoin.role.in_(["owner"])) | |||||
| .where(Tenant.status == TenantStatus.NORMAL) | |||||
| .one_or_none() | |||||
| ) # TODO: only owner information is required, so only one is returned. | |||||
| if tenant_account_join: | |||||
| tenant, ta = tenant_account_join | |||||
| account = db.session.query(Account).where(Account.id == ta.account_id).first() | |||||
| # Login admin | |||||
| if account: | |||||
| account.current_tenant = tenant | |||||
| current_app.login_manager._update_request_context_with_user(account) # type: ignore | |||||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore | |||||
| else: | |||||
| raise Unauthorized("Tenant owner account does not exist.") | |||||
| else: | |||||
| raise Unauthorized("Tenant does not exist.") | |||||
| kwargs["app_model"] = app_model | kwargs["app_model"] = app_model | ||||
| if fetch_user_arg: | if fetch_user_arg: | ||||
| def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | def cloud_edition_billing_resource_check(resource: str, api_token_type: str): | ||||
| def interceptor(view): | |||||
| def decorated(*args, **kwargs): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| api_token = validate_and_get_api_token(api_token_type) | api_token = validate_and_get_api_token(api_token_type) | ||||
| features = FeatureService.get_features(api_token.tenant_id) | features = FeatureService.get_features(api_token.tenant_id) | ||||
| def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): | def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| api_token = validate_and_get_api_token(api_token_type) | api_token = validate_and_get_api_token(api_token_type) | ||||
| features = FeatureService.get_features(api_token.tenant_id) | features = FeatureService.get_features(api_token.tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): | ||||
| def interceptor(view): | |||||
| def interceptor(view: Callable[P, R]): | |||||
| @wraps(view) | @wraps(view) | ||||
| def decorated(*args, **kwargs): | |||||
| def decorated(*args: P.args, **kwargs: P.kwargs): | |||||
| api_token = validate_and_get_api_token(api_token_type) | api_token = validate_and_get_api_token(api_token_type) | ||||
| if resource == "knowledge": | if resource == "knowledge": |
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from functools import wraps | from functools import wraps | ||||
| from typing import ParamSpec, TypeVar | |||||
| from flask import request | from flask import request | ||||
| from flask_restx import Resource | from flask_restx import Resource | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| from services.webapp_auth_service import WebAppAuthService | from services.webapp_auth_service import WebAppAuthService | ||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| def validate_jwt_token(view=None): | def validate_jwt_token(view=None): | ||||
| def decorator(view): | def decorator(view): |
| raise MessageNotExistsError() | raise MessageNotExistsError() | ||||
| current_app_model_config = app_model.app_model_config | current_app_model_config = app_model.app_model_config | ||||
| if not current_app_model_config: | |||||
| raise MoreLikeThisDisabledError() | |||||
| more_like_this = current_app_model_config.more_like_this_dict | more_like_this = current_app_model_config.more_like_this_dict | ||||
| if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: | if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: |
| messages = list(reversed(thread_messages)) | messages = list(reversed(thread_messages)) | ||||
| curr_message_tokens = 0 | |||||
| prompt_messages: list[PromptMessage] = [] | prompt_messages: list[PromptMessage] = [] | ||||
| for message in messages: | for message in messages: | ||||
| # Process user message with files | # Process user message with files |
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| from typing import ParamSpec, TypeVar | |||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| class MatrixoneConfig(BaseModel): | class MatrixoneConfig(BaseModel): |
| last_edited_time = self.get_notion_last_edited_time() | last_edited_time = self.get_notion_last_edited_time() | ||||
| data_source_info = document_model.data_source_info_dict | data_source_info = document_model.data_source_info_dict | ||||
| data_source_info["last_edited_time"] = last_edited_time | |||||
| if data_source_info: | |||||
| data_source_info["last_edited_time"] = last_edited_time | |||||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update( | db.session.query(DocumentModel).filter_by(id=document_model.id).update( | ||||
| {DocumentModel.data_source_info: json.dumps(data_source_info)} | {DocumentModel.data_source_info: json.dumps(data_source_info)} |
| import json | import json | ||||
| from typing import Any, Optional | |||||
| from typing import Any, Optional, Self | |||||
| from core.mcp.types import Tool as RemoteMCPTool | from core.mcp.types import Tool as RemoteMCPTool | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| return ToolProviderType.MCP | return ToolProviderType.MCP | ||||
| @classmethod | @classmethod | ||||
| def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": | |||||
| def from_db(cls, db_provider: MCPToolProvider) -> Self: | |||||
| """ | """ | ||||
| from db provider | from db provider | ||||
| """ | """ |
| if provider is None: | if provider is None: | ||||
| raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | ||||
| controller = MCPToolProviderController._from_db(provider) | |||||
| controller = MCPToolProviderController.from_db(provider) | |||||
| return controller | return controller | ||||
| tenant_id: str, | tenant_id: str, | ||||
| provider_type: ToolProviderType, | provider_type: ToolProviderType, | ||||
| provider_id: str, | provider_id: str, | ||||
| ) -> Union[str, dict]: | |||||
| ) -> Union[str, dict[str, Any]]: | |||||
| """ | """ | ||||
| get the tool icon | get the tool icon | ||||
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | from typing import TYPE_CHECKING, Any, Optional, Union, cast | ||||
| from core.variables import ArrayVariable, IntegerVariable, NoneVariable | |||||
| from core.variables import IntegerVariable, NoneSegment | |||||
| from core.variables.segments import ArrayAnySegment, ArraySegment | from core.variables.segments import ArrayAnySegment, ArraySegment | ||||
| from core.workflow.entities import VariablePool | from core.workflow.entities import VariablePool | ||||
| from core.workflow.enums import ( | from core.workflow.enums import ( | ||||
| if not variable: | if not variable: | ||||
| raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") | raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") | ||||
| if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): | |||||
| if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): | |||||
| raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") | raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") | ||||
| if isinstance(variable, NoneVariable) or len(variable.value) == 0: | |||||
| if isinstance(variable, NoneSegment) or len(variable.value) == 0: | |||||
| # Try our best to preserve the type informat. | # Try our best to preserve the type informat. | ||||
| if isinstance(variable, ArraySegment): | if isinstance(variable, ArraySegment): | ||||
| output = variable.model_copy(update={"value": []}) | output = variable.model_copy(update={"value": []}) |
| ) | ) | ||||
| from .prompts import ( | from .prompts import ( | ||||
| CHAT_EXAMPLE, | CHAT_EXAMPLE, | ||||
| CHAT_GENERATE_JSON_PROMPT, | |||||
| CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, | CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, | ||||
| COMPLETION_GENERATE_JSON_PROMPT, | COMPLETION_GENERATE_JSON_PROMPT, | ||||
| FUNCTION_CALLING_EXTRACTOR_EXAMPLE, | FUNCTION_CALLING_EXTRACTOR_EXAMPLE, | ||||
| if model_mode == ModelMode.CHAT: | if model_mode == ModelMode.CHAT: | ||||
| system_prompt_messages = ChatModelMessage( | system_prompt_messages = ChatModelMessage( | ||||
| role=PromptMessageRole.SYSTEM, | role=PromptMessageRole.SYSTEM, | ||||
| text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), | |||||
| text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction), | |||||
| ) | ) | ||||
| user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) | user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) | ||||
| return [system_prompt_messages, user_prompt_message] | return [system_prompt_messages, user_prompt_message] |
| from collections.abc import Callable | |||||
| from functools import wraps | from functools import wraps | ||||
| from typing import Union, cast | from typing import Union, cast | ||||
| #: A proxy for the current user. If no user is logged in, this will be an | #: A proxy for the current user. If no user is logged in, this will be an | ||||
| #: anonymous user | #: anonymous user | ||||
| current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) | current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) | ||||
| from typing import ParamSpec, TypeVar | |||||
| P = ParamSpec("P") | |||||
| R = TypeVar("R") | |||||
| def login_required(func): | |||||
| def login_required(func: Callable[P, R]): | |||||
| """ | """ | ||||
| If you decorate a view with this, it will ensure that the current user is | If you decorate a view with this, it will ensure that the current user is | ||||
| logged in and authenticated before calling the actual view. (If they are | logged in and authenticated before calling the actual view. (If they are | ||||
| """ | """ | ||||
| @wraps(func) | @wraps(func) | ||||
| def decorated_view(*args, **kwargs): | |||||
| def decorated_view(*args: P.args, **kwargs: P.kwargs): | |||||
| if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: | ||||
| pass | pass | ||||
| elif current_user is not None and not current_user.is_authenticated: | elif current_user is not None and not current_user.is_authenticated: | ||||
| return current_app.login_manager.unauthorized() # type: ignore | return current_app.login_manager.unauthorized() # type: ignore | ||||
| # flask 1.x compatibility | |||||
| # current_app.ensure_sync is only available in Flask >= 2.0 | |||||
| if callable(getattr(current_app, "ensure_sync", None)): | |||||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||||
| return func(*args, **kwargs) | |||||
| return current_app.ensure_sync(func)(*args, **kwargs) | |||||
| return decorated_view | return decorated_view | ||||
| import enum | import enum | ||||
| import json | import json | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Optional | |||||
| from typing import Any, Optional | |||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from flask_login import UserMixin | |||||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||||
| from sqlalchemy import DateTime, String, func, select | from sqlalchemy import DateTime, String, func, select | ||||
| from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | ||||
| ) | ) | ||||
| @property | @property | ||||
| def custom_config_dict(self): | |||||
| def custom_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.custom_config) if self.custom_config else {} | return json.loads(self.custom_config) if self.custom_config else {} | ||||
| @custom_config_dict.setter | @custom_config_dict.setter | ||||
| def custom_config_dict(self, value: dict): | |||||
| def custom_config_dict(self, value: dict[str, Any]) -> None: | |||||
| self.custom_config = json.dumps(value) | self.custom_config = json.dumps(value) | ||||
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | ||||
| } | } | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "dataset_id": self.dataset_id, | "dataset_id": self.dataset_id, | ||||
| } | } | ||||
| @property | @property | ||||
| def rules_dict(self): | |||||
| def rules_dict(self) -> dict[str, Any] | None: | |||||
| try: | try: | ||||
| return json.loads(self.rules) if self.rules else None | return json.loads(self.rules) if self.rules else None | ||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| return status | return status | ||||
| @property | @property | ||||
| def data_source_info_dict(self): | |||||
| def data_source_info_dict(self) -> dict[str, Any] | None: | |||||
| if self.data_source_info: | if self.data_source_info: | ||||
| try: | try: | ||||
| data_source_info_dict = json.loads(self.data_source_info) | |||||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| data_source_info_dict = {} | data_source_info_dict = {} | ||||
| return None | return None | ||||
| @property | @property | ||||
| def data_source_detail_dict(self): | |||||
| def data_source_detail_dict(self) -> dict[str, Any]: | |||||
| if self.data_source_info: | if self.data_source_info: | ||||
| if self.data_source_type == "upload_file": | if self.data_source_type == "upload_file": | ||||
| data_source_info_dict = json.loads(self.data_source_info) | |||||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||||
| file_detail = ( | file_detail = ( | ||||
| db.session.query(UploadFile) | db.session.query(UploadFile) | ||||
| .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | ||||
| } | } | ||||
| } | } | ||||
| elif self.data_source_type in {"notion_import", "website_crawl"}: | elif self.data_source_type in {"notion_import", "website_crawl"}: | ||||
| return json.loads(self.data_source_info) | |||||
| result: dict[str, Any] = json.loads(self.data_source_info) | |||||
| return result | |||||
| return {} | return {} | ||||
| @property | @property | ||||
| return self.updated_at | return self.updated_at | ||||
| @property | @property | ||||
| def doc_metadata_details(self): | |||||
| def doc_metadata_details(self) -> list[dict[str, Any]] | None: | |||||
| if self.doc_metadata: | if self.doc_metadata: | ||||
| document_metadatas = ( | document_metadatas = ( | ||||
| db.session.query(DatasetMetadata) | db.session.query(DatasetMetadata) | ||||
| ) | ) | ||||
| .all() | .all() | ||||
| ) | ) | ||||
| metadata_list = [] | |||||
| metadata_list: list[dict[str, Any]] = [] | |||||
| for metadata in document_metadatas: | for metadata in document_metadatas: | ||||
| metadata_dict = { | |||||
| metadata_dict: dict[str, Any] = { | |||||
| "id": metadata.id, | "id": metadata.id, | ||||
| "name": metadata.name, | "name": metadata.name, | ||||
| "type": metadata.type, | "type": metadata.type, | ||||
| return None | return None | ||||
| @property | @property | ||||
| def process_rule_dict(self): | |||||
| if self.dataset_process_rule_id: | |||||
| def process_rule_dict(self) -> dict[str, Any] | None: | |||||
| if self.dataset_process_rule_id and self.dataset_process_rule: | |||||
| return self.dataset_process_rule.to_dict() | return self.dataset_process_rule.to_dict() | ||||
| return None | return None | ||||
| def get_built_in_fields(self): | |||||
| built_in_fields = [] | |||||
| def get_built_in_fields(self) -> list[dict[str, Any]]: | |||||
| built_in_fields: list[dict[str, Any]] = [] | |||||
| built_in_fields.append( | built_in_fields.append( | ||||
| { | { | ||||
| "id": "built-in", | "id": "built-in", | ||||
| ) | ) | ||||
| return built_in_fields | return built_in_fields | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "tenant_id": self.tenant_id, | "tenant_id": self.tenant_id, | ||||
| "data_source_info_dict": self.data_source_info_dict, | "data_source_info_dict": self.data_source_info_dict, | ||||
| "average_segment_length": self.average_segment_length, | "average_segment_length": self.average_segment_length, | ||||
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | ||||
| "dataset": self.dataset.to_dict() if self.dataset else None, | |||||
| "dataset": None, # Dataset class doesn't have a to_dict method | |||||
| "segment_count": self.segment_count, | "segment_count": self.segment_count, | ||||
| "hit_count": self.hit_count, | "hit_count": self.hit_count, | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict): | |||||
| def from_dict(cls, data: dict[str, Any]): | |||||
| return cls( | return cls( | ||||
| id=data.get("id"), | id=data.get("id"), | ||||
| tenant_id=data.get("tenant_id"), | tenant_id=data.get("tenant_id"), | ||||
| ) | ) | ||||
| @property | @property | ||||
| def child_chunks(self): | |||||
| process_rule = self.document.dataset_process_rule | |||||
| if process_rule.mode == "hierarchical": | |||||
| rules = Rule(**process_rule.rules_dict) | |||||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| else: | |||||
| return [] | |||||
| else: | |||||
| def child_chunks(self) -> list[Any]: | |||||
| if not self.document: | |||||
| return [] | return [] | ||||
| def get_child_chunks(self): | |||||
| process_rule = self.document.dataset_process_rule | process_rule = self.document.dataset_process_rule | ||||
| if process_rule.mode == "hierarchical": | |||||
| rules = Rule(**process_rule.rules_dict) | |||||
| if rules.parent_mode: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| else: | |||||
| return [] | |||||
| else: | |||||
| if process_rule and process_rule.mode == "hierarchical": | |||||
| rules_dict = process_rule.rules_dict | |||||
| if rules_dict: | |||||
| rules = Rule(**rules_dict) | |||||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| return [] | |||||
| def get_child_chunks(self) -> list[Any]: | |||||
| if not self.document: | |||||
| return [] | return [] | ||||
| process_rule = self.document.dataset_process_rule | |||||
| if process_rule and process_rule.mode == "hierarchical": | |||||
| rules_dict = process_rule.rules_dict | |||||
| if rules_dict: | |||||
| rules = Rule(**rules_dict) | |||||
| if rules.parent_mode: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| return [] | |||||
| @property | @property | ||||
| def sign_content(self): | |||||
| def sign_content(self) -> str: | |||||
| return self.get_sign_content() | return self.get_sign_content() | ||||
| def get_sign_content(self): | |||||
| signed_urls = [] | |||||
| def get_sign_content(self) -> str: | |||||
| signed_urls: list[tuple[int, int, str]] = [] | |||||
| text = self.content | text = self.content | ||||
| # For data before v0.10.0 | # For data before v0.10.0 | ||||
| ) | ) | ||||
| @property | @property | ||||
| def keyword_table_dict(self): | |||||
| def keyword_table_dict(self) -> dict[str, set[Any]] | None: | |||||
| class SetDecoder(json.JSONDecoder): | class SetDecoder(json.JSONDecoder): | ||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||||
| def object_hook(self, dct): | |||||
| if isinstance(dct, dict): | |||||
| for keyword, node_idxs in dct.items(): | |||||
| if isinstance(node_idxs, list): | |||||
| dct[keyword] = set(node_idxs) | |||||
| return dct | |||||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||||
| def object_hook(dct: Any) -> Any: | |||||
| if isinstance(dct, dict): | |||||
| result: dict[str, Any] = {} | |||||
| items = cast(dict[str, Any], dct).items() | |||||
| for keyword, node_idxs in items: | |||||
| if isinstance(node_idxs, list): | |||||
| result[keyword] = set(cast(list[Any], node_idxs)) | |||||
| else: | |||||
| result[keyword] = node_idxs | |||||
| return result | |||||
| return dct | |||||
| super().__init__(object_hook=object_hook, *args, **kwargs) | |||||
| # get dataset | # get dataset | ||||
| dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | ||||
| updated_by = mapped_column(StringUUID, nullable=True) | updated_by = mapped_column(StringUUID, nullable=True) | ||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "tenant_id": self.tenant_id, | "tenant_id": self.tenant_id, | ||||
| } | } | ||||
| @property | @property | ||||
| def settings_dict(self): | |||||
| def settings_dict(self) -> dict[str, Any] | None: | |||||
| try: | try: | ||||
| return json.loads(self.settings) if self.settings else None | return json.loads(self.settings) if self.settings else None | ||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| return None | return None | ||||
| @property | @property | ||||
| def dataset_bindings(self): | |||||
| def dataset_bindings(self) -> list[dict[str, Any]]: | |||||
| external_knowledge_bindings = ( | external_knowledge_bindings = ( | ||||
| db.session.query(ExternalKnowledgeBindings) | db.session.query(ExternalKnowledgeBindings) | ||||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | ||||
| ) | ) | ||||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | ||||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | ||||
| dataset_bindings = [] | |||||
| dataset_bindings: list[dict[str, Any]] = [] | |||||
| for dataset in datasets: | for dataset in datasets: | ||||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from flask import request | from flask import request | ||||
| from flask_login import UserMixin | |||||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | ||||
| from sqlalchemy.orm import Mapped, Session, mapped_column | from sqlalchemy.orm import Mapped, Session, mapped_column | ||||
| from core.file import helpers as file_helpers | from core.file import helpers as file_helpers | ||||
| from core.tools.signature import sign_tool_file | from core.tools.signature import sign_tool_file | ||||
| from core.workflow.enums import WorkflowExecutionStatus | from core.workflow.enums import WorkflowExecutionStatus | ||||
| from libs.helper import generate_string | |||||
| from libs.helper import generate_string # type: ignore[import-not-found] | |||||
| from .account import Account, Tenant | from .account import Account, Tenant | ||||
| from .base import Base | from .base import Base | ||||
| use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | ||||
| @property | @property | ||||
| def desc_or_prompt(self): | |||||
| def desc_or_prompt(self) -> str: | |||||
| if self.description: | if self.description: | ||||
| return self.description | return self.description | ||||
| else: | else: | ||||
| return "" | return "" | ||||
| @property | @property | ||||
| def site(self): | |||||
| def site(self) -> Optional["Site"]: | |||||
| site = db.session.query(Site).where(Site.app_id == self.id).first() | site = db.session.query(Site).where(Site.app_id == self.id).first() | ||||
| return site | return site | ||||
| @property | @property | ||||
| def app_model_config(self): | |||||
| def app_model_config(self) -> Optional["AppModelConfig"]: | |||||
| if self.app_model_config_id: | if self.app_model_config_id: | ||||
| return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | ||||
| return None | return None | ||||
| @property | @property | ||||
| def api_base_url(self): | |||||
| def api_base_url(self) -> str: | |||||
| return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | ||||
| @property | @property | ||||
| def tenant(self): | |||||
| def tenant(self) -> Optional[Tenant]: | |||||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| return tenant | return tenant | ||||
| return str(self.mode) | return str(self.mode) | ||||
| @property | @property | ||||
| def deleted_tools(self) -> list: | |||||
| from core.tools.entities.tool_entities import ToolProviderType | |||||
| from core.tools.tool_manager import ToolManager | |||||
| def deleted_tools(self) -> list[dict[str, str]]: | |||||
| from core.tools.tool_manager import ToolManager, ToolProviderType | |||||
| from services.plugin.plugin_service import PluginService | from services.plugin.plugin_service import PluginService | ||||
| # get agent mode tools | # get agent mode tools | ||||
| provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | ||||
| } | } | ||||
| deleted_tools = [] | |||||
| deleted_tools: list[dict[str, str]] = [] | |||||
| for tool in tools: | for tool in tools: | ||||
| keys = list(tool.keys()) | keys = list(tool.keys()) | ||||
| return deleted_tools | return deleted_tools | ||||
| @property | @property | ||||
| def tags(self): | |||||
| def tags(self) -> list["Tag"]: | |||||
| tags = ( | tags = ( | ||||
| db.session.query(Tag) | db.session.query(Tag) | ||||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | .join(TagBinding, Tag.id == TagBinding.tag_id) | ||||
| return tags or [] | return tags or [] | ||||
| @property | @property | ||||
| def author_name(self): | |||||
| def author_name(self) -> Optional[str]: | |||||
| if self.created_by: | if self.created_by: | ||||
| account = db.session.query(Account).where(Account.id == self.created_by).first() | account = db.session.query(Account).where(Account.id == self.created_by).first() | ||||
| if account: | if account: | ||||
| file_upload = mapped_column(sa.Text) | file_upload = mapped_column(sa.Text) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| @property | @property | ||||
| def model_dict(self): | |||||
| def model_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.model) if self.model else {} | return json.loads(self.model) if self.model else {} | ||||
| @property | @property | ||||
| def suggested_questions_list(self): | |||||
| def suggested_questions_list(self) -> list[str]: | |||||
| return json.loads(self.suggested_questions) if self.suggested_questions else [] | return json.loads(self.suggested_questions) if self.suggested_questions else [] | ||||
| @property | @property | ||||
| def suggested_questions_after_answer_dict(self): | |||||
| def suggested_questions_after_answer_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.suggested_questions_after_answer) | json.loads(self.suggested_questions_after_answer) | ||||
| if self.suggested_questions_after_answer | if self.suggested_questions_after_answer | ||||
| ) | ) | ||||
| @property | @property | ||||
| def speech_to_text_dict(self): | |||||
| def speech_to_text_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | ||||
| @property | @property | ||||
| def text_to_speech_dict(self): | |||||
| def text_to_speech_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | ||||
| @property | @property | ||||
| def retriever_resource_dict(self): | |||||
| def retriever_resource_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | ||||
| @property | @property | ||||
| def annotation_reply_dict(self): | |||||
| def annotation_reply_dict(self) -> dict[str, Any]: | |||||
| annotation_setting = ( | annotation_setting = ( | ||||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | ||||
| ) | ) | ||||
| return {"enabled": False} | return {"enabled": False} | ||||
| @property | @property | ||||
| def more_like_this_dict(self): | |||||
| def more_like_this_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | ||||
| @property | @property | ||||
| def sensitive_word_avoidance_dict(self): | |||||
| def sensitive_word_avoidance_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.sensitive_word_avoidance) | json.loads(self.sensitive_word_avoidance) | ||||
| if self.sensitive_word_avoidance | if self.sensitive_word_avoidance | ||||
| ) | ) | ||||
| @property | @property | ||||
| def external_data_tools_list(self) -> list[dict]: | |||||
| def external_data_tools_list(self) -> list[dict[str, Any]]: | |||||
| return json.loads(self.external_data_tools) if self.external_data_tools else [] | return json.loads(self.external_data_tools) if self.external_data_tools else [] | ||||
| @property | @property | ||||
| def user_input_form_list(self): | |||||
| def user_input_form_list(self) -> list[dict[str, Any]]: | |||||
| return json.loads(self.user_input_form) if self.user_input_form else [] | return json.loads(self.user_input_form) if self.user_input_form else [] | ||||
| @property | @property | ||||
| def agent_mode_dict(self): | |||||
| def agent_mode_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.agent_mode) | json.loads(self.agent_mode) | ||||
| if self.agent_mode | if self.agent_mode | ||||
| ) | ) | ||||
| @property | @property | ||||
| def chat_prompt_config_dict(self): | |||||
| def chat_prompt_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | ||||
| @property | @property | ||||
| def completion_prompt_config_dict(self): | |||||
| def completion_prompt_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | ||||
| @property | @property | ||||
| def dataset_configs_dict(self): | |||||
| def dataset_configs_dict(self) -> dict[str, Any]: | |||||
| if self.dataset_configs: | if self.dataset_configs: | ||||
| dataset_configs: dict = json.loads(self.dataset_configs) | |||||
| dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) | |||||
| if "retrieval_model" not in dataset_configs: | if "retrieval_model" not in dataset_configs: | ||||
| return {"retrieval_model": "single"} | return {"retrieval_model": "single"} | ||||
| else: | else: | ||||
| } | } | ||||
| @property | @property | ||||
| def file_upload_dict(self): | |||||
| def file_upload_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.file_upload) | json.loads(self.file_upload) | ||||
| if self.file_upload | if self.file_upload | ||||
| } | } | ||||
| ) | ) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "opening_statement": self.opening_statement, | "opening_statement": self.opening_statement, | ||||
| "suggested_questions": self.suggested_questions_list, | "suggested_questions": self.suggested_questions_list, | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| @property | @property | ||||
| def tenant(self): | |||||
| def tenant(self) -> Optional[Tenant]: | |||||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| return tenant | return tenant | ||||
| mode: Mapped[str] = mapped_column(String(255)) | mode: Mapped[str] = mapped_column(String(255)) | ||||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | name: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| summary = mapped_column(sa.Text) | summary = mapped_column(sa.Text) | ||||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||||
| introduction = mapped_column(sa.Text) | introduction = mapped_column(sa.Text) | ||||
| system_instruction = mapped_column(sa.Text) | system_instruction = mapped_column(sa.Text) | ||||
| system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | ||||
| is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | ||||
| @property | @property | ||||
| def inputs(self): | |||||
| def inputs(self) -> dict[str, Any]: | |||||
| inputs = self._inputs.copy() | inputs = self._inputs.copy() | ||||
| # Convert file mapping to File object | # Convert file mapping to File object | ||||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | ||||
| from factories import file_factory | from factories import file_factory | ||||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value["tool_file_id"] = value["related_id"] | |||||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value["upload_file_id"] = value["related_id"] | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||||
| elif isinstance(value, list) and all( | |||||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||||
| if ( | |||||
| isinstance(value, dict) | |||||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| ): | ): | ||||
| inputs[key] = [] | |||||
| for item in value: | |||||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item["tool_file_id"] = item["related_id"] | |||||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| item["upload_file_id"] = item["related_id"] | |||||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||||
| value_dict = cast(dict[str, Any], value) | |||||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||||
| elif isinstance(value, list): | |||||
| value_list = cast(list[Any], value) | |||||
| if all( | |||||
| isinstance(item, dict) | |||||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| for item in value_list | |||||
| ): | |||||
| file_list: list[File] = [] | |||||
| for item in value_list: | |||||
| if not isinstance(item, dict): | |||||
| continue | |||||
| item_dict = cast(dict[str, Any], item) | |||||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||||
| elif item_dict["transfer_method"] in [ | |||||
| FileTransferMethod.LOCAL_FILE, | |||||
| FileTransferMethod.REMOTE_URL, | |||||
| ]: | |||||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||||
| inputs[key] = file_list | |||||
| return inputs | return inputs | ||||
| for k, v in inputs.items(): | for k, v in inputs.items(): | ||||
| if isinstance(v, File): | if isinstance(v, File): | ||||
| inputs[k] = v.model_dump() | inputs[k] = v.model_dump() | ||||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||||
| inputs[k] = [item.model_dump() for item in v] | |||||
| elif isinstance(v, list): | |||||
| v_list = cast(list[Any], v) | |||||
| if all(isinstance(item, File) for item in v_list): | |||||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||||
| self._inputs = inputs | self._inputs = inputs | ||||
| @property | @property | ||||
| ) | ) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| return session.query(App).where(App.id == self.app_id).first() | return session.query(App).where(App.id == self.app_id).first() | ||||
| return None | return None | ||||
| @property | @property | ||||
| def from_account_name(self): | |||||
| def from_account_name(self) -> Optional[str]: | |||||
| if self.from_account_id: | if self.from_account_id: | ||||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | account = db.session.query(Account).where(Account.id == self.from_account_id).first() | ||||
| if account: | if account: | ||||
| return None | return None | ||||
| @property | @property | ||||
| def in_debug_mode(self): | |||||
| def in_debug_mode(self) -> bool: | |||||
| return self.override_model_configs is not None | return self.override_model_configs is not None | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, | ||||
| model_id = mapped_column(String(255), nullable=True) | model_id = mapped_column(String(255), nullable=True) | ||||
| override_model_configs = mapped_column(sa.Text) | override_model_configs = mapped_column(sa.Text) | ||||
| conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | ||||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||||
| query: Mapped[str] = mapped_column(sa.Text, nullable=False) | query: Mapped[str] = mapped_column(sa.Text, nullable=False) | ||||
| message = mapped_column(sa.JSON, nullable=False) | message = mapped_column(sa.JSON, nullable=False) | ||||
| message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | ||||
| workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | ||||
| @property | @property | ||||
| def inputs(self): | |||||
| def inputs(self) -> dict[str, Any]: | |||||
| inputs = self._inputs.copy() | inputs = self._inputs.copy() | ||||
| for key, value in inputs.items(): | for key, value in inputs.items(): | ||||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | ||||
| from factories import file_factory | from factories import file_factory | ||||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value["tool_file_id"] = value["related_id"] | |||||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value["upload_file_id"] = value["related_id"] | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||||
| elif isinstance(value, list) and all( | |||||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||||
| if ( | |||||
| isinstance(value, dict) | |||||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| ): | ): | ||||
| inputs[key] = [] | |||||
| for item in value: | |||||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item["tool_file_id"] = item["related_id"] | |||||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| item["upload_file_id"] = item["related_id"] | |||||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||||
| value_dict = cast(dict[str, Any], value) | |||||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||||
| elif isinstance(value, list): | |||||
| value_list = cast(list[Any], value) | |||||
| if all( | |||||
| isinstance(item, dict) | |||||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| for item in value_list | |||||
| ): | |||||
| file_list: list[File] = [] | |||||
| for item in value_list: | |||||
| if not isinstance(item, dict): | |||||
| continue | |||||
| item_dict = cast(dict[str, Any], item) | |||||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||||
| elif item_dict["transfer_method"] in [ | |||||
| FileTransferMethod.LOCAL_FILE, | |||||
| FileTransferMethod.REMOTE_URL, | |||||
| ]: | |||||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||||
| inputs[key] = file_list | |||||
| return inputs | return inputs | ||||
| @inputs.setter | @inputs.setter | ||||
| for k, v in inputs.items(): | for k, v in inputs.items(): | ||||
| if isinstance(v, File): | if isinstance(v, File): | ||||
| inputs[k] = v.model_dump() | inputs[k] = v.model_dump() | ||||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||||
| inputs[k] = [item.model_dump() for item in v] | |||||
| elif isinstance(v, list): | |||||
| v_list = cast(list[Any], v) | |||||
| if all(isinstance(item, File) for item in v_list): | |||||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||||
| self._inputs = inputs | self._inputs = inputs | ||||
| @property | @property | ||||
| return None | return None | ||||
| @property | @property | ||||
| def in_debug_mode(self): | |||||
| def in_debug_mode(self) -> bool: | |||||
| return self.override_model_configs is not None | return self.override_model_configs is not None | ||||
| @property | @property | ||||
| def message_metadata_dict(self): | |||||
| def message_metadata_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.message_metadata) if self.message_metadata else {} | return json.loads(self.message_metadata) if self.message_metadata else {} | ||||
| @property | @property | ||||
| def agent_thoughts(self): | |||||
| def agent_thoughts(self) -> list["MessageAgentThought"]: | |||||
| return ( | return ( | ||||
| db.session.query(MessageAgentThought) | db.session.query(MessageAgentThought) | ||||
| .where(MessageAgentThought.message_id == self.id) | .where(MessageAgentThought.message_id == self.id) | ||||
| ) | ) | ||||
| @property | @property | ||||
| def retriever_resources(self): | |||||
| def retriever_resources(self) -> Any | list[Any]: | |||||
| return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | ||||
| @property | @property | ||||
| def message_files(self): | |||||
| def message_files(self) -> list[dict[str, Any]]: | |||||
| from factories import file_factory | from factories import file_factory | ||||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | ||||
| if not current_app: | if not current_app: | ||||
| raise ValueError(f"App {self.app_id} not found") | raise ValueError(f"App {self.app_id} not found") | ||||
| files = [] | |||||
| files: list[File] = [] | |||||
| for message_file in message_files: | for message_file in message_files: | ||||
| if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | ||||
| if message_file.upload_file_id is None: | if message_file.upload_file_id is None: | ||||
| ) | ) | ||||
| files.append(file) | files.append(file) | ||||
| result = [ | |||||
| result: list[dict[str, Any]] = [ | |||||
| {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | ||||
| for (file, message_file) in zip(files, message_files) | for (file, message_file) in zip(files, message_files) | ||||
| ] | ] | ||||
| return None | return None | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict): | |||||
| def from_dict(cls, data: dict[str, Any]) -> "Message": | |||||
| return cls( | return cls( | ||||
| id=data["id"], | id=data["id"], | ||||
| app_id=data["app_id"], | app_id=data["app_id"], | ||||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | account = db.session.query(Account).where(Account.id == self.from_account_id).first() | ||||
| return account | return account | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": str(self.id), | "id": str(self.id), | ||||
| "app_id": str(self.app_id), | "app_id": str(self.app_id), | ||||
| type: Mapped[str] = mapped_column(String(255), nullable=False) | type: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| external_user_id = mapped_column(String(255), nullable=True) | external_user_id = mapped_column(String(255), nullable=True) | ||||
| name = mapped_column(String(255)) | name = mapped_column(String(255)) | ||||
| is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||||
| _is_anonymous: Mapped[bool] = mapped_column( | |||||
| "is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true") | |||||
| ) | |||||
| @property | |||||
| def is_anonymous(self) -> Literal[False]: | |||||
| return False | |||||
| @is_anonymous.setter | |||||
| def is_anonymous(self, value: bool) -> None: | |||||
| self._is_anonymous = value | |||||
| session_id: Mapped[str] = mapped_column() | session_id: Mapped[str] = mapped_column() | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_server_code(n): | |||||
| def generate_server_code(n: int) -> str: | |||||
| while True: | while True: | ||||
| result = generate_string(n) | result = generate_string(n) | ||||
| while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | ||||
| self._custom_disclaimer = value | self._custom_disclaimer = value | ||||
| @staticmethod | @staticmethod | ||||
| def generate_code(n): | |||||
| def generate_code(n: int) -> str: | |||||
| while True: | while True: | ||||
| result = generate_string(n) | result = generate_string(n) | ||||
| while db.session.query(Site).where(Site.code == result).count() > 0: | while db.session.query(Site).where(Site.code == result).count() > 0: | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_api_key(prefix, n): | |||||
| def generate_api_key(prefix: str, n: int) -> str: | |||||
| while True: | while True: | ||||
| result = prefix + generate_string(n) | result = prefix + generate_string(n) | ||||
| if db.session.scalar(select(exists().where(ApiToken.token == result))): | if db.session.scalar(select(exists().where(ApiToken.token == result))): | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | ||||
| @property | @property | ||||
| def files(self): | |||||
| def files(self) -> list[Any]: | |||||
| if self.message_files: | if self.message_files: | ||||
| return cast(list[Any], json.loads(self.message_files)) | return cast(list[Any], json.loads(self.message_files)) | ||||
| else: | else: | ||||
| return self.tool.split(";") if self.tool else [] | return self.tool.split(";") if self.tool else [] | ||||
| @property | @property | ||||
| def tool_labels(self): | |||||
| def tool_labels(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| if self.tool_labels_str: | if self.tool_labels_str: | ||||
| return cast(dict, json.loads(self.tool_labels_str)) | |||||
| return cast(dict[str, Any], json.loads(self.tool_labels_str)) | |||||
| else: | else: | ||||
| return {} | return {} | ||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_meta(self): | |||||
| def tool_meta(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| if self.tool_meta_str: | if self.tool_meta_str: | ||||
| return cast(dict, json.loads(self.tool_meta_str)) | |||||
| return cast(dict[str, Any], json.loads(self.tool_meta_str)) | |||||
| else: | else: | ||||
| return {} | return {} | ||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_inputs_dict(self): | |||||
| def tool_inputs_dict(self) -> dict[str, Any]: | |||||
| tools = self.tools | tools = self.tools | ||||
| try: | try: | ||||
| if self.tool_input: | if self.tool_input: | ||||
| data = json.loads(self.tool_input) | data = json.loads(self.tool_input) | ||||
| result = {} | |||||
| result: dict[str, Any] = {} | |||||
| for tool in tools: | for tool in tools: | ||||
| if tool in data: | if tool in data: | ||||
| result[tool] = data[tool] | result[tool] = data[tool] | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_outputs_dict(self): | |||||
| def tool_outputs_dict(self) -> dict[str, Any]: | |||||
| tools = self.tools | tools = self.tools | ||||
| try: | try: | ||||
| if self.observation: | if self.observation: | ||||
| data = json.loads(self.observation) | data = json.loads(self.observation) | ||||
| result = {} | |||||
| result: dict[str, Any] = {} | |||||
| for tool in tools: | for tool in tools: | ||||
| if tool in data: | if tool in data: | ||||
| result[tool] = data[tool] | result[tool] = data[tool] | ||||
| is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | ||||
| @property | @property | ||||
| def tracing_config_dict(self): | |||||
| def tracing_config_dict(self) -> dict[str, Any]: | |||||
| return self.tracing_config or {} | return self.tracing_config or {} | ||||
| @property | @property | ||||
| def tracing_config_str(self): | |||||
| def tracing_config_str(self) -> str: | |||||
| return json.dumps(self.tracing_config_dict) | return json.dumps(self.tracing_config_dict) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, |
| SYSTEM = "system" | SYSTEM = "system" | ||||
| @staticmethod | @staticmethod | ||||
| def value_of(value): | |||||
| def value_of(value: str) -> "ProviderType": | |||||
| for member in ProviderType: | for member in ProviderType: | ||||
| if member.value == value: | if member.value == value: | ||||
| return member | return member | ||||
| """hosted trial quota""" | """hosted trial quota""" | ||||
| @staticmethod | @staticmethod | ||||
| def value_of(value): | |||||
| def value_of(value: str) -> "ProviderQuotaType": | |||||
| for member in ProviderQuotaType: | for member in ProviderQuotaType: | ||||
| if member.value == value: | if member.value == value: | ||||
| return member | return member |
| import json | import json | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import TYPE_CHECKING, Optional, cast | |||||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | ||||
| @property | @property | ||||
| def oauth_params(self): | |||||
| return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) | |||||
| def oauth_params(self) -> dict[str, Any]: | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}")) | |||||
| class BuiltinToolProvider(Base): | class BuiltinToolProvider(Base): | ||||
| expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| return cast(dict, json.loads(self.encrypted_credentials)) | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) | |||||
| class ApiToolProvider(Base): | class ApiToolProvider(Base): | ||||
| return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| return dict(json.loads(self.credentials_str)) | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| return dict[str, Any](json.loads(self.credentials_str)) | |||||
| @property | @property | ||||
| def user(self) -> Account | None: | def user(self) -> Account | None: | ||||
| return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| return cast(dict, json.loads(self.encrypted_credentials)) or {} | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} | |||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| return mask_url(self.decrypted_server_url) | return mask_url(self.decrypted_server_url) | ||||
| @property | @property | ||||
| def decrypted_credentials(self): | |||||
| def decrypted_credentials(self) -> dict[str, Any]: | |||||
| from core.helper.provider_cache import NoOpProviderCredentialCache | from core.helper.provider_cache import NoOpProviderCredentialCache | ||||
| from core.tools.mcp_tool.provider import MCPToolProviderController | from core.tools.mcp_tool.provider import MCPToolProviderController | ||||
| from core.tools.utils.encryption import create_provider_encrypter | from core.tools.utils.encryption import create_provider_encrypter | ||||
| provider_controller = MCPToolProviderController._from_db(self) | |||||
| provider_controller = MCPToolProviderController.from_db(self) | |||||
| encrypter, _ = create_provider_encrypter( | encrypter, _ = create_provider_encrypter( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| cache=NoOpProviderCredentialCache(), | cache=NoOpProviderCredentialCache(), | ||||
| ) | ) | ||||
| return encrypter.decrypt(self.credentials) # type: ignore | |||||
| return encrypter.decrypt(self.credentials) | |||||
| class ToolModelInvoke(Base): | class ToolModelInvoke(Base): |
| import enum | import enum | ||||
| from typing import Generic, TypeVar | |||||
| import uuid | |||||
| from typing import Any, Generic, TypeVar | |||||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | from sqlalchemy import CHAR, VARCHAR, TypeDecorator | ||||
| from sqlalchemy.dialects.postgresql import UUID | from sqlalchemy.dialects.postgresql import UUID | ||||
| from sqlalchemy.engine.interfaces import Dialect | |||||
| from sqlalchemy.sql.type_api import TypeEngine | |||||
| class StringUUID(TypeDecorator): | |||||
| class StringUUID(TypeDecorator[uuid.UUID | str | None]): | |||||
| impl = CHAR | impl = CHAR | ||||
| cache_ok = True | cache_ok = True | ||||
| def process_bind_param(self, value, dialect): | |||||
| def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| elif dialect.name == "postgresql": | elif dialect.name == "postgresql": | ||||
| return str(value) | return str(value) | ||||
| else: | else: | ||||
| return value.hex | |||||
| if isinstance(value, uuid.UUID): | |||||
| return value.hex | |||||
| return value | |||||
| def load_dialect_impl(self, dialect): | |||||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||||
| if dialect.name == "postgresql": | if dialect.name == "postgresql": | ||||
| return dialect.type_descriptor(UUID()) | return dialect.type_descriptor(UUID()) | ||||
| else: | else: | ||||
| return dialect.type_descriptor(CHAR(36)) | return dialect.type_descriptor(CHAR(36)) | ||||
| def process_result_value(self, value, dialect): | |||||
| def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| return str(value) | return str(value) | ||||
| _E = TypeVar("_E", bound=enum.StrEnum) | _E = TypeVar("_E", bound=enum.StrEnum) | ||||
| class EnumText(TypeDecorator, Generic[_E]): | |||||
| class EnumText(TypeDecorator[_E | None], Generic[_E]): | |||||
| impl = VARCHAR | impl = VARCHAR | ||||
| cache_ok = True | cache_ok = True | ||||
| # leave some rooms for future longer enum values. | # leave some rooms for future longer enum values. | ||||
| self._length = max(max_enum_value_len, 20) | self._length = max(max_enum_value_len, 20) | ||||
| def process_bind_param(self, value: _E | str | None, dialect): | |||||
| def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| if isinstance(value, self._enum_class): | if isinstance(value, self._enum_class): | ||||
| return value.value | return value.value | ||||
| elif isinstance(value, str): | |||||
| self._enum_class(value) | |||||
| return value | |||||
| else: | |||||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||||
| # Since _E is bound to StrEnum which inherits from str, at this point value must be str | |||||
| self._enum_class(value) | |||||
| return value | |||||
| def load_dialect_impl(self, dialect): | |||||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||||
| return dialect.type_descriptor(VARCHAR(self._length)) | return dialect.type_descriptor(VARCHAR(self._length)) | ||||
| def process_result_value(self, value, dialect) -> _E | None: | |||||
| def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| if not isinstance(value, str): | |||||
| raise TypeError(f"expected str, got {type(value)}") | |||||
| # Type annotation guarantees value is str at this point | |||||
| return self._enum_class(value) | return self._enum_class(value) | ||||
| def compare_values(self, x, y): | |||||
| def compare_values(self, x: _E | None, y: _E | None) -> bool: | |||||
| if x is None or y is None: | if x is None or y is None: | ||||
| return x is y | return x is y | ||||
| return x == y | return x == y |
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| raise WorkflowDataError("nodes not found in workflow graph") | raise WorkflowDataError("nodes not found in workflow graph") | ||||
| try: | try: | ||||
| node_config = next(filter(lambda node: node["id"] == node_id, nodes)) | |||||
| node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) | |||||
| except StopIteration: | except StopIteration: | ||||
| raise NodeNotFoundError(node_id) | raise NodeNotFoundError(node_id) | ||||
| assert isinstance(node_config, dict) | assert isinstance(node_config, dict) | ||||
| def features_dict(self) -> dict[str, Any]: | def features_dict(self) -> dict[str, Any]: | ||||
| return json.loads(self.features) if self.features else {} | return json.loads(self.features) if self.features else {} | ||||
| def user_input_form(self, to_old_structure: bool = False): | |||||
| def user_input_form(self, to_old_structure: bool = False) -> list[Any]: | |||||
| # get start node from graph | # get start node from graph | ||||
| if not self.graph: | if not self.graph: | ||||
| return [] | return [] | ||||
| variables: list[Any] = start_node.get("data", {}).get("variables", []) | variables: list[Any] = start_node.get("data", {}).get("variables", []) | ||||
| if to_old_structure: | if to_old_structure: | ||||
| old_structure_variables = [] | |||||
| old_structure_variables: list[dict[str, Any]] = [] | |||||
| for variable in variables: | for variable in variables: | ||||
| old_structure_variables.append({variable["type"]: variable}) | old_structure_variables.append({variable["type"]: variable}) | ||||
| @property | @property | ||||
| def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | ||||
| # TODO: find some way to init `self._environment_variables` when instance created. | |||||
| if self._environment_variables is None: | |||||
| self._environment_variables = "{}" | |||||
| # _environment_variables is guaranteed to be non-None due to server_default="{}" | |||||
| # Use workflow.tenant_id to avoid relying on request user in background threads | # Use workflow.tenant_id to avoid relying on request user in background threads | ||||
| tenant_id = self.tenant_id | tenant_id = self.tenant_id | ||||
| ] | ] | ||||
| # decrypt secret variables value | # decrypt secret variables value | ||||
| def decrypt_func(var): | |||||
| def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: | |||||
| if isinstance(var, SecretVariable): | if isinstance(var, SecretVariable): | ||||
| return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | ||||
| elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | ||||
| return var | return var | ||||
| else: | else: | ||||
| raise AssertionError("this statement should be unreachable.") | |||||
| # Other variable types are not supported for environment variables | |||||
| raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") | |||||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( | |||||
| map(decrypt_func, results) | |||||
| ) | |||||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ | |||||
| decrypt_func(var) for var in results | |||||
| ] | |||||
| return decrypted_results | return decrypted_results | ||||
| @environment_variables.setter | @environment_variables.setter | ||||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | ||||
| # encrypt secret variables value | # encrypt secret variables value | ||||
| def encrypt_func(var): | |||||
| def encrypt_func(var: Variable) -> Variable: | |||||
| if isinstance(var, SecretVariable): | if isinstance(var, SecretVariable): | ||||
| return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | ||||
| else: | else: | ||||
| @property | @property | ||||
| def conversation_variables(self) -> Sequence[Variable]: | def conversation_variables(self) -> Sequence[Variable]: | ||||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||||
| if self._conversation_variables is None: | |||||
| self._conversation_variables = "{}" | |||||
| # _conversation_variables is guaranteed to be non-None due to server_default="{}" | |||||
| variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | ||||
| results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict) -> "WorkflowRun": | |||||
| def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": | |||||
| return cls( | return cls( | ||||
| id=data.get("id"), | id=data.get("id"), | ||||
| tenant_id=data.get("tenant_id"), | tenant_id=data.get("tenant_id"), | ||||
| __tablename__ = "workflow_node_executions" | __tablename__ = "workflow_node_executions" | ||||
| @declared_attr | @declared_attr | ||||
| def __table_args__(cls): # noqa | |||||
| @classmethod | |||||
| def __table_args__(cls) -> Any: | |||||
| return ( | return ( | ||||
| PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | ||||
| Index( | Index( | ||||
| # MyPy may flag the following line because it doesn't recognize that | # MyPy may flag the following line because it doesn't recognize that | ||||
| # the `declared_attr` decorator passes the receiving class as the first | # the `declared_attr` decorator passes the receiving class as the first | ||||
| # argument to this method, allowing us to reference class attributes. | # argument to this method, allowing us to reference class attributes. | ||||
| cls.created_at.desc(), # type: ignore | |||||
| cls.created_at.desc(), | |||||
| ), | ), | ||||
| ) | ) | ||||
| return json.loads(self.execution_metadata) if self.execution_metadata else {} | return json.loads(self.execution_metadata) if self.execution_metadata else {} | ||||
| @property | @property | ||||
| def extras(self): | |||||
| def extras(self) -> dict[str, Any]: | |||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| extras = {} | |||||
| extras: dict[str, Any] = {} | |||||
| if self.execution_metadata_dict: | if self.execution_metadata_dict: | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | ||||
| tool_info = self.execution_metadata_dict["tool_info"] | |||||
| tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] | |||||
| extras["icon"] = ToolManager.get_tool_icon( | extras["icon"] = ToolManager.get_tool_icon( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| provider_type=tool_info["provider_type"], | provider_type=tool_info["provider_type"], | ||||
| # making this attribute harder to access from outside the class. | # making this attribute harder to access from outside the class. | ||||
| __value: Segment | None | __value: Segment | None | ||||
| def __init__(self, *args, **kwargs): | |||||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||||
| """ | """ | ||||
| The constructor of `WorkflowDraftVariable` is not intended for | The constructor of `WorkflowDraftVariable` is not intended for | ||||
| direct use outside this file. Its solo purpose is setup private state | direct use outside this file. Its solo purpose is setup private state | ||||
| self.__value = None | self.__value = None | ||||
| def get_selector(self) -> list[str]: | def get_selector(self) -> list[str]: | ||||
| selector = json.loads(self.selector) | |||||
| selector: Any = json.loads(self.selector) | |||||
| if not isinstance(selector, list): | if not isinstance(selector, list): | ||||
| logger.error( | logger.error( | ||||
| "invalid selector loaded from database, type=%s, value=%s", | "invalid selector loaded from database, type=%s, value=%s", | ||||
| type(selector), | |||||
| type(selector).__name__, | |||||
| self.selector, | self.selector, | ||||
| ) | ) | ||||
| raise ValueError("invalid selector.") | raise ValueError("invalid selector.") | ||||
| return selector | |||||
| return cast(list[str], selector) | |||||
| def _set_selector(self, value: list[str]): | def _set_selector(self, value: list[str]): | ||||
| self.selector = json.dumps(value) | self.selector = json.dumps(value) | ||||
| # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | ||||
| if isinstance(value, dict): | if isinstance(value, dict): | ||||
| if not maybe_file_object(value): | if not maybe_file_object(value): | ||||
| return value | |||||
| return cast(Any, value) | |||||
| return File.model_validate(value) | return File.model_validate(value) | ||||
| elif isinstance(value, list) and value: | elif isinstance(value, list) and value: | ||||
| first = value[0] | |||||
| value_list = cast(list[Any], value) | |||||
| first: Any = value_list[0] | |||||
| if not maybe_file_object(first): | if not maybe_file_object(first): | ||||
| return value | |||||
| return [File.model_validate(i) for i in value] | |||||
| return cast(Any, value) | |||||
| file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] | |||||
| return cast(Any, file_list) | |||||
| else: | else: | ||||
| return value | |||||
| return cast(Any, value) | |||||
| @classmethod | @classmethod | ||||
| def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: | def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: |
| "tests/", | "tests/", | ||||
| "migrations/", | "migrations/", | ||||
| ".venv/", | ".venv/", | ||||
| "models/", | |||||
| "core/", | "core/", | ||||
| "controllers/", | "controllers/", | ||||
| "tasks/", | "tasks/", |
| import threading | import threading | ||||
| from typing import Optional | |||||
| from typing import Any, Optional | |||||
| import pytz | import pytz | ||||
| from flask_login import current_user | |||||
| import contexts | import contexts | ||||
| from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager | from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager | ||||
| from core.plugin.impl.exc import PluginDaemonClientSideError | from core.plugin.impl.exc import PluginDaemonClientSideError | ||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.login import current_user | |||||
| from models.account import Account | from models.account import Account | ||||
| from models.model import App, Conversation, EndUser, Message, MessageAgentThought | from models.model import App, Conversation, EndUser, Message, MessageAgentThought | ||||
| executor = executor.name | executor = executor.name | ||||
| else: | else: | ||||
| executor = "Unknown" | executor = "Unknown" | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.timezone is not None | |||||
| timezone = pytz.timezone(current_user.timezone) | timezone = pytz.timezone(current_user.timezone) | ||||
| app_model_config = app_model.app_model_config | app_model_config = app_model.app_model_config | ||||
| if not app_model_config: | if not app_model_config: | ||||
| raise ValueError("App model config not found") | raise ValueError("App model config not found") | ||||
| result = { | |||||
| result: dict[str, Any] = { | |||||
| "meta": { | "meta": { | ||||
| "status": "success", | "status": "success", | ||||
| "executor": executor, | "executor": executor, |
| from typing import Optional | from typing import Optional | ||||
| import pandas as pd | import pandas as pd | ||||
| from flask_login import current_user | |||||
| from sqlalchemy import or_, select | from sqlalchemy import or_, select | ||||
| from werkzeug.datastructures import FileStorage | from werkzeug.datastructures import FileStorage | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from libs.login import current_user | |||||
| from models.account import Account | |||||
| from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation | from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task | from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task | ||||
| @classmethod | @classmethod | ||||
| def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| db.session.commit() | db.session.commit() | ||||
| # if annotation reply is enabled , add annotation to index | # if annotation reply is enabled , add annotation to index | ||||
| annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() | annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() | ||||
| assert current_user.current_tenant_id is not None | |||||
| if annotation_setting: | if annotation_setting: | ||||
| add_annotation_to_index_task.delay( | add_annotation_to_index_task.delay( | ||||
| annotation.id, | annotation.id, | ||||
| enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | ||||
| # send batch add segments task | # send batch add segments task | ||||
| redis_client.setnx(enable_app_annotation_job_key, "waiting") | redis_client.setnx(enable_app_annotation_job_key, "waiting") | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| enable_annotation_reply_task.delay( | enable_annotation_reply_task.delay( | ||||
| str(job_id), | str(job_id), | ||||
| app_id, | app_id, | ||||
| @classmethod | @classmethod | ||||
| def disable_app_annotation(cls, app_id: str): | def disable_app_annotation(cls, app_id: str): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" | disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" | ||||
| cache_result = redis_client.get(disable_app_annotation_key) | cache_result = redis_client.get(disable_app_annotation_key) | ||||
| if cache_result is not None: | if cache_result is not None: | ||||
| @classmethod | @classmethod | ||||
| def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def export_annotation_list_by_app_id(cls, app_id: str): | def export_annotation_list_by_app_id(cls, app_id: str): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def delete_app_annotation(cls, app_id: str, annotation_id: str): | def delete_app_annotation(cls, app_id: str, annotation_id: str): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): | def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def batch_import_app_annotations(cls, app_id, file: FileStorage): | def batch_import_app_annotations(cls, app_id, file: FileStorage): | ||||
| # get app info | # get app info | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | ||||
| @classmethod | @classmethod | ||||
| def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| # get app info | # get app info | ||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| @classmethod | @classmethod | ||||
| def get_app_annotation_setting_by_app_id(cls, app_id: str): | def get_app_annotation_setting_by_app_id(cls, app_id: str): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| # get app info | # get app info | ||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| @classmethod | @classmethod | ||||
| def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| # get app info | # get app info | ||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| @classmethod | @classmethod | ||||
| def clear_all_annotations(cls, app_id: str): | def clear_all_annotations(cls, app_id: str): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| app = ( | app = ( | ||||
| db.session.query(App) | db.session.query(App) | ||||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") |
| import logging | import logging | ||||
| from typing import Optional, TypedDict, cast | from typing import Optional, TypedDict, cast | ||||
| from flask_login import current_user | |||||
| from flask_sqlalchemy.pagination import Pagination | from flask_sqlalchemy.pagination import Pagination | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from events.app_event import app_was_created | from events.app_event import app_was_created | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from libs.login import current_user | |||||
| from models.account import Account | from models.account import Account | ||||
| from models.model import App, AppMode, AppModelConfig, Site | from models.model import App, AppMode, AppModelConfig, Site | ||||
| from models.tools import ApiToolProvider | from models.tools import ApiToolProvider | ||||
| """ | """ | ||||
| Get App | Get App | ||||
| """ | """ | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| # get original app model config | # get original app model config | ||||
| if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | ||||
| model_config = app.app_model_config | model_config = app.app_model_config | ||||
| if not model_config: | |||||
| return app | |||||
| agent_mode = model_config.agent_mode_dict | agent_mode = model_config.agent_mode_dict | ||||
| # decrypt agent tool parameters if it's secret-input | # decrypt agent tool parameters if it's secret-input | ||||
| for tool in agent_mode.get("tools") or []: | for tool in agent_mode.get("tools") or []: | ||||
| pass | pass | ||||
| # override agent mode | # override agent mode | ||||
| model_config.agent_mode = json.dumps(agent_mode) | |||||
| if model_config: | |||||
| model_config.agent_mode = json.dumps(agent_mode) | |||||
| class ModifiedApp(App): | class ModifiedApp(App): | ||||
| """ | """ | ||||
| :param args: request args | :param args: request args | ||||
| :return: App instance | :return: App instance | ||||
| """ | """ | ||||
| assert current_user is not None | |||||
| app.name = args["name"] | app.name = args["name"] | ||||
| app.description = args["description"] | app.description = args["description"] | ||||
| app.icon_type = args["icon_type"] | app.icon_type = args["icon_type"] | ||||
| :param name: new name | :param name: new name | ||||
| :return: App instance | :return: App instance | ||||
| """ | """ | ||||
| assert current_user is not None | |||||
| app.name = name | app.name = name | ||||
| app.updated_by = current_user.id | app.updated_by = current_user.id | ||||
| app.updated_at = naive_utc_now() | app.updated_at = naive_utc_now() | ||||
| :param icon_background: new icon_background | :param icon_background: new icon_background | ||||
| :return: App instance | :return: App instance | ||||
| """ | """ | ||||
| assert current_user is not None | |||||
| app.icon = icon | app.icon = icon | ||||
| app.icon_background = icon_background | app.icon_background = icon_background | ||||
| app.updated_by = current_user.id | app.updated_by = current_user.id | ||||
| """ | """ | ||||
| if enable_site == app.enable_site: | if enable_site == app.enable_site: | ||||
| return app | return app | ||||
| assert current_user is not None | |||||
| app.enable_site = enable_site | app.enable_site = enable_site | ||||
| app.updated_by = current_user.id | app.updated_by = current_user.id | ||||
| app.updated_at = naive_utc_now() | app.updated_at = naive_utc_now() | ||||
| """ | """ | ||||
| if enable_api == app.enable_api: | if enable_api == app.enable_api: | ||||
| return app | return app | ||||
| assert current_user is not None | |||||
| app.enable_api = enable_api | app.enable_api = enable_api | ||||
| app.updated_by = current_user.id | app.updated_by = current_user.id |
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.enums import MessageStatus | from models.enums import MessageStatus | ||||
| from models.model import App, AppMode, AppModelConfig, Message | |||||
| from models.model import App, AppMode, Message | |||||
| from services.errors.audio import ( | from services.errors.audio import ( | ||||
| AudioTooLargeServiceError, | AudioTooLargeServiceError, | ||||
| NoAudioUploadedServiceError, | NoAudioUploadedServiceError, | ||||
| if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | ||||
| raise ValueError("Speech to text is not enabled") | raise ValueError("Speech to text is not enabled") | ||||
| else: | else: | ||||
| app_model_config: AppModelConfig = app_model.app_model_config | |||||
| app_model_config = app_model.app_model_config | |||||
| if not app_model_config: | |||||
| raise ValueError("Speech to text is not enabled") | |||||
| if not app_model_config.speech_to_text_dict["enabled"]: | if not app_model_config.speech_to_text_dict["enabled"]: | ||||
| raise ValueError("Speech to text is not enabled") | raise ValueError("Speech to text is not enabled") |
| return response.json() | return response.json() | ||||
| @staticmethod | @staticmethod | ||||
| def is_tenant_owner_or_admin(current_user): | |||||
| def is_tenant_owner_or_admin(current_user: Account): | |||||
| tenant_id = current_user.current_tenant_id | tenant_id = current_user.current_tenant_id | ||||
| join: Optional[TenantAccountJoin] = ( | join: Optional[TenantAccountJoin] = ( |
| from collections import Counter | from collections import Counter | ||||
| from typing import Any, Literal, Optional | from typing import Any, Literal, Optional | ||||
| from flask_login import current_user | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy import exists, func, select | from sqlalchemy import exists, func, select | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from libs import helper | from libs import helper | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from libs.login import current_user | |||||
| from models.account import Account, TenantAccountRole | from models.account import Account, TenantAccountRole | ||||
| from models.dataset import ( | from models.dataset import ( | ||||
| AppDatasetJoin, | AppDatasetJoin, | ||||
| data: Update data dictionary | data: Update data dictionary | ||||
| filtered_data: Filtered update data to modify | filtered_data: Filtered update data to modify | ||||
| """ | """ | ||||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||||
| try: | try: | ||||
| model_manager = ModelManager() | model_manager = ModelManager() | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| embedding_model = model_manager.get_model_instance( | embedding_model = model_manager.get_model_instance( | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| provider=data["embedding_model_provider"], | provider=data["embedding_model_provider"], | ||||
| data: Update data dictionary | data: Update data dictionary | ||||
| filtered_data: Filtered update data to modify | filtered_data: Filtered update data to modify | ||||
| """ | """ | ||||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||||
| model_manager = ModelManager() | model_manager = ModelManager() | ||||
| try: | try: | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| embedding_model = model_manager.get_model_instance( | embedding_model = model_manager.get_model_instance( | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| provider=data["embedding_model_provider"], | provider=data["embedding_model_provider"], | ||||
| @staticmethod | @staticmethod | ||||
| def get_dataset_auto_disable_logs(dataset_id: str): | def get_dataset_auto_disable_logs(dataset_id: str): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | ||||
| return { | return { | ||||
| @staticmethod | @staticmethod | ||||
| def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | ||||
| assert isinstance(current_user, Account) | |||||
| documents = ( | documents = ( | ||||
| db.session.query(Document) | db.session.query(Document) | ||||
| .where( | .where( | ||||
| file_ids = [ | file_ids = [ | ||||
| document.data_source_info_dict["upload_file_id"] | document.data_source_info_dict["upload_file_id"] | ||||
| for document in documents | for document in documents | ||||
| if document.data_source_type == "upload_file" | |||||
| if document.data_source_type == "upload_file" and document.data_source_info_dict | |||||
| ] | ] | ||||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | ||||
| @staticmethod | @staticmethod | ||||
| def rename_document(dataset_id: str, document_id: str, name: str) -> Document: | def rename_document(dataset_id: str, document_id: str, name: str) -> Document: | ||||
| assert isinstance(current_user, Account) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) | ||||
| if not dataset: | if not dataset: | ||||
| raise ValueError("Dataset not found.") | raise ValueError("Dataset not found.") | ||||
| if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: | if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: | ||||
| raise DocumentIndexingError() | raise DocumentIndexingError() | ||||
| # update document to be paused | # update document to be paused | ||||
| assert current_user is not None | |||||
| document.is_paused = True | document.is_paused = True | ||||
| document.paused_by = current_user.id | document.paused_by = current_user.id | ||||
| document.paused_at = naive_utc_now() | document.paused_at = naive_utc_now() | ||||
| # sync document indexing | # sync document indexing | ||||
| document.indexing_status = "waiting" | document.indexing_status = "waiting" | ||||
| data_source_info = document.data_source_info_dict | data_source_info = document.data_source_info_dict | ||||
| data_source_info["mode"] = "scrape" | |||||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||||
| if data_source_info: | |||||
| data_source_info["mode"] = "scrape" | |||||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| # check doc_form | # check doc_form | ||||
| DatasetService.check_doc_form(dataset, knowledge_config.doc_form) | DatasetService.check_doc_form(dataset, knowledge_config.doc_form) | ||||
| # check document limit | # check document limit | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| @staticmethod | @staticmethod | ||||
| def get_tenant_documents_count(): | def get_tenant_documents_count(): | ||||
| assert isinstance(current_user, Account) | |||||
| documents_count = ( | documents_count = ( | ||||
| db.session.query(Document) | db.session.query(Document) | ||||
| .where( | .where( | ||||
| dataset_process_rule: Optional[DatasetProcessRule] = None, | dataset_process_rule: Optional[DatasetProcessRule] = None, | ||||
| created_from: str = "web", | created_from: str = "web", | ||||
| ): | ): | ||||
| assert isinstance(current_user, Account) | |||||
| DatasetService.check_dataset_model_setting(dataset) | DatasetService.check_dataset_model_setting(dataset) | ||||
| document = DocumentService.get_document(dataset.id, document_data.original_document_id) | document = DocumentService.get_document(dataset.id, document_data.original_document_id) | ||||
| if document is None: | if document is None: | ||||
| data_source_binding = ( | data_source_binding = ( | ||||
| db.session.query(DataSourceOauthBinding) | db.session.query(DataSourceOauthBinding) | ||||
| .where( | .where( | ||||
| db.and_( | |||||
| sa.and_( | |||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceOauthBinding.provider == "notion", | DataSourceOauthBinding.provider == "notion", | ||||
| DataSourceOauthBinding.disabled == False, | DataSourceOauthBinding.disabled == False, | ||||
| @staticmethod | @staticmethod | ||||
| def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): | def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | features = FeatureService.get_features(current_user.current_tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| @classmethod | @classmethod | ||||
| def create_segment(cls, args: dict, document: Document, dataset: Dataset): | def create_segment(cls, args: dict, document: Document, dataset: Dataset): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| content = args["content"] | content = args["content"] | ||||
| doc_id = str(uuid.uuid4()) | doc_id = str(uuid.uuid4()) | ||||
| segment_hash = helper.generate_text_hash(content) | segment_hash = helper.generate_text_hash(content) | ||||
| @classmethod | @classmethod | ||||
| def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): | def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| lock_name = f"multi_add_segment_lock_document_id_{document.id}" | lock_name = f"multi_add_segment_lock_document_id_{document.id}" | ||||
| increment_word_count = 0 | increment_word_count = 0 | ||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| @classmethod | @classmethod | ||||
| def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): | def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| indexing_cache_key = f"segment_{segment.id}_indexing" | indexing_cache_key = f"segment_{segment.id}_indexing" | ||||
| cache_result = redis_client.get(indexing_cache_key) | cache_result = redis_client.get(indexing_cache_key) | ||||
| if cache_result is not None: | if cache_result is not None: | ||||
| @classmethod | @classmethod | ||||
| def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): | def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): | ||||
| assert isinstance(current_user, Account) | |||||
| segments = ( | segments = ( | ||||
| db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) | db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) | ||||
| .where( | .where( | ||||
| def update_segments_status( | def update_segments_status( | ||||
| cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document | cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document | ||||
| ): | ): | ||||
| assert current_user is not None | |||||
| # Check if segment_ids is not empty to avoid WHERE false condition | # Check if segment_ids is not empty to avoid WHERE false condition | ||||
| if not segment_ids or len(segment_ids) == 0: | if not segment_ids or len(segment_ids) == 0: | ||||
| return | return | ||||
| def create_child_chunk( | def create_child_chunk( | ||||
| cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset | cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset | ||||
| ) -> ChildChunk: | ) -> ChildChunk: | ||||
| assert isinstance(current_user, Account) | |||||
| lock_name = f"add_child_lock_{segment.id}" | lock_name = f"add_child_lock_{segment.id}" | ||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| index_node_id = str(uuid.uuid4()) | index_node_id = str(uuid.uuid4()) | ||||
| document: Document, | document: Document, | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| ) -> list[ChildChunk]: | ) -> list[ChildChunk]: | ||||
| assert isinstance(current_user, Account) | |||||
| child_chunks = ( | child_chunks = ( | ||||
| db.session.query(ChildChunk) | db.session.query(ChildChunk) | ||||
| .where( | .where( | ||||
| document: Document, | document: Document, | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| ) -> ChildChunk: | ) -> ChildChunk: | ||||
| assert current_user is not None | |||||
| try: | try: | ||||
| child_chunk.content = content | child_chunk.content = content | ||||
| child_chunk.word_count = len(content) | child_chunk.word_count = len(content) | ||||
| def get_child_chunks( | def get_child_chunks( | ||||
| cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None | cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None | ||||
| ): | ): | ||||
| assert isinstance(current_user, Account) | |||||
| query = ( | query = ( | ||||
| select(ChildChunk) | select(ChildChunk) | ||||
| .filter_by( | .filter_by( |
| ) | ) | ||||
| if external_knowledge_api is None: | if external_knowledge_api is None: | ||||
| raise ValueError("api template not found") | raise ValueError("api template not found") | ||||
| if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: | |||||
| args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||||
| settings = args.get("settings") | |||||
| if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict: | |||||
| settings["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||||
| external_knowledge_api.name = args.get("name") | external_knowledge_api.name = args.get("name") | ||||
| external_knowledge_api.description = args.get("description", "") | external_knowledge_api.description = args.get("description", "") |
| import uuid | import uuid | ||||
| from typing import Any, Literal, Union | from typing import Any, Literal, Union | ||||
| from flask_login import current_user | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from libs.helper import extract_tenant_id | from libs.helper import extract_tenant_id | ||||
| from libs.login import current_user | |||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatorUserRole | from models.enums import CreatorUserRole | ||||
| from models.model import EndUser, UploadFile | from models.model import EndUser, UploadFile | ||||
| @staticmethod | @staticmethod | ||||
| def upload_text(text: str, text_name: str) -> UploadFile: | def upload_text(text: str, text_name: str) -> UploadFile: | ||||
| assert isinstance(current_user, Account) | |||||
| assert current_user.current_tenant_id is not None | |||||
| if len(text_name) > 200: | if len(text_name) > 200: | ||||
| text_name = text_name[:200] | text_name = text_name[:200] | ||||
| # user uuid as file name | # user uuid as file name |
| def update_mcp_provider_credentials( | def update_mcp_provider_credentials( | ||||
| cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | ||||
| ): | ): | ||||
| provider_controller = MCPToolProviderController._from_db(mcp_provider) | |||||
| provider_controller = MCPToolProviderController.from_db(mcp_provider) | |||||
| tool_configuration = ProviderConfigEncrypter( | tool_configuration = ProviderConfigEncrypter( | ||||
| tenant_id=mcp_provider.tenant_id, | tenant_id=mcp_provider.tenant_id, | ||||
| config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] | config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] |
| import json | import json | ||||
| from unittest.mock import MagicMock, patch | |||||
| from unittest.mock import MagicMock, create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from core.plugin.impl.exc import PluginDaemonClientSideError | from core.plugin.impl.exc import PluginDaemonClientSideError | ||||
| from models.account import Account | |||||
| from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought | from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought | ||||
| from services.account_service import AccountService, TenantService | from services.account_service import AccountService, TenantService | ||||
| from services.agent_service import AgentService | from services.agent_service import AgentService | ||||
| patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, | patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, | ||||
| patch("services.agent_service.ToolManager") as mock_tool_manager, | patch("services.agent_service.ToolManager") as mock_tool_manager, | ||||
| patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, | patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, | ||||
| patch("services.agent_service.current_user") as mock_current_user, | |||||
| patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||||
| patch("services.app_service.FeatureService") as mock_feature_service, | patch("services.app_service.FeatureService") as mock_feature_service, | ||||
| patch("services.app_service.EnterpriseService") as mock_enterprise_service, | patch("services.app_service.EnterpriseService") as mock_enterprise_service, | ||||
| patch("services.app_service.ModelManager") as mock_model_manager, | patch("services.app_service.ModelManager") as mock_model_manager, |
| from unittest.mock import patch | |||||
| from unittest.mock import create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from models.account import Account | |||||
| from models.model import MessageAnnotation | from models.model import MessageAnnotation | ||||
| from services.annotation_service import AppAnnotationService | from services.annotation_service import AppAnnotationService | ||||
| from services.app_service import AppService | from services.app_service import AppService | ||||
| patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, | patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, | ||||
| patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, | patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, | ||||
| patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, | patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, | ||||
| patch("services.annotation_service.current_user") as mock_current_user, | |||||
| patch( | |||||
| "services.annotation_service.current_user", create_autospec(Account, instance=True) | |||||
| ) as mock_current_user, | |||||
| ): | ): | ||||
| # Setup default mock returns | # Setup default mock returns | ||||
| mock_account_feature_service.get_features.return_value.billing.enabled = False | mock_account_feature_service.get_features.return_value.billing.enabled = False |
| from unittest.mock import patch | |||||
| from unittest.mock import create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from constants.model_template import default_app_templates | from constants.model_template import default_app_templates | ||||
| from models.account import Account | |||||
| from models.model import App, Site | from models.model import App, Site | ||||
| from services.account_service import AccountService, TenantService | from services.account_service import AccountService, TenantService | ||||
| from services.app_service import AppService | from services.app_service import AppService | ||||
| app_service = AppService() | app_service = AppService() | ||||
| created_app = app_service.create_app(tenant.id, app_args, account) | created_app = app_service.create_app(tenant.id, app_args, account) | ||||
| # Get app using the service | |||||
| retrieved_app = app_service.get_app(created_app) | |||||
| # Get app using the service - needs current_user mock | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| retrieved_app = app_service.get_app(created_app) | |||||
| # Verify retrieved app matches created app | # Verify retrieved app matches created app | ||||
| assert retrieved_app.id == created_app.id | assert retrieved_app.id == created_app.id | ||||
| "use_icon_as_answer_icon": True, | "use_icon_as_answer_icon": True, | ||||
| } | } | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app(app, update_args) | updated_app = app_service.update_app(app, update_args) | ||||
| # Verify updated fields | # Verify updated fields | ||||
| # Update app name | # Update app name | ||||
| new_name = "New App Name" | new_name = "New App Name" | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_name(app, new_name) | updated_app = app_service.update_app_name(app, new_name) | ||||
| assert updated_app.name == new_name | assert updated_app.name == new_name | ||||
| # Update app icon | # Update app icon | ||||
| new_icon = "🌟" | new_icon = "🌟" | ||||
| new_icon_background = "#FFD93D" | new_icon_background = "#FFD93D" | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) | updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) | ||||
| assert updated_app.icon == new_icon | assert updated_app.icon == new_icon | ||||
| original_site_status = app.enable_site | original_site_status = app.enable_site | ||||
| # Update site status to disabled | # Update site status to disabled | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_site_status(app, False) | updated_app = app_service.update_app_site_status(app, False) | ||||
| assert updated_app.enable_site is False | assert updated_app.enable_site is False | ||||
| assert updated_app.updated_by == account.id | assert updated_app.updated_by == account.id | ||||
| # Update site status back to enabled | # Update site status back to enabled | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_site_status(updated_app, True) | updated_app = app_service.update_app_site_status(updated_app, True) | ||||
| assert updated_app.enable_site is True | assert updated_app.enable_site is True | ||||
| assert updated_app.updated_by == account.id | assert updated_app.updated_by == account.id | ||||
| original_api_status = app.enable_api | original_api_status = app.enable_api | ||||
| # Update API status to disabled | # Update API status to disabled | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.id = account.id | |||||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_api_status(app, False) | updated_app = app_service.update_app_api_status(app, False) | ||||
| assert updated_app.enable_api is False | assert updated_app.enable_api is False | ||||
| assert updated_app.updated_by == account.id | assert updated_app.updated_by == account.id | ||||
| # Update API status back to enabled | # Update API status back to enabled | ||||
| with patch("flask_login.utils._get_user", return_value=account): | |||||
| with patch("services.app_service.current_user", mock_current_user): | |||||
| updated_app = app_service.update_app_api_status(updated_app, True) | updated_app = app_service.update_app_api_status(updated_app, True) | ||||
| assert updated_app.enable_api is True | assert updated_app.enable_api is True | ||||
| assert updated_app.updated_by == account.id | assert updated_app.updated_by == account.id |
| import hashlib | import hashlib | ||||
| from io import BytesIO | from io import BytesIO | ||||
| from unittest.mock import patch | |||||
| from unittest.mock import create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| text = "This is a test text content" | text = "This is a test text content" | ||||
| text_name = "test_text.txt" | text_name = "test_text.txt" | ||||
| # Mock current_user | |||||
| with patch("services.file_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| # Mock current_user using create_autospec | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| with patch("services.file_service.current_user", mock_current_user): | |||||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | upload_file = FileService.upload_text(text=text, text_name=text_name) | ||||
| assert upload_file is not None | assert upload_file is not None | ||||
| text = "test content" | text = "test content" | ||||
| long_name = "a" * 250 # Longer than 200 characters | long_name = "a" * 250 # Longer than 200 characters | ||||
| # Mock current_user | |||||
| with patch("services.file_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| # Mock current_user using create_autospec | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| with patch("services.file_service.current_user", mock_current_user): | |||||
| upload_file = FileService.upload_text(text=text, text_name=long_name) | upload_file = FileService.upload_text(text=text, text_name=long_name) | ||||
| # Verify name was truncated | # Verify name was truncated | ||||
| text = "" | text = "" | ||||
| text_name = "empty.txt" | text_name = "empty.txt" | ||||
| # Mock current_user | |||||
| with patch("services.file_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| # Mock current_user using create_autospec | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||||
| mock_current_user.id = str(fake.uuid4()) | |||||
| with patch("services.file_service.current_user", mock_current_user): | |||||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | upload_file = FileService.upload_text(text=text, text_name=text_name) | ||||
| assert upload_file is not None | assert upload_file is not None |
| from unittest.mock import patch | |||||
| from unittest.mock import create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| def mock_external_service_dependencies(self): | def mock_external_service_dependencies(self): | ||||
| """Mock setup for external service dependencies.""" | """Mock setup for external service dependencies.""" | ||||
| with ( | with ( | ||||
| patch("services.metadata_service.current_user") as mock_current_user, | |||||
| patch( | |||||
| "services.metadata_service.current_user", create_autospec(Account, instance=True) | |||||
| ) as mock_current_user, | |||||
| patch("services.metadata_service.redis_client") as mock_redis_client, | patch("services.metadata_service.redis_client") as mock_redis_client, | ||||
| patch("services.dataset_service.DocumentService") as mock_document_service, | patch("services.dataset_service.DocumentService") as mock_document_service, | ||||
| ): | ): |
| from unittest.mock import patch | |||||
| from unittest.mock import create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| def mock_external_service_dependencies(self): | def mock_external_service_dependencies(self): | ||||
| """Mock setup for external service dependencies.""" | """Mock setup for external service dependencies.""" | ||||
| with ( | with ( | ||||
| patch("services.tag_service.current_user") as mock_current_user, | |||||
| patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||||
| ): | ): | ||||
| # Setup default mock returns | # Setup default mock returns | ||||
| mock_current_user.current_tenant_id = "test-tenant-id" | mock_current_user.current_tenant_id = "test-tenant-id" |
| from datetime import datetime | from datetime import datetime | ||||
| from unittest.mock import MagicMock, patch | |||||
| from unittest.mock import MagicMock, create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| fake = Faker() | fake = Faker() | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="firecrawl", | provider="firecrawl", | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="watercrawl", | provider="watercrawl", | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request for single page crawling | # Create API request for single page crawling | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="jinareader", | provider="jinareader", | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request with invalid provider | # Create API request with invalid provider | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="invalid_provider", | provider="invalid_provider", | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") | api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") | api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") | api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request with invalid provider | # Create API request with invalid provider | ||||
| api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") | api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Mock missing credentials | # Mock missing credentials | ||||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None | mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Mock missing API key in config | # Mock missing API key in config | ||||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { | mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { | ||||
| "config": {"base_url": "https://api.example.com"} | "config": {"base_url": "https://api.example.com"} | ||||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request for sub-page crawling | # Create API request for sub-page crawling | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="jinareader", | provider="jinareader", | ||||
| mock_external_service_dependencies["requests"].get.return_value = mock_failed_response | mock_external_service_dependencies["requests"].get.return_value = mock_failed_response | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlApiRequest( | api_request = WebsiteCrawlApiRequest( | ||||
| provider="jinareader", | provider="jinareader", | ||||
| mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance | mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance | ||||
| # Mock current_user for the test | # Mock current_user for the test | ||||
| with patch("services.website_service.current_user") as mock_current_user: | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| mock_current_user = create_autospec(Account, instance=True) | |||||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||||
| with patch("services.website_service.current_user", mock_current_user): | |||||
| # Create API request | # Create API request | ||||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") | api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") | ||||
| TestCase( | TestCase( | ||||
| name="session insert with invalid type", | name="session insert with invalid type", | ||||
| action=lambda s: _session_insert_with_value(s, 1), | action=lambda s: _session_insert_with_value(s, 1), | ||||
| exc_type=TypeError, | |||||
| exc_type=ValueError, | |||||
| ), | ), | ||||
| TestCase( | TestCase( | ||||
| name="insert with invalid value", | name="insert with invalid value", | ||||
| TestCase( | TestCase( | ||||
| name="insert with invalid type", | name="insert with invalid type", | ||||
| action=lambda s: _insert_with_user(s, 1), | action=lambda s: _insert_with_user(s, 1), | ||||
| exc_type=TypeError, | |||||
| exc_type=ValueError, | |||||
| ), | ), | ||||
| ] | ] | ||||
| for idx, c in enumerate(cases, 1): | for idx, c in enumerate(cases, 1): |
| from typing import Any, Optional | from typing import Any, Optional | ||||
| # Mock redis_client before importing dataset_service | # Mock redis_client before importing dataset_service | ||||
| from unittest.mock import Mock, patch | |||||
| from unittest.mock import Mock, create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from models.account import Account | |||||
| from models.dataset import Dataset, ExternalKnowledgeBindings | from models.dataset import Dataset, ExternalKnowledgeBindings | ||||
| from services.dataset_service import DatasetService | from services.dataset_service import DatasetService | ||||
| from services.errors.account import NoPermissionError | from services.errors.account import NoPermissionError | ||||
| @staticmethod | @staticmethod | ||||
| def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: | def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: | ||||
| """Create a mock current user.""" | """Create a mock current user.""" | ||||
| current_user = Mock() | |||||
| current_user = create_autospec(Account, instance=True) | |||||
| current_user.current_tenant_id = tenant_id | current_user.current_tenant_id = tenant_id | ||||
| return current_user | return current_user | ||||
| "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" | "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" | ||||
| ) as mock_get_binding, | ) as mock_get_binding, | ||||
| patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, | patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, | ||||
| patch("services.dataset_service.current_user") as mock_current_user, | |||||
| patch( | |||||
| "services.dataset_service.current_user", create_autospec(Account, instance=True) | |||||
| ) as mock_current_user, | |||||
| ): | ): | ||||
| mock_current_user.current_tenant_id = "tenant-123" | mock_current_user.current_tenant_id = "tenant-123" | ||||
| yield { | yield { |
| from unittest.mock import Mock, patch | |||||
| from unittest.mock import Mock, create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from flask_restx import reqparse | from flask_restx import reqparse | ||||
| from werkzeug.exceptions import BadRequest | from werkzeug.exceptions import BadRequest | ||||
| from models.account import Account | |||||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | ||||
| from services.metadata_service import MetadataService | from services.metadata_service import MetadataService | ||||
| mock_metadata_args.name = None | mock_metadata_args.name = None | ||||
| mock_metadata_args.type = "string" | mock_metadata_args.type = "string" | ||||
| with patch("services.metadata_service.current_user") as mock_user: | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| mock_user = create_autospec(Account, instance=True) | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| with patch("services.metadata_service.current_user", mock_user): | |||||
| # Should crash with TypeError | # Should crash with TypeError | ||||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | ||||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | MetadataService.create_metadata("dataset-123", mock_metadata_args) | ||||
| # Test update method as well | # Test update method as well | ||||
| with patch("services.metadata_service.current_user") as mock_user: | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| mock_user = create_autospec(Account, instance=True) | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| with patch("services.metadata_service.current_user", mock_user): | |||||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | ||||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | ||||
| from unittest.mock import Mock, patch | |||||
| from unittest.mock import Mock, create_autospec, patch | |||||
| import pytest | import pytest | ||||
| from flask_restx import reqparse | from flask_restx import reqparse | ||||
| from models.account import Account | |||||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | ||||
| from services.metadata_service import MetadataService | from services.metadata_service import MetadataService | ||||
| mock_metadata_args.name = None # This will cause len() to crash | mock_metadata_args.name = None # This will cause len() to crash | ||||
| mock_metadata_args.type = "string" | mock_metadata_args.type = "string" | ||||
| with patch("services.metadata_service.current_user") as mock_user: | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| mock_user = create_autospec(Account, instance=True) | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| with patch("services.metadata_service.current_user", mock_user): | |||||
| # This should crash with TypeError when calling len(None) | # This should crash with TypeError when calling len(None) | ||||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | ||||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | MetadataService.create_metadata("dataset-123", mock_metadata_args) | ||||
| def test_metadata_service_update_with_none_name_crashes(self): | def test_metadata_service_update_with_none_name_crashes(self): | ||||
| """Test that MetadataService.update_metadata_name crashes when name is None.""" | """Test that MetadataService.update_metadata_name crashes when name is None.""" | ||||
| with patch("services.metadata_service.current_user") as mock_user: | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| mock_user = create_autospec(Account, instance=True) | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| with patch("services.metadata_service.current_user", mock_user): | |||||
| # This should crash with TypeError when calling len(None) | # This should crash with TypeError when calling len(None) | ||||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | ||||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | ||||
| mock_metadata_args.name = None # From args["name"] | mock_metadata_args.name = None # From args["name"] | ||||
| mock_metadata_args.type = None # From args["type"] | mock_metadata_args.type = None # From args["type"] | ||||
| with patch("services.metadata_service.current_user") as mock_user: | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| mock_user = create_autospec(Account, instance=True) | |||||
| mock_user.current_tenant_id = "tenant-123" | |||||
| mock_user.id = "user-456" | |||||
| with patch("services.metadata_service.current_user", mock_user): | |||||
| # Step 4: Service layer crashes on len(None) | # Step 4: Service layer crashes on len(None) | ||||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | ||||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | MetadataService.create_metadata("dataset-123", mock_metadata_args) |
| const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false) | const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false) | ||||
| const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false) | const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false) | ||||
| const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([]) | const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([]) | ||||
| const [showExportWarning, setShowExportWarning] = useState(false) | |||||
| const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ | const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ | ||||
| name, | name, | ||||
| onExport() | onExport() | ||||
| return | return | ||||
| } | } | ||||
| setShowExportWarning(true) | |||||
| } | |||||
| const handleConfirmExport = async () => { | |||||
| if (!appDetail) | |||||
| return | |||||
| setShowExportWarning(false) | |||||
| try { | try { | ||||
| const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) | const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) | ||||
| const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') | const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') | ||||
| onClose={() => setSecretEnvList([])} | onClose={() => setSecretEnvList([])} | ||||
| /> | /> | ||||
| )} | )} | ||||
| {showExportWarning && ( | |||||
| <Confirm | |||||
| type="info" | |||||
| isShow={showExportWarning} | |||||
| title={t('workflow.sidebar.exportWarning')} | |||||
| content={t('workflow.sidebar.exportWarningDesc')} | |||||
| onConfirm={handleConfirmExport} | |||||
| onCancel={() => setShowExportWarning(false)} | |||||
| /> | |||||
| )} | |||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } |
| size?: 'xs' | 's' | 'm' | 'l' | 'xl' | size?: 'xs' | 's' | 'm' | 'l' | 'xl' | ||||
| state?: ActionButtonState | state?: ActionButtonState | ||||
| styleCss?: CSSProperties | styleCss?: CSSProperties | ||||
| ref?: React.Ref<HTMLButtonElement> | |||||
| } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof actionButtonVariants> | } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof actionButtonVariants> | ||||
| function getActionButtonState(state: ActionButtonState) { | function getActionButtonState(state: ActionButtonState) { | ||||
| } | } | ||||
| } | } | ||||
| const ActionButton = React.forwardRef<HTMLButtonElement, ActionButtonProps>( | |||||
| ({ className, size, state = ActionButtonState.Default, styleCss, children, ...props }, ref) => { | |||||
| return ( | |||||
| <button | |||||
| type='button' | |||||
| className={classNames( | |||||
| actionButtonVariants({ className, size }), | |||||
| getActionButtonState(state), | |||||
| )} | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| {...props} | |||||
| > | |||||
| {children} | |||||
| </button> | |||||
| ) | |||||
| }, | |||||
| ) | |||||
| const ActionButton = ({ className, size, state = ActionButtonState.Default, styleCss, children, ref, ...props }: ActionButtonProps) => { | |||||
| return ( | |||||
| <button | |||||
| type='button' | |||||
| className={classNames( | |||||
| actionButtonVariants({ className, size }), | |||||
| getActionButtonState(state), | |||||
| )} | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| {...props} | |||||
| > | |||||
| {children} | |||||
| </button> | |||||
| ) | |||||
| } | |||||
| ActionButton.displayName = 'ActionButton' | ActionButton.displayName = 'ActionButton' | ||||
| export default ActionButton | export default ActionButton |
| loading?: boolean | loading?: boolean | ||||
| styleCss?: CSSProperties | styleCss?: CSSProperties | ||||
| spinnerClassName?: string | spinnerClassName?: string | ||||
| ref?: React.Ref<HTMLButtonElement> | |||||
| } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof buttonVariants> | } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof buttonVariants> | ||||
| const Button = React.forwardRef<HTMLButtonElement, ButtonProps>( | |||||
| ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ...props }, ref) => { | |||||
| return ( | |||||
| <button | |||||
| type='button' | |||||
| className={classNames( | |||||
| buttonVariants({ variant, size, className }), | |||||
| destructive && 'btn-destructive', | |||||
| )} | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| {...props} | |||||
| > | |||||
| {children} | |||||
| {loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />} | |||||
| </button> | |||||
| ) | |||||
| }, | |||||
| ) | |||||
| const Button = ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ref, ...props }: ButtonProps) => { | |||||
| return ( | |||||
| <button | |||||
| type='button' | |||||
| className={classNames( | |||||
| buttonVariants({ variant, size, className }), | |||||
| destructive && 'btn-destructive', | |||||
| )} | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| {...props} | |||||
| > | |||||
| {children} | |||||
| {loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />} | |||||
| </button> | |||||
| ) | |||||
| } | |||||
| Button.displayName = 'Button' | Button.displayName = 'Button' | ||||
| export default Button | export default Button |
| wrapperClassName?: string | wrapperClassName?: string | ||||
| styleCss?: CSSProperties | styleCss?: CSSProperties | ||||
| unit?: string | unit?: string | ||||
| ref?: React.Ref<HTMLInputElement> | |||||
| } & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants> | } & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants> | ||||
| const Input = React.forwardRef<HTMLInputElement, InputProps>(({ | |||||
| const Input = ({ | |||||
| size, | size, | ||||
| disabled, | disabled, | ||||
| destructive, | destructive, | ||||
| placeholder, | placeholder, | ||||
| onChange = noop, | onChange = noop, | ||||
| unit, | unit, | ||||
| ref, | |||||
| ...props | ...props | ||||
| }, ref) => { | |||||
| }: InputProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| return ( | return ( | ||||
| <div className={cn('relative w-full', wrapperClassName)}> | <div className={cn('relative w-full', wrapperClassName)}> | ||||
| } | } | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| }) | |||||
| } | |||||
| Input.displayName = 'Input' | Input.displayName = 'Input' | ||||
| return isMermaidInitialized | return isMermaidInitialized | ||||
| } | } | ||||
| const Flowchart = React.forwardRef((props: { | |||||
| type FlowchartProps = { | |||||
| PrimitiveCode: string | PrimitiveCode: string | ||||
| theme?: 'light' | 'dark' | theme?: 'light' | 'dark' | ||||
| }, ref) => { | |||||
| ref?: React.Ref<HTMLDivElement> | |||||
| } | |||||
| const Flowchart = (props: FlowchartProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const [svgString, setSvgString] = useState<string | null>(null) | const [svgString, setSvgString] = useState<string | null>(null) | ||||
| const [look, setLook] = useState<'classic' | 'handDrawn'>('classic') | const [look, setLook] = useState<'classic' | 'handDrawn'>('classic') | ||||
| } | } | ||||
| return ( | return ( | ||||
| <div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> | |||||
| <div ref={props.ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> | |||||
| <div className={themeClasses.segmented}> | <div className={themeClasses.segmented}> | ||||
| <div className="msh-segmented-group"> | <div className="msh-segmented-group"> | ||||
| <label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1"> | <label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1"> | ||||
| )} | )} | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| }) | |||||
| } | |||||
| Flowchart.displayName = 'Flowchart' | Flowchart.displayName = 'Flowchart' | ||||
| disabled?: boolean | disabled?: boolean | ||||
| destructive?: boolean | destructive?: boolean | ||||
| styleCss?: CSSProperties | styleCss?: CSSProperties | ||||
| ref?: React.Ref<HTMLTextAreaElement> | |||||
| } & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants> | } & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants> | ||||
| const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>( | |||||
| ({ className, value, onChange, disabled, size, destructive, styleCss, ...props }, ref) => { | |||||
| return ( | |||||
| <textarea | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| className={cn( | |||||
| 'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs', | |||||
| textareaVariants({ size }), | |||||
| disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled', | |||||
| destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive', | |||||
| className, | |||||
| )} | |||||
| value={value ?? ''} | |||||
| onChange={onChange} | |||||
| disabled={disabled} | |||||
| {...props} | |||||
| > | |||||
| </textarea> | |||||
| ) | |||||
| }, | |||||
| ) | |||||
| const Textarea = ({ className, value, onChange, disabled, size, destructive, styleCss, ref, ...props }: TextareaProps) => { | |||||
| return ( | |||||
| <textarea | |||||
| ref={ref} | |||||
| style={styleCss} | |||||
| className={cn( | |||||
| 'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs', | |||||
| textareaVariants({ size }), | |||||
| disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled', | |||||
| destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive', | |||||
| className, | |||||
| )} | |||||
| value={value ?? ''} | |||||
| onChange={onChange} | |||||
| disabled={disabled} | |||||
| {...props} | |||||
| > | |||||
| </textarea> | |||||
| ) | |||||
| } | |||||
| Textarea.displayName = 'Textarea' | Textarea.displayName = 'Textarea' | ||||
| export default Textarea | export default Textarea |
| import type { ComponentProps, FC, ReactNode } from 'react' | import type { ComponentProps, FC, ReactNode } from 'react' | ||||
| import { forwardRef } from 'react' | |||||
| import classNames from '@/utils/classnames' | import classNames from '@/utils/classnames' | ||||
| export type PreviewContainerProps = ComponentProps<'div'> & { | export type PreviewContainerProps = ComponentProps<'div'> & { | ||||
| header: ReactNode | header: ReactNode | ||||
| mainClassName?: string | mainClassName?: string | ||||
| ref?: React.Ref<HTMLDivElement> | |||||
| } | } | ||||
| export const PreviewContainer: FC<PreviewContainerProps> = forwardRef((props, ref) => { | |||||
| const { children, className, header, mainClassName, ...rest } = props | |||||
| export const PreviewContainer: FC<PreviewContainerProps> = (props) => { | |||||
| const { children, className, header, mainClassName, ref, ...rest } = props | |||||
| return <div className={className}> | return <div className={className}> | ||||
| <div | <div | ||||
| {...rest} | {...rest} | ||||
| </main> | </main> | ||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| }) | |||||
| } | |||||
| PreviewContainer.displayName = 'PreviewContainer' | PreviewContainer.displayName = 'PreviewContainer' |
| --- | --- | ||||
| <Heading | |||||
| url='/files/:file_id/preview' | |||||
| method='GET' | |||||
| title='File Preview' | |||||
| name='#file-preview' | |||||
| /> | |||||
| <Row> | |||||
| <Col> | |||||
| Preview or download uploaded files. This endpoint allows you to access files that have been previously uploaded via the File Upload API. | |||||
| <i>Files can only be accessed if they belong to messages within the requesting application.</i> | |||||
| ### Path Parameters | |||||
| - `file_id` (string) Required | |||||
| The unique identifier of the file to preview, obtained from the File Upload API response. | |||||
| ### Query Parameters | |||||
| - `as_attachment` (boolean) Optional | |||||
| Whether to force download the file as an attachment. Default is `false` (preview in browser). | |||||
| ### Response | |||||
| Returns the file content with appropriate headers for browser display or download. | |||||
| - `Content-Type` Set based on file mime type | |||||
| - `Content-Length` File size in bytes (if available) | |||||
| - `Content-Disposition` Set to "attachment" if `as_attachment=true` | |||||
| - `Cache-Control` Caching headers for performance | |||||
| - `Accept-Ranges` Set to "bytes" for audio/video files | |||||
| ### Errors | |||||
| - 400, `invalid_param`, abnormal parameter input | |||||
| - 403, `file_access_denied`, file access denied or file does not belong to current application | |||||
| - 404, `file_not_found`, file not found or has been deleted | |||||
| - 500, internal server error | |||||
| </Col> | |||||
| <Col sticky> | |||||
| ### Request Example | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||||
| --header 'Authorization: Bearer {api_key}'`} | |||||
| /> | |||||
| ### Download as Attachment | |||||
| <CodeGroup | |||||
| title="Download Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview?as_attachment=true" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||||
| --header 'Authorization: Bearer {api_key}' \\ | |||||
| --output downloaded_file.png`} | |||||
| /> | |||||
| ### Response Headers Example | |||||
| <CodeGroup title="Response Headers"> | |||||
| ```http {{ title: 'Headers - Image Preview' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| ### Download Response Headers | |||||
| <CodeGroup title="Download Response Headers"> | |||||
| ```http {{ title: 'Headers - File Download' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| </Col> | |||||
| </Row> | |||||
| --- | |||||
| <Heading | <Heading | ||||
| url='/workflows/logs' | url='/workflows/logs' | ||||
| method='GET' | method='GET' |
| --- | --- | ||||
| <Heading | |||||
| url='/files/:file_id/preview' | |||||
| method='GET' | |||||
| title='ファイルプレビュー' | |||||
| name='#file-preview' | |||||
| /> | |||||
| <Row> | |||||
| <Col> | |||||
| アップロードされたファイルをプレビューまたはダウンロードします。このエンドポイントを使用すると、以前にファイルアップロード API でアップロードされたファイルにアクセスできます。 | |||||
| <i>ファイルは、リクエストしているアプリケーションのメッセージ範囲内にある場合のみアクセス可能です。</i> | |||||
| ### パスパラメータ | |||||
| - `file_id` (string) 必須 | |||||
| プレビューするファイルの一意識別子。ファイルアップロード API レスポンスから取得します。 | |||||
| ### クエリパラメータ | |||||
| - `as_attachment` (boolean) オプション | |||||
| ファイルを添付ファイルとして強制ダウンロードするかどうか。デフォルトは `false`(ブラウザでプレビュー)。 | |||||
| ### レスポンス | |||||
| ブラウザ表示またはダウンロード用の適切なヘッダー付きでファイル内容を返します。 | |||||
| - `Content-Type` ファイル MIME タイプに基づいて設定 | |||||
| - `Content-Length` ファイルサイズ(バイト、利用可能な場合) | |||||
| - `Content-Disposition` `as_attachment=true` の場合は "attachment" に設定 | |||||
| - `Cache-Control` パフォーマンス向上のためのキャッシュヘッダー | |||||
| - `Accept-Ranges` 音声/動画ファイルの場合は "bytes" に設定 | |||||
| ### エラー | |||||
| - 400, `invalid_param`, パラメータ入力異常 | |||||
| - 403, `file_access_denied`, ファイルアクセス拒否またはファイルが現在のアプリケーションに属していません | |||||
| - 404, `file_not_found`, ファイルが見つからないか削除されています | |||||
| - 500, サーバー内部エラー | |||||
| </Col> | |||||
| <Col sticky> | |||||
| ### リクエスト例 | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||||
| --header 'Authorization: Bearer {api_key}'`} | |||||
| /> | |||||
| ### 添付ファイルとしてダウンロード | |||||
| <CodeGroup | |||||
| title="Download Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview?as_attachment=true" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||||
| --header 'Authorization: Bearer {api_key}' \\ | |||||
| --output downloaded_file.png`} | |||||
| /> | |||||
| ### レスポンスヘッダー例 | |||||
| <CodeGroup title="Response Headers"> | |||||
| ```http {{ title: 'ヘッダー - 画像プレビュー' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| ### ダウンロードレスポンスヘッダー | |||||
| <CodeGroup title="Download Response Headers"> | |||||
| ```http {{ title: 'ヘッダー - ファイルダウンロード' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| </Col> | |||||
| </Row> | |||||
| --- | |||||
| <Heading | <Heading | ||||
| url='/workflows/logs' | url='/workflows/logs' | ||||
| method='GET' | method='GET' |
| </Row> | </Row> | ||||
| --- | --- | ||||
| <Heading | |||||
| url='/files/:file_id/preview' | |||||
| method='GET' | |||||
| title='文件预览' | |||||
| name='#file-preview' | |||||
| /> | |||||
| <Row> | |||||
| <Col> | |||||
| 预览或下载已上传的文件。此端点允许您访问先前通过文件上传 API 上传的文件。 | |||||
| <i>文件只能在属于请求应用程序的消息范围内访问。</i> | |||||
| ### 路径参数 | |||||
| - `file_id` (string) 必需 | |||||
| 要预览的文件的唯一标识符,从文件上传 API 响应中获得。 | |||||
| ### 查询参数 | |||||
| - `as_attachment` (boolean) 可选 | |||||
| 是否强制将文件作为附件下载。默认为 `false`(在浏览器中预览)。 | |||||
| ### 响应 | |||||
| 返回带有适当浏览器显示或下载标头的文件内容。 | |||||
| - `Content-Type` 根据文件 MIME 类型设置 | |||||
| - `Content-Length` 文件大小(以字节为单位,如果可用) | |||||
| - `Content-Disposition` 如果 `as_attachment=true` 则设置为 "attachment" | |||||
| - `Cache-Control` 用于性能的缓存标头 | |||||
| - `Accept-Ranges` 对于音频/视频文件设置为 "bytes" | |||||
| ### 错误 | |||||
| - 400, `invalid_param`, 参数输入异常 | |||||
| - 403, `file_access_denied`, 文件访问被拒绝或文件不属于当前应用程序 | |||||
| - 404, `file_not_found`, 文件未找到或已被删除 | |||||
| - 500, 服务内部错误 | |||||
| </Col> | |||||
| <Col sticky> | |||||
| ### 请求示例 | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\ | |||||
| --header 'Authorization: Bearer {api_key}'`} | |||||
| /> | |||||
| ### 作为附件下载 | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="GET" | |||||
| label="/files/:file_id/preview?as_attachment=true" | |||||
| targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\ | |||||
| --header 'Authorization: Bearer {api_key}' \\ | |||||
| --output downloaded_file.png`} | |||||
| /> | |||||
| ### 响应标头示例 | |||||
| <CodeGroup title="Response Headers"> | |||||
| ```http {{ title: 'Headers - 图片预览' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| ### 文件下载响应标头 | |||||
| <CodeGroup title="Download Response Headers"> | |||||
| ```http {{ title: 'Headers - 文件下载' }} | |||||
| Content-Type: image/png | |||||
| Content-Length: 1024 | |||||
| Content-Disposition: attachment; filename*=UTF-8''example.png | |||||
| Cache-Control: public, max-age=3600 | |||||
| ``` | |||||
| </CodeGroup> | |||||
| </Col> | |||||
| </Row> | |||||
| --- | |||||
| <Heading | <Heading | ||||
| url='/workflows/logs' | url='/workflows/logs' | ||||
| method='GET' | method='GET' |
| 'use client' | 'use client' | ||||
| import type { ForwardRefRenderFunction } from 'react' | |||||
| import { useImperativeHandle } from 'react' | import { useImperativeHandle } from 'react' | ||||
| import React, { useCallback, useEffect, useMemo, useState } from 'react' | import React, { useCallback, useEffect, useMemo, useState } from 'react' | ||||
| import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types' | import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types' | ||||
| onDeSelectAll: () => void | onDeSelectAll: () => void | ||||
| onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void | onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void | ||||
| isFromMarketPlace?: boolean | isFromMarketPlace?: boolean | ||||
| ref?: React.Ref<ExposeRefs> | |||||
| } | } | ||||
| export type ExposeRefs = { | export type ExposeRefs = { | ||||
| deSelectAllPlugins: () => void | deSelectAllPlugins: () => void | ||||
| } | } | ||||
| const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({ | |||||
| const InstallByDSLList = ({ | |||||
| allPlugins, | allPlugins, | ||||
| selectedPlugins, | selectedPlugins, | ||||
| onSelect, | onSelect, | ||||
| onDeSelectAll, | onDeSelectAll, | ||||
| onLoadedAllPlugin, | onLoadedAllPlugin, | ||||
| isFromMarketPlace, | isFromMarketPlace, | ||||
| }, ref) => { | |||||
| ref, | |||||
| }: Props) => { | |||||
| const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) | const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) | ||||
| // DSL has id, to get plugin info to show more info | // DSL has id, to get plugin info to show more info | ||||
| const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => { | const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => { | ||||
| </> | </> | ||||
| ) | ) | ||||
| } | } | ||||
| export default React.forwardRef(InstallByDSLList) | |||||
| export default InstallByDSLList |
| }, [showSearchParams, handleActivePluginTypeChange]) | }, [showSearchParams, handleActivePluginTypeChange]) | ||||
| useEffect(() => { | useEffect(() => { | ||||
| window.addEventListener('popstate', () => { | |||||
| handlePopState() | |||||
| }) | |||||
| window.addEventListener('popstate', handlePopState) | |||||
| return () => { | return () => { | ||||
| window.removeEventListener('popstate', handlePopState) | window.removeEventListener('popstate', handlePopState) | ||||
| } | } |
| 'use client' | 'use client' | ||||
| import React, { forwardRef, useEffect, useImperativeHandle, useMemo, useRef } from 'react' | |||||
| import React, { useEffect, useImperativeHandle, useMemo, useRef } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll' | import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll' | ||||
| import Item from './item' | import Item from './item' | ||||
| tags: string[] | tags: string[] | ||||
| toolContentClassName?: string | toolContentClassName?: string | ||||
| disableMaxWidth?: boolean | disableMaxWidth?: boolean | ||||
| ref?: React.Ref<ListRef> | |||||
| } | } | ||||
| export type ListRef = { handleScroll: () => void } | export type ListRef = { handleScroll: () => void } | ||||
| const List = forwardRef<ListRef, ListProps>(({ | |||||
| const List = ({ | |||||
| wrapElemRef, | wrapElemRef, | ||||
| searchText, | searchText, | ||||
| tags, | tags, | ||||
| list, | list, | ||||
| toolContentClassName, | toolContentClassName, | ||||
| disableMaxWidth = false, | disableMaxWidth = false, | ||||
| }, ref) => { | |||||
| ref, | |||||
| }: ListProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const hasFilter = !searchText | const hasFilter = !searchText | ||||
| const hasRes = list.length > 0 | const hasRes = list.length > 0 | ||||
| </div> | </div> | ||||
| </> | </> | ||||
| ) | ) | ||||
| }) | |||||
| } | |||||
| List.displayName = 'List' | List.displayName = 'List' | ||||
| noLastRunFound: 'Kein vorheriger Lauf gefunden', | noLastRunFound: 'Kein vorheriger Lauf gefunden', | ||||
| lastOutput: 'Letzte Ausgabe', | lastOutput: 'Letzte Ausgabe', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Aktuelle gespeicherte Version exportieren', | |||||
| exportWarningDesc: 'Dies wird die derzeit gespeicherte Version Ihres Workflows exportieren. Wenn Sie ungespeicherte Änderungen im Editor haben, speichern Sie diese bitte zuerst, indem Sie die Exportoption im Workflow-Canvas verwenden.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| export: 'Export DSL with secret values ', | export: 'Export DSL with secret values ', | ||||
| }, | }, | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Export Current Saved Version', | |||||
| exportWarningDesc: 'This will export the current saved version of your workflow. If you have unsaved changes in the editor, please save them first by using the export option in the workflow canvas.', | |||||
| }, | |||||
| chatVariable: { | chatVariable: { | ||||
| panelTitle: 'Conversation Variables', | panelTitle: 'Conversation Variables', | ||||
| panelDescription: 'Conversation Variables are used to store interactive information that LLM needs to remember, including conversation history, uploaded files, user preferences. They are read-write. ', | panelDescription: 'Conversation Variables are used to store interactive information that LLM needs to remember, including conversation history, uploaded files, user preferences. They are read-write. ', |
| noMatchingInputsFound: 'No se encontraron entradas coincidentes de la última ejecución.', | noMatchingInputsFound: 'No se encontraron entradas coincidentes de la última ejecución.', | ||||
| lastOutput: 'Última salida', | lastOutput: 'Última salida', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Exportar la versión guardada actual', | |||||
| exportWarningDesc: 'Esto exportará la versión guardada actual de tu flujo de trabajo. Si tienes cambios no guardados en el editor, guárdalos primero utilizando la opción de exportar en el lienzo del flujo de trabajo.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'نتوانستم ورودیهای آخرین اجرای را کپی کنم', | copyLastRunError: 'نتوانستم ورودیهای آخرین اجرای را کپی کنم', | ||||
| lastOutput: 'آخرین خروجی', | lastOutput: 'آخرین خروجی', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'صادرات نسخه ذخیره شده فعلی', | |||||
| exportWarningDesc: 'این نسخه فعلی ذخیره شده از کار خود را صادر خواهد کرد. اگر تغییرات غیرذخیره شدهای در ویرایشگر دارید، لطفاً ابتدا از گزینه صادرات در بوم کار برای ذخیره آنها استفاده کنید.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'Échec de la copie des entrées de la dernière exécution', | copyLastRunError: 'Échec de la copie des entrées de la dernière exécution', | ||||
| lastOutput: 'Dernière sortie', | lastOutput: 'Dernière sortie', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Exporter la version enregistrée actuelle', | |||||
| exportWarningDesc: 'Cela exportera la version actuelle enregistrée de votre flux de travail. Si vous avez des modifications non enregistrées dans l\'éditeur, veuillez d\'abord les enregistrer en utilisant l\'option d\'exportation dans le canevas du flux de travail.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'अंतिम रन इनपुट को कॉपी करने में विफल', | copyLastRunError: 'अंतिम रन इनपुट को कॉपी करने में विफल', | ||||
| lastOutput: 'अंतिम आउटपुट', | lastOutput: 'अंतिम आउटपुट', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'वर्तमान सहेजी गई संस्करण निर्यात करें', | |||||
| exportWarningDesc: 'यह आपके कार्यप्रवाह का वर्तमान सहेजा हुआ संस्करण निर्यात करेगा। यदि आपके संपादक में कोई असहेजा किए गए परिवर्तन हैं, तो कृपया पहले उन्हें सहेजें, कार्यप्रवाह कैनवास में निर्यात विकल्प का उपयोग करके।', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| lastOutput: 'Keluaran Terakhir', | lastOutput: 'Keluaran Terakhir', | ||||
| noLastRunFound: 'Tidak ada eksekusi sebelumnya ditemukan', | noLastRunFound: 'Tidak ada eksekusi sebelumnya ditemukan', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Ekspor Versi Tersimpan Saat Ini', | |||||
| exportWarningDesc: 'Ini akan mengekspor versi terkini dari alur kerja Anda yang telah disimpan. Jika Anda memiliki perubahan yang belum disimpan di editor, harap simpan terlebih dahulu dengan menggunakan opsi ekspor di kanvas alur kerja.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| noLastRunFound: 'Nessuna esecuzione precedente trovata', | noLastRunFound: 'Nessuna esecuzione precedente trovata', | ||||
| lastOutput: 'Ultimo output', | lastOutput: 'Ultimo output', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Esporta la versione salvata corrente', | |||||
| exportWarningDesc: 'Questo exporterà l\'attuale versione salvata del tuo flusso di lavoro. Se hai modifiche non salvate nell\'editor, ti preghiamo di salvarle prima utilizzando l\'opzione di esportazione nel canvas del flusso di lavoro.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| export: 'シークレット値付きでエクスポート', | export: 'シークレット値付きでエクスポート', | ||||
| }, | }, | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: '現在保存されているバージョンをエクスポート', | |||||
| exportWarningDesc: 'これは現在保存されているワークフローのバージョンをエクスポートします。エディターで未保存の変更がある場合は、まずワークフローキャンバスのエクスポートオプションを使用して保存してください。', | |||||
| }, | |||||
| chatVariable: { | chatVariable: { | ||||
| panelTitle: '会話変数', | panelTitle: '会話変数', | ||||
| panelDescription: '対話情報を保存・管理(会話履歴/ファイル/ユーザー設定など)。書き換えができます。', | panelDescription: '対話情報を保存・管理(会話履歴/ファイル/ユーザー設定など)。書き換えができます。', |
| copyLastRunError: '마지막 실행 입력을 복사하는 데 실패했습니다.', | copyLastRunError: '마지막 실행 입력을 복사하는 데 실패했습니다.', | ||||
| lastOutput: '마지막 출력', | lastOutput: '마지막 출력', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: '현재 저장된 버전 내보내기', | |||||
| exportWarningDesc: '이 작업은 현재 저장된 워크플로우 버전을 내보냅니다. 편집기에서 저장되지 않은 변경 사항이 있는 경우, 먼저 워크플로우 캔버스의 내보내기 옵션을 사용하여 저장해 주세요.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'Nie udało się skopiować danych wejściowych z ostatniego uruchomienia', | copyLastRunError: 'Nie udało się skopiować danych wejściowych z ostatniego uruchomienia', | ||||
| lastOutput: 'Ostatni wynik', | lastOutput: 'Ostatni wynik', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Eksportuj obecną zapisaną wersję', | |||||
| exportWarningDesc: 'To wyeksportuje aktualnie zapisaną wersję twojego przepływu pracy. Jeśli masz niesave\'owane zmiany w edytorze, najpierw je zapisz, korzystając z opcji eksportu w kanwie przepływu pracy.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRun: 'Copiar Última Execução', | copyLastRun: 'Copiar Última Execução', | ||||
| lastOutput: 'Última Saída', | lastOutput: 'Última Saída', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Exportar a versão salva atual', | |||||
| exportWarningDesc: 'Isto irá exportar a versão atual salva do seu fluxo de trabalho. Se você tiver alterações não salvas no editor, por favor, salve-as primeiro utilizando a opção de exportação na tela do fluxo de trabalho.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'Nu s-au putut copia ultimele intrări de rulare', | copyLastRunError: 'Nu s-au putut copia ultimele intrări de rulare', | ||||
| lastOutput: 'Ultimul rezultat', | lastOutput: 'Ultimul rezultat', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Exportați versiunea salvată curentă', | |||||
| exportWarningDesc: 'Aceasta va exporta versiunea curent salvată a fluxului dumneavoastră de lucru. Dacă aveți modificări nesalvate în editor, vă rugăm să le salvați mai întâi utilizând opțiunea de export din canvasul fluxului de lucru.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| noMatchingInputsFound: 'Не найдено соответствующих входных данных из последнего запуска.', | noMatchingInputsFound: 'Не найдено соответствующих входных данных из последнего запуска.', | ||||
| lastOutput: 'Последний вывод', | lastOutput: 'Последний вывод', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Экспортировать текущую сохранённую версию', | |||||
| exportWarningDesc: 'Это экспортирует текущую сохранённую версию вашего рабочего процесса. Если у вас есть несохранённые изменения в редакторе, сначала сохраните их с помощью опции экспорта на полотне рабочего процесса.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| noMatchingInputsFound: 'Ni podatkov, ki bi ustrezali prejšnjemu zagonu', | noMatchingInputsFound: 'Ni podatkov, ki bi ustrezali prejšnjemu zagonu', | ||||
| lastOutput: 'Nazadnje izhod', | lastOutput: 'Nazadnje izhod', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Izvozi trenutna shranjena različica', | |||||
| exportWarningDesc: 'To bo izvozilo trenutno shranjeno različico vašega delovnega toka. Če imate neshranjene spremembe v urejevalniku, jih najprej shranite z uporabo možnosti izvoza na platnu delovnega toka.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| noMatchingInputsFound: 'ไม่พบข้อมูลที่ตรงกันจากการรันครั้งล่าสุด', | noMatchingInputsFound: 'ไม่พบข้อมูลที่ตรงกันจากการรันครั้งล่าสุด', | ||||
| lastOutput: 'ผลลัพธ์สุดท้าย', | lastOutput: 'ผลลัพธ์สุดท้าย', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'ส่งออกเวอร์ชันที่บันทึกปัจจุบัน', | |||||
| exportWarningDesc: 'นี่จะส่งออกเวอร์ชันที่บันทึกไว้ปัจจุบันของเวิร์กโฟลว์ของคุณ หากคุณมีการเปลี่ยนแปลงที่ยังไม่ได้บันทึกในแก้ไข กรุณาบันทึกมันก่อนโดยใช้ตัวเลือกส่งออกในผืนผ้าใบเวิร์กโฟลว์', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'Son çalışma girdilerini kopyalamak başarısız oldu.', | copyLastRunError: 'Son çalışma girdilerini kopyalamak başarısız oldu.', | ||||
| lastOutput: 'Son Çıktı', | lastOutput: 'Son Çıktı', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Mevcut Kaydedilmiş Versiyonu Dışa Aktar', | |||||
| exportWarningDesc: 'Bu, çalışma akışınızın mevcut kaydedilmiş sürümünü dışa aktaracaktır. Editörde kaydedilmemiş değişiklikleriniz varsa, lütfen önce bunları çalışma akışı alanındaki dışa aktarma seçeneğini kullanarak kaydedin.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| noMatchingInputsFound: 'Не знайдено відповідних вхідних даних з останнього запуску', | noMatchingInputsFound: 'Не знайдено відповідних вхідних даних з останнього запуску', | ||||
| lastOutput: 'Останній вихід', | lastOutput: 'Останній вихід', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Експортувати поточну збережену версію', | |||||
| exportWarningDesc: 'Це експортує поточну збережену версію вашого робочого процесу. Якщо у вас є незбережені зміни в редакторі, будь ласка, спочатку збережіть їх, використовуючи опцію експорту на полотні робочого процесу.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| copyLastRunError: 'Không thể sao chép đầu vào của lần chạy trước', | copyLastRunError: 'Không thể sao chép đầu vào của lần chạy trước', | ||||
| lastOutput: 'Đầu ra cuối cùng', | lastOutput: 'Đầu ra cuối cùng', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: 'Xuất Phiên Bản Đã Lưu Hiện Tại', | |||||
| exportWarningDesc: 'Điều này sẽ xuất phiên bản hiện tại đã được lưu của quy trình làm việc của bạn. Nếu bạn có những thay đổi chưa được lưu trong trình soạn thảo, vui lòng lưu chúng trước bằng cách sử dụng tùy chọn xuất trong bản vẽ quy trình.', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| export: '导出包含 Secret 值的 DSL', | export: '导出包含 Secret 值的 DSL', | ||||
| }, | }, | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: '导出当前已保存版本', | |||||
| exportWarningDesc: '这将导出您工作流的当前已保存版本。如果您在编辑器中有未保存的更改,请先使用工作流画布中的导出选项保存它们。', | |||||
| }, | |||||
| chatVariable: { | chatVariable: { | ||||
| panelTitle: '会话变量', | panelTitle: '会话变量', | ||||
| panelDescription: '会话变量用于存储 LLM 需要的上下文信息,如用户偏好、对话历史等。它是可读写的。', | panelDescription: '会话变量用于存储 LLM 需要的上下文信息,如用户偏好、对话历史等。它是可读写的。', |
| noLastRunFound: '沒有找到之前的運行', | noLastRunFound: '沒有找到之前的運行', | ||||
| lastOutput: '最後的輸出', | lastOutput: '最後的輸出', | ||||
| }, | }, | ||||
| sidebar: { | |||||
| exportWarning: '導出當前保存的版本', | |||||
| exportWarningDesc: '這將導出當前保存的工作流程版本。如果您在編輯器中有未保存的更改,請先通過使用工作流程畫布中的導出選項來保存它們。', | |||||
| }, | |||||
| } | } | ||||
| export default translation | export default translation |
| (()=>{"use strict";self.fallback=async e=>"document"===e.destination?caches.match("/_offline.html",{ignoreSearch:!0}):Response.error()})(); |
| import { DataType } from '@/app/components/datasets/metadata/types' | |||||
| import { act, renderHook } from '@testing-library/react' | |||||
| import { QueryClient, QueryClientProvider } from '@tanstack/react-query' | |||||
| import { useBatchUpdateDocMetadata } from '@/service/knowledge/use-metadata' | |||||
| import { useDocumentListKey } from './use-document' | |||||
| // Mock the post function to avoid real network requests | |||||
| jest.mock('@/service/base', () => ({ | |||||
| post: jest.fn().mockResolvedValue({ success: true }), | |||||
| })) | |||||
| const NAME_SPACE = 'dataset-metadata' | |||||
| describe('useBatchUpdateDocMetadata', () => { | |||||
| let queryClient: QueryClient | |||||
| beforeEach(() => { | |||||
| // Create a fresh QueryClient before each test | |||||
| queryClient = new QueryClient() | |||||
| }) | |||||
| // Wrapper for React Query context | |||||
| const wrapper = ({ children }: { children: React.ReactNode }) => ( | |||||
| <QueryClientProvider client={queryClient}>{children}</QueryClientProvider> | |||||
| ) | |||||
| it('should correctly invalidate dataset and document caches', async () => { | |||||
| const { result } = renderHook(() => useBatchUpdateDocMetadata(), { wrapper }) | |||||
| // Spy on queryClient.invalidateQueries | |||||
| const invalidateSpy = jest.spyOn(queryClient, 'invalidateQueries') | |||||
| // Correct payload type: each document has its own metadata_list array | |||||
| const payload = { | |||||
| dataset_id: 'dataset-1', | |||||
| metadata_list: [ | |||||
| { | |||||
| document_id: 'doc-1', | |||||
| metadata_list: [ | |||||
| { key: 'title-1', id: '01', name: 'name-1', type: DataType.string, value: 'new title 01' }, | |||||
| ], | |||||
| }, | |||||
| { | |||||
| document_id: 'doc-2', | |||||
| metadata_list: [ | |||||
| { key: 'title-2', id: '02', name: 'name-1', type: DataType.string, value: 'new title 02' }, | |||||
| ], | |||||
| }, | |||||
| ], | |||||
| } | |||||
| // Execute the mutation | |||||
| await act(async () => { | |||||
| await result.current.mutateAsync(payload) | |||||
| }) | |||||
| // Expect invalidateQueries to have been called exactly 5 times | |||||
| expect(invalidateSpy).toHaveBeenCalledTimes(5) | |||||
| // Dataset cache invalidation | |||||
| expect(invalidateSpy).toHaveBeenNthCalledWith(1, { | |||||
| queryKey: [NAME_SPACE, 'dataset', 'dataset-1'], | |||||
| }) | |||||
| // Document list cache invalidation | |||||
| expect(invalidateSpy).toHaveBeenNthCalledWith(2, { | |||||
| queryKey: [NAME_SPACE, 'document', 'dataset-1'], | |||||
| }) | |||||
| // useDocumentListKey cache invalidation | |||||
| expect(invalidateSpy).toHaveBeenNthCalledWith(3, { | |||||
| queryKey: [...useDocumentListKey, 'dataset-1'], | |||||
| }) | |||||
| // Single document cache invalidation | |||||
| expect(invalidateSpy.mock.calls.slice(3)).toEqual( | |||||
| expect.arrayContaining([ | |||||
| [{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-1'] }], | |||||
| [{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-2'] }], | |||||
| ]), | |||||
| ) | |||||
| }) | |||||
| }) |
| }) | }) | ||||
| // meta data in document list | // meta data in document list | ||||
| await queryClient.invalidateQueries({ | await queryClient.invalidateQueries({ | ||||
| queryKey: [NAME_SPACE, 'dataset', payload.dataset_id], | |||||
| queryKey: [NAME_SPACE, 'document', payload.dataset_id], | |||||
| }) | }) | ||||
| await queryClient.invalidateQueries({ | await queryClient.invalidateQueries({ | ||||
| queryKey: [...useDocumentListKey, payload.dataset_id], | queryKey: [...useDocumentListKey, payload.dataset_id], |