Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.0
| @@ -54,7 +54,6 @@ def initialize_extensions(app: DifyApp): | |||
| ext_otel, | |||
| ext_proxy_fix, | |||
| ext_redis, | |||
| ext_repositories, | |||
| ext_sentry, | |||
| ext_set_secretkey, | |||
| ext_storage, | |||
| @@ -75,7 +74,6 @@ def initialize_extensions(app: DifyApp): | |||
| ext_migrate, | |||
| ext_redis, | |||
| ext_storage, | |||
| ext_repositories, | |||
| ext_celery, | |||
| ext_login, | |||
| ext_mail, | |||
| @@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from core.workflow.repository import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -163,12 +163,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -231,12 +229,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -297,12 +293,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -9,7 +9,6 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| @@ -58,7 +57,7 @@ from core.app.entities.task_entities import ( | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| @@ -66,6 +65,7 @@ from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models import Conversation, EndUser, Message, MessageFile | |||
| @@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| else: | |||
| raise NotImplementedError(f"User type not supported: {type(user)}") | |||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.QUERY: message.query, | |||
| @@ -18,13 +18,13 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager | |||
| from core.app.apps.workflow.app_runner import WorkflowAppRunner | |||
| from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.repository import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models import Account, App, EndUser, Workflow | |||
| @@ -138,12 +138,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -264,12 +262,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -329,12 +325,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| # Create workflow node execution repository | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||
| "app_id": application_generate_entity.app_config.app_id, | |||
| "session_factory": session_factory, | |||
| } | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, | |||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| return self._generate( | |||
| @@ -9,7 +9,6 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AgentChatAppGenerateEntity, | |||
| @@ -45,6 +44,7 @@ from core.app.entities.task_entities import ( | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| from core.model_runtime.entities.message_entities import ( | |||
| @@ -0,0 +1 @@ | |||
| # Core base package | |||
| @@ -0,0 +1,6 @@ | |||
| from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| __all__ = [ | |||
| "AppGeneratorTTSPublisher", | |||
| "AudioTrunk", | |||
| ] | |||
| @@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( | |||
| UnitEnum, | |||
| ) | |||
| from core.ops.utils import filter_none_values | |||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser | |||
| @@ -113,8 +113,8 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory}, | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( | |||
| LangSmithRunUpdateModel, | |||
| ) | |||
| from core.ops.utils import filter_none_values, generate_dotted_order | |||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser, MessageFile | |||
| @@ -137,12 +137,8 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": trace_info.tenant_id, | |||
| "app_id": trace_info.metadata.get("app_id"), | |||
| "session_factory": session_factory, | |||
| }, | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( | |||
| TraceTaskName, | |||
| WorkflowTraceInfo, | |||
| ) | |||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from models.model import EndUser, MessageFile | |||
| @@ -150,12 +150,8 @@ class OpikDataTrace(BaseTraceInstance): | |||
| # through workflow_run_id get all_nodes_execution using repository | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": trace_info.tenant_id, | |||
| "app_id": trace_info.metadata.get("app_id"), | |||
| "session_factory": session_factory, | |||
| }, | |||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||
| ) | |||
| # Get all executions for this workflow run | |||
| @@ -4,3 +4,9 @@ Repository implementations for data access. | |||
| This package contains concrete implementations of the repository interfaces | |||
| defined in the core.workflow.repository package. | |||
| """ | |||
| from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||
| ] | |||
| @@ -1,87 +0,0 @@ | |||
| """ | |||
| Registry for repository implementations. | |||
| This module is responsible for registering factory functions with the repository factory. | |||
| """ | |||
| import logging | |||
| from collections.abc import Mapping | |||
| from typing import Any | |||
| from sqlalchemy.orm import sessionmaker | |||
| from configs import dify_config | |||
| from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||
| from extensions.ext_database import db | |||
| logger = logging.getLogger(__name__) | |||
| # Storage type constants | |||
| STORAGE_TYPE_RDBMS = "rdbms" | |||
| STORAGE_TYPE_HYBRID = "hybrid" | |||
| def register_repositories() -> None: | |||
| """ | |||
| Register repository factory functions with the RepositoryFactory. | |||
| This function reads configuration settings to determine which repository | |||
| implementations to register. | |||
| """ | |||
| # Configure WorkflowNodeExecutionRepository factory based on configuration | |||
| workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE | |||
| # Check storage type and register appropriate implementation | |||
| if workflow_node_execution_storage == STORAGE_TYPE_RDBMS: | |||
| # Register SQLAlchemy implementation for RDBMS storage | |||
| logger.info("Registering WorkflowNodeExecution repository with RDBMS storage") | |||
| RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository) | |||
| elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID: | |||
| # Hybrid storage is not yet implemented | |||
| raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented") | |||
| else: | |||
| # Unknown storage type | |||
| raise ValueError( | |||
| f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. " | |||
| f"Supported types: {STORAGE_TYPE_RDBMS}" | |||
| ) | |||
| def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository: | |||
| """ | |||
| Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation. | |||
| This factory function creates a repository for the RDBMS storage type. | |||
| Args: | |||
| params: Parameters for creating the repository, including: | |||
| - tenant_id: Required. The tenant ID for multi-tenancy. | |||
| - app_id: Optional. The application ID for filtering. | |||
| - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided, | |||
| a new sessionmaker will be created using the global database engine. | |||
| Returns: | |||
| A WorkflowNodeExecutionRepository instance | |||
| Raises: | |||
| ValueError: If required parameters are missing | |||
| """ | |||
| # Extract required parameters | |||
| tenant_id = params.get("tenant_id") | |||
| if tenant_id is None: | |||
| raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage") | |||
| # Extract optional parameters | |||
| app_id = params.get("app_id") | |||
| # Use the session_factory from params if provided, otherwise create one using the global db engine | |||
| session_factory = params.get("session_factory") | |||
| if session_factory is None: | |||
| # Create a sessionmaker using the same engine as the global db session | |||
| session_factory = sessionmaker(bind=db.engine) | |||
| # Create and return the repository | |||
| return SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=session_factory, tenant_id=tenant_id, app_id=app_id | |||
| ) | |||
| @@ -10,13 +10,13 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| class SQLAlchemyWorkflowNodeExecutionRepository: | |||
| class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): | |||
| """ | |||
| SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. | |||
| @@ -1,9 +0,0 @@ | |||
| """ | |||
| WorkflowNodeExecution repository implementations. | |||
| """ | |||
| from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||
| ] | |||
| @@ -6,10 +6,9 @@ for accessing and manipulating data, regardless of the underlying | |||
| storage mechanism. | |||
| """ | |||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "RepositoryFactory", | |||
| "OrderConfig", | |||
| "WorkflowNodeExecutionRepository", | |||
| ] | |||
| @@ -1,97 +0,0 @@ | |||
| """ | |||
| Repository factory for creating repository instances. | |||
| This module provides a simple factory interface for creating repository instances. | |||
| It does not contain any implementation details or dependencies on specific repositories. | |||
| """ | |||
| from collections.abc import Callable, Mapping | |||
| from typing import Any, Literal, Optional, cast | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| # Type for factory functions - takes a dict of parameters and returns any repository type | |||
| RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] | |||
| # Type for workflow node execution factory function | |||
| WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] | |||
| # Repository type literals | |||
| _RepositoryType = Literal["workflow_node_execution"] | |||
| class RepositoryFactory: | |||
| """ | |||
| Factory class for creating repository instances. | |||
| This factory delegates the actual repository creation to implementation-specific | |||
| factory functions that are registered with the factory at runtime. | |||
| """ | |||
| # Dictionary to store factory functions | |||
| _factory_functions: dict[str, RepositoryFactoryFunc] = {} | |||
| @classmethod | |||
| def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None: | |||
| """ | |||
| Register a factory function for a specific repository type. | |||
| This is a private method and should not be called directly. | |||
| Args: | |||
| repository_type: The type of repository (e.g., 'workflow_node_execution') | |||
| factory_func: A function that takes parameters and returns a repository instance | |||
| """ | |||
| cls._factory_functions[repository_type] = factory_func | |||
| @classmethod | |||
| def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: | |||
| """ | |||
| Create a new repository instance with the provided parameters. | |||
| This is a private method and should not be called directly. | |||
| Args: | |||
| repository_type: The type of repository to create | |||
| params: A dictionary of parameters to pass to the factory function | |||
| Returns: | |||
| A new instance of the requested repository | |||
| Raises: | |||
| ValueError: If no factory function is registered for the repository type | |||
| """ | |||
| if repository_type not in cls._factory_functions: | |||
| raise ValueError(f"No factory function registered for repository type '{repository_type}'") | |||
| # Use empty dict if params is None | |||
| params = params or {} | |||
| return cls._factory_functions[repository_type](params) | |||
| @classmethod | |||
| def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None: | |||
| """ | |||
| Register a factory function for the workflow node execution repository. | |||
| Args: | |||
| factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance | |||
| """ | |||
| cls._register_factory("workflow_node_execution", factory_func) | |||
| @classmethod | |||
| def create_workflow_node_execution_repository( | |||
| cls, params: Optional[Mapping[str, Any]] = None | |||
| ) -> WorkflowNodeExecutionRepository: | |||
| """ | |||
| Create a new WorkflowNodeExecutionRepository instance with the provided parameters. | |||
| Args: | |||
| params: A dictionary of parameters to pass to the factory function | |||
| Returns: | |||
| A new instance of the WorkflowNodeExecutionRepository | |||
| Raises: | |||
| ValueError: If no factory function is registered for the workflow_node_execution repository type | |||
| """ | |||
| # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc | |||
| return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params)) | |||
| @@ -6,7 +6,6 @@ from typing import Optional, Union | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import ( | |||
| InvokeFrom, | |||
| @@ -52,10 +51,11 @@ from core.app.entities.task_entities import ( | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| @@ -102,7 +102,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| else: | |||
| raise ValueError(f"Invalid user type: {type(user)}") | |||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| @@ -69,7 +69,7 @@ from models.workflow import ( | |||
| ) | |||
| class WorkflowCycleManage: | |||
| class WorkflowCycleManager: | |||
| def __init__( | |||
| self, | |||
| *, | |||
| @@ -1,18 +0,0 @@ | |||
| """ | |||
| Extension for initializing repositories. | |||
| This extension registers repository implementations with the RepositoryFactory. | |||
| """ | |||
| from core.repositories.repository_registry import register_repositories | |||
| from dify_app import DifyApp | |||
| def init_app(_app: DifyApp) -> None: | |||
| """ | |||
| Initialize repository implementations. | |||
| Args: | |||
| _app: The Flask application instance (unused) | |||
| """ | |||
| register_repositories() | |||
| @@ -2,7 +2,7 @@ import threading | |||
| from typing import Optional | |||
| import contexts | |||
| from core.workflow.repository import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||
| from extensions.ext_database import db | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| @@ -129,12 +129,8 @@ class WorkflowRunService: | |||
| return [] | |||
| # Use the repository to get the node executions | |||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": app_model.tenant_id, | |||
| "app_id": app_model.id, | |||
| "session_factory": db.session.get_bind(), | |||
| } | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||
| ) | |||
| # Use the repository to get the node executions with ordering | |||
| @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.variables import Variable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| @@ -21,7 +22,6 @@ from core.workflow.nodes.enums import ErrorStrategy | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.event.types import NodeEvent | |||
| from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.repository import RepositoryFactory | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | |||
| from extensions.ext_database import db | |||
| @@ -285,12 +285,8 @@ class WorkflowService: | |||
| workflow_node_execution.workflow_id = draft_workflow.id | |||
| # Use the repository to save the workflow node execution | |||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": app_model.tenant_id, | |||
| "app_id": app_model.id, | |||
| "session_factory": db.session.get_bind(), | |||
| } | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||
| ) | |||
| repository.save(workflow_node_execution) | |||
| @@ -7,7 +7,7 @@ from celery import shared_task # type: ignore | |||
| from sqlalchemy import delete | |||
| from sqlalchemy.exc import SQLAlchemyError | |||
| from core.workflow.repository import RepositoryFactory | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from extensions.ext_database import db | |||
| from models.dataset import AppDatasetJoin | |||
| from models.model import ( | |||
| @@ -189,12 +189,8 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): | |||
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | |||
| # Create a repository instance for WorkflowNodeExecution | |||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||
| params={ | |||
| "tenant_id": tenant_id, | |||
| "app_id": app_id, | |||
| "session_factory": db.session.get_bind(), | |||
| } | |||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||
| session_factory=db.engine, tenant_id=tenant_id, app_id=app_id | |||
| ) | |||
| # Use the clear method to delete all records for this tenant_id and app_id | |||
| @@ -0,0 +1,348 @@ | |||
| import json | |||
| import time | |||
| from datetime import UTC, datetime | |||
| from unittest.mock import MagicMock, patch | |||
| import pytest | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.queue_entities import ( | |||
| QueueNodeFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| ) | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||
| from models.enums import CreatedByRole | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| @pytest.fixture | |||
| def mock_app_generate_entity(): | |||
| entity = MagicMock(spec=AdvancedChatAppGenerateEntity) | |||
| entity.inputs = {"query": "test query"} | |||
| entity.invoke_from = InvokeFrom.WEB_APP | |||
| # Create app_config as a separate mock | |||
| app_config = MagicMock() | |||
| app_config.tenant_id = "test-tenant-id" | |||
| app_config.app_id = "test-app-id" | |||
| entity.app_config = app_config | |||
| return entity | |||
| @pytest.fixture | |||
| def mock_workflow_system_variables(): | |||
| return { | |||
| SystemVariableKey.QUERY: "test query", | |||
| SystemVariableKey.CONVERSATION_ID: "test-conversation-id", | |||
| SystemVariableKey.USER_ID: "test-user-id", | |||
| SystemVariableKey.APP_ID: "test-app-id", | |||
| SystemVariableKey.WORKFLOW_ID: "test-workflow-id", | |||
| SystemVariableKey.WORKFLOW_RUN_ID: "test-workflow-run-id", | |||
| } | |||
| @pytest.fixture | |||
| def mock_node_execution_repository(): | |||
| repo = MagicMock(spec=WorkflowNodeExecutionRepository) | |||
| repo.get_by_node_execution_id.return_value = None | |||
| repo.get_running_executions.return_value = [] | |||
| return repo | |||
| @pytest.fixture | |||
| def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository): | |||
| return WorkflowCycleManager( | |||
| application_generate_entity=mock_app_generate_entity, | |||
| workflow_system_variables=mock_workflow_system_variables, | |||
| workflow_node_execution_repository=mock_node_execution_repository, | |||
| ) | |||
| @pytest.fixture | |||
| def mock_session(): | |||
| session = MagicMock(spec=Session) | |||
| return session | |||
| @pytest.fixture | |||
| def mock_workflow(): | |||
| workflow = MagicMock(spec=Workflow) | |||
| workflow.id = "test-workflow-id" | |||
| workflow.tenant_id = "test-tenant-id" | |||
| workflow.app_id = "test-app-id" | |||
| workflow.type = "chat" | |||
| workflow.version = "1.0" | |||
| workflow.graph = json.dumps({"nodes": [], "edges": []}) | |||
| return workflow | |||
| @pytest.fixture | |||
| def mock_workflow_run(): | |||
| workflow_run = MagicMock(spec=WorkflowRun) | |||
| workflow_run.id = "test-workflow-run-id" | |||
| workflow_run.tenant_id = "test-tenant-id" | |||
| workflow_run.app_id = "test-app-id" | |||
| workflow_run.workflow_id = "test-workflow-id" | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.created_by_role = CreatedByRole.ACCOUNT | |||
| workflow_run.created_by = "test-user-id" | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.inputs_dict = {"query": "test query"} | |||
| workflow_run.outputs_dict = {"answer": "test answer"} | |||
| return workflow_run | |||
| def test_init( | |||
| workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository | |||
| ): | |||
| """Test initialization of WorkflowCycleManager""" | |||
| assert workflow_cycle_manager._workflow_run is None | |||
| assert workflow_cycle_manager._workflow_node_executions == {} | |||
| assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | |||
| assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | |||
| assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | |||
| def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow): | |||
| """Test _handle_workflow_run_start method""" | |||
| # Mock session.scalar to return the workflow and max sequence | |||
| mock_session.scalar.side_effect = [mock_workflow, 5] | |||
| # Call the method | |||
| workflow_run = workflow_cycle_manager._handle_workflow_run_start( | |||
| session=mock_session, | |||
| workflow_id="test-workflow-id", | |||
| user_id="test-user-id", | |||
| created_by_role=CreatedByRole.ACCOUNT, | |||
| ) | |||
| # Verify the result | |||
| assert workflow_run.tenant_id == mock_workflow.tenant_id | |||
| assert workflow_run.app_id == mock_workflow.app_id | |||
| assert workflow_run.workflow_id == mock_workflow.id | |||
| assert workflow_run.sequence_number == 6 # max_sequence + 1 | |||
| assert workflow_run.status == WorkflowRunStatus.RUNNING | |||
| assert workflow_run.created_by_role == CreatedByRole.ACCOUNT | |||
| assert workflow_run.created_by == "test-user-id" | |||
| # Verify session.add was called | |||
| mock_session.add.assert_called_once_with(workflow_run) | |||
| def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_success method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_success( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=100, | |||
| total_steps=5, | |||
| outputs={"answer": "test answer"}, | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.SUCCEEDED | |||
| assert result.outputs == json.dumps({"answer": "test answer"}) | |||
| assert result.total_tokens == 100 | |||
| assert result.total_steps == 5 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_failed method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Mock get_running_executions to return an empty list | |||
| workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=50, | |||
| total_steps=3, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error="Test error message", | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.FAILED.value | |||
| assert result.error == "Test error message" | |||
| assert result.total_tokens == 50 | |||
| assert result.total_steps == 3 | |||
| assert result.finished_at is not None | |||
| def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): | |||
| """Test _handle_node_execution_start method""" | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeStartedEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| event.node_id = "test-node-id" | |||
| event.node_type = NodeType.LLM | |||
| # Create node_data as a separate mock | |||
| node_data = MagicMock() | |||
| node_data.title = "Test Node" | |||
| event.node_data = node_data | |||
| event.predecessor_node_id = "test-predecessor-node-id" | |||
| event.node_run_index = 1 | |||
| event.parallel_mode_run_id = "test-parallel-mode-run-id" | |||
| event.in_iteration_id = "test-iteration-id" | |||
| event.in_loop_id = "test-loop-id" | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_node_execution_start( | |||
| workflow_run=mock_workflow_run, | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| assert result.tenant_id == mock_workflow_run.tenant_id | |||
| assert result.app_id == mock_workflow_run.app_id | |||
| assert result.workflow_id == mock_workflow_run.workflow_id | |||
| assert result.workflow_run_id == mock_workflow_run.id | |||
| assert result.node_execution_id == event.node_execution_id | |||
| assert result.node_id == event.node_id | |||
| assert result.node_type == event.node_type.value | |||
| assert result.title == event.node_data.title | |||
| assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | |||
| assert result.created_by_role == mock_workflow_run.created_by_role | |||
| assert result.created_by == mock_workflow_run.created_by | |||
| # Verify save was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | |||
| # Verify the node execution was added to the cache | |||
| assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result | |||
| def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _get_workflow_run method""" | |||
| # Mock session.scalar to return the workflow run | |||
| mock_session.scalar.return_value = mock_workflow_run | |||
| # Call the method | |||
| result = workflow_cycle_manager._get_workflow_run( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert workflow_cycle_manager._workflow_run == mock_workflow_run | |||
| def test_handle_workflow_node_execution_success(workflow_cycle_manager): | |||
| """Test _handle_workflow_node_execution_success method""" | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeSucceededEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| event.inputs = {"input": "test input"} | |||
| event.process_data = {"process": "test process"} | |||
| event.outputs = {"output": "test output"} | |||
| event.execution_metadata = {"metadata": "test metadata"} | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| # Create a mock workflow node execution | |||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Mock _get_workflow_node_execution to return the mock node execution | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| assert result.inputs == json.dumps(event.inputs) | |||
| assert result.process_data == json.dumps(event.process_data) | |||
| assert result.outputs == json.dumps(event.outputs) | |||
| assert result.finished_at is not None | |||
| assert result.elapsed_time is not None | |||
| # Verify update was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||
| """Test _handle_workflow_run_partial_success method""" | |||
| # Mock _get_workflow_run to return the mock_workflow_run | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=mock_session, | |||
| workflow_run_id="test-workflow-run-id", | |||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||
| total_tokens=75, | |||
| total_steps=4, | |||
| outputs={"partial_answer": "test partial answer"}, | |||
| exceptions_count=2, | |||
| ) | |||
| # Verify the result | |||
| assert result == mock_workflow_run | |||
| assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value | |||
| assert result.outputs == json.dumps({"partial_answer": "test partial answer"}) | |||
| assert result.total_tokens == 75 | |||
| assert result.total_steps == 4 | |||
| assert result.exceptions_count == 2 | |||
| assert result.finished_at is not None | |||
| def test_handle_workflow_node_execution_failed(workflow_cycle_manager): | |||
| """Test _handle_workflow_node_execution_failed method""" | |||
| # Create a mock event | |||
| event = MagicMock(spec=QueueNodeFailedEvent) | |||
| event.node_execution_id = "test-node-execution-id" | |||
| event.inputs = {"input": "test input"} | |||
| event.process_data = {"process": "test process"} | |||
| event.outputs = {"output": "test output"} | |||
| event.execution_metadata = {"metadata": "test metadata"} | |||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| event.error = "Test error message" | |||
| # Create a mock workflow node execution | |||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||
| node_execution.node_execution_id = "test-node-execution-id" | |||
| # Mock _get_workflow_node_execution to return the mock node execution | |||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||
| # Call the method | |||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| event=event, | |||
| ) | |||
| # Verify the result | |||
| assert result == node_execution | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||
| assert result.error == "Test error message" | |||
| assert result.inputs == json.dumps(event.inputs) | |||
| assert result.process_data == json.dumps(event.process_data) | |||
| assert result.outputs == json.dumps(event.outputs) | |||
| assert result.finished_at is not None | |||
| assert result.elapsed_time is not None | |||
| assert result.execution_metadata == json.dumps(event.execution_metadata) | |||
| # Verify update was called | |||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||
| @@ -8,7 +8,7 @@ import pytest | |||
| from pytest_mock import MockerFixture | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||
| from models.workflow import WorkflowNodeExecution | |||
| @@ -80,7 +80,7 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): | |||
| """Test get_by_node_execution_id method.""" | |||
| session_obj, _ = session | |||
| # Set up mock | |||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| @@ -99,7 +99,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): | |||
| """Test get_by_workflow_run method.""" | |||
| session_obj, _ = session | |||
| # Set up mock | |||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| @@ -120,7 +120,7 @@ def test_get_running_executions(repository, session, mocker: MockerFixture): | |||
| """Test get_running_executions method.""" | |||
| session_obj, _ = session | |||
| # Set up mock | |||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_select.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||
| @@ -158,7 +158,7 @@ def test_clear(repository, session, mocker: MockerFixture): | |||
| """Test clear method.""" | |||
| session_obj, _ = session | |||
| # Set up mock | |||
| mock_delete = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.delete") | |||
| mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete") | |||
| mock_stmt = mocker.MagicMock() | |||
| mock_delete.return_value = mock_stmt | |||
| mock_stmt.where.return_value = mock_stmt | |||