Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.0
| ext_otel, | ext_otel, | ||||
| ext_proxy_fix, | ext_proxy_fix, | ||||
| ext_redis, | ext_redis, | ||||
| ext_repositories, | |||||
| ext_sentry, | ext_sentry, | ||||
| ext_set_secretkey, | ext_set_secretkey, | ||||
| ext_storage, | ext_storage, | ||||
| ext_migrate, | ext_migrate, | ||||
| ext_redis, | ext_redis, | ||||
| ext_storage, | ext_storage, | ||||
| ext_repositories, | |||||
| ext_celery, | ext_celery, | ||||
| ext_login, | ext_login, | ||||
| ext_mail, | ext_mail, |
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | from core.model_runtime.errors.invoke import InvokeAuthorizationError | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | ||||
| from core.workflow.repository import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( |
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | ||||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.entities.app_invoke_entities import ( | from core.app.entities.app_invoke_entities import ( | ||||
| AdvancedChatAppGenerateEntity, | AdvancedChatAppGenerateEntity, | ||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | ||||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||||
| from events.message_event import message_was_created | from events.message_event import message_was_created | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models import Conversation, EndUser, Message, MessageFile | from models import Conversation, EndUser, Message, MessageFile | ||||
| else: | else: | ||||
| raise NotImplementedError(f"User type not supported: {type(user)}") | raise NotImplementedError(f"User type not supported: {type(user)}") | ||||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| workflow_system_variables={ | workflow_system_variables={ | ||||
| SystemVariableKey.QUERY: message.query, | SystemVariableKey.QUERY: message.query, |
| from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager | from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager | ||||
| from core.app.apps.workflow.app_runner import WorkflowAppRunner | from core.app.apps.workflow.app_runner import WorkflowAppRunner | ||||
| from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter | ||||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | ||||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | from core.model_runtime.errors.invoke import InvokeAuthorizationError | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.workflow.repository import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from models import Account, App, EndUser, Workflow | from models import Account, App, EndUser, Workflow | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( | ||||
| # Create workflow node execution repository | # Create workflow node execution repository | ||||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": application_generate_entity.app_config.tenant_id, | |||||
| "app_id": application_generate_entity.app_config.app_id, | |||||
| "session_factory": session_factory, | |||||
| } | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, | |||||
| tenant_id=application_generate_entity.app_config.tenant_id, | |||||
| app_id=application_generate_entity.app_config.app_id, | |||||
| ) | ) | ||||
| return self._generate( | return self._generate( |
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | ||||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.entities.app_invoke_entities import ( | from core.app.entities.app_invoke_entities import ( | ||||
| AgentChatAppGenerateEntity, | AgentChatAppGenerateEntity, | ||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | ||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | ||||
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( |
| # Core base package |
| from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||||
| __all__ = [ | |||||
| "AppGeneratorTTSPublisher", | |||||
| "AudioTrunk", | |||||
| ] |
| UnitEnum, | UnitEnum, | ||||
| ) | ) | ||||
| from core.ops.utils import filter_none_values | from core.ops.utils import filter_none_values | ||||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser | from models.model import EndUser | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory}, | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run |
| LangSmithRunUpdateModel, | LangSmithRunUpdateModel, | ||||
| ) | ) | ||||
| from core.ops.utils import filter_none_values, generate_dotted_order | from core.ops.utils import filter_none_values, generate_dotted_order | ||||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser, MessageFile | from models.model import EndUser, MessageFile | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": trace_info.tenant_id, | |||||
| "app_id": trace_info.metadata.get("app_id"), | |||||
| "session_factory": session_factory, | |||||
| }, | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run |
| TraceTaskName, | TraceTaskName, | ||||
| WorkflowTraceInfo, | WorkflowTraceInfo, | ||||
| ) | ) | ||||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import EndUser, MessageFile | from models.model import EndUser, MessageFile | ||||
| # through workflow_run_id get all_nodes_execution using repository | # through workflow_run_id get all_nodes_execution using repository | ||||
| session_factory = sessionmaker(bind=db.engine) | session_factory = sessionmaker(bind=db.engine) | ||||
| workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": trace_info.tenant_id, | |||||
| "app_id": trace_info.metadata.get("app_id"), | |||||
| "session_factory": session_factory, | |||||
| }, | |||||
| workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") | |||||
| ) | ) | ||||
| # Get all executions for this workflow run | # Get all executions for this workflow run |
| This package contains concrete implementations of the repository interfaces | This package contains concrete implementations of the repository interfaces | ||||
| defined in the core.workflow.repository package. | defined in the core.workflow.repository package. | ||||
| """ | """ | ||||
| from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| __all__ = [ | |||||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||||
| ] |
| """ | |||||
| Registry for repository implementations. | |||||
| This module is responsible for registering factory functions with the repository factory. | |||||
| """ | |||||
| import logging | |||||
| from collections.abc import Mapping | |||||
| from typing import Any | |||||
| from sqlalchemy.orm import sessionmaker | |||||
| from configs import dify_config | |||||
| from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||||
| from extensions.ext_database import db | |||||
| logger = logging.getLogger(__name__) | |||||
| # Storage type constants | |||||
| STORAGE_TYPE_RDBMS = "rdbms" | |||||
| STORAGE_TYPE_HYBRID = "hybrid" | |||||
| def register_repositories() -> None: | |||||
| """ | |||||
| Register repository factory functions with the RepositoryFactory. | |||||
| This function reads configuration settings to determine which repository | |||||
| implementations to register. | |||||
| """ | |||||
| # Configure WorkflowNodeExecutionRepository factory based on configuration | |||||
| workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE | |||||
| # Check storage type and register appropriate implementation | |||||
| if workflow_node_execution_storage == STORAGE_TYPE_RDBMS: | |||||
| # Register SQLAlchemy implementation for RDBMS storage | |||||
| logger.info("Registering WorkflowNodeExecution repository with RDBMS storage") | |||||
| RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository) | |||||
| elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID: | |||||
| # Hybrid storage is not yet implemented | |||||
| raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented") | |||||
| else: | |||||
| # Unknown storage type | |||||
| raise ValueError( | |||||
| f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. " | |||||
| f"Supported types: {STORAGE_TYPE_RDBMS}" | |||||
| ) | |||||
| def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository: | |||||
| """ | |||||
| Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation. | |||||
| This factory function creates a repository for the RDBMS storage type. | |||||
| Args: | |||||
| params: Parameters for creating the repository, including: | |||||
| - tenant_id: Required. The tenant ID for multi-tenancy. | |||||
| - app_id: Optional. The application ID for filtering. | |||||
| - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided, | |||||
| a new sessionmaker will be created using the global database engine. | |||||
| Returns: | |||||
| A WorkflowNodeExecutionRepository instance | |||||
| Raises: | |||||
| ValueError: If required parameters are missing | |||||
| """ | |||||
| # Extract required parameters | |||||
| tenant_id = params.get("tenant_id") | |||||
| if tenant_id is None: | |||||
| raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage") | |||||
| # Extract optional parameters | |||||
| app_id = params.get("app_id") | |||||
| # Use the session_factory from params if provided, otherwise create one using the global db engine | |||||
| session_factory = params.get("session_factory") | |||||
| if session_factory is None: | |||||
| # Create a sessionmaker using the same engine as the global db session | |||||
| session_factory = sessionmaker(bind=db.engine) | |||||
| # Create and return the repository | |||||
| return SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=session_factory, tenant_id=tenant_id, app_id=app_id | |||||
| ) |
| from sqlalchemy.engine import Engine | from sqlalchemy.engine import Engine | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||||
| from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class SQLAlchemyWorkflowNodeExecutionRepository: | |||||
| class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): | |||||
| """ | """ | ||||
| SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. | SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. | ||||
| """ | |||||
| WorkflowNodeExecution repository implementations. | |||||
| """ | |||||
| from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| __all__ = [ | |||||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||||
| ] |
| storage mechanism. | storage mechanism. | ||||
| """ | """ | ||||
| from core.workflow.repository.repository_factory import RepositoryFactory | |||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository | |||||
| __all__ = [ | __all__ = [ | ||||
| "RepositoryFactory", | |||||
| "OrderConfig", | |||||
| "WorkflowNodeExecutionRepository", | "WorkflowNodeExecutionRepository", | ||||
| ] | ] |
| """ | |||||
| Repository factory for creating repository instances. | |||||
| This module provides a simple factory interface for creating repository instances. | |||||
| It does not contain any implementation details or dependencies on specific repositories. | |||||
| """ | |||||
| from collections.abc import Callable, Mapping | |||||
| from typing import Any, Literal, Optional, cast | |||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||||
| # Type for factory functions - takes a dict of parameters and returns any repository type | |||||
| RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] | |||||
| # Type for workflow node execution factory function | |||||
| WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] | |||||
| # Repository type literals | |||||
| _RepositoryType = Literal["workflow_node_execution"] | |||||
| class RepositoryFactory: | |||||
| """ | |||||
| Factory class for creating repository instances. | |||||
| This factory delegates the actual repository creation to implementation-specific | |||||
| factory functions that are registered with the factory at runtime. | |||||
| """ | |||||
| # Dictionary to store factory functions | |||||
| _factory_functions: dict[str, RepositoryFactoryFunc] = {} | |||||
| @classmethod | |||||
| def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None: | |||||
| """ | |||||
| Register a factory function for a specific repository type. | |||||
| This is a private method and should not be called directly. | |||||
| Args: | |||||
| repository_type: The type of repository (e.g., 'workflow_node_execution') | |||||
| factory_func: A function that takes parameters and returns a repository instance | |||||
| """ | |||||
| cls._factory_functions[repository_type] = factory_func | |||||
| @classmethod | |||||
| def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: | |||||
| """ | |||||
| Create a new repository instance with the provided parameters. | |||||
| This is a private method and should not be called directly. | |||||
| Args: | |||||
| repository_type: The type of repository to create | |||||
| params: A dictionary of parameters to pass to the factory function | |||||
| Returns: | |||||
| A new instance of the requested repository | |||||
| Raises: | |||||
| ValueError: If no factory function is registered for the repository type | |||||
| """ | |||||
| if repository_type not in cls._factory_functions: | |||||
| raise ValueError(f"No factory function registered for repository type '{repository_type}'") | |||||
| # Use empty dict if params is None | |||||
| params = params or {} | |||||
| return cls._factory_functions[repository_type](params) | |||||
| @classmethod | |||||
| def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None: | |||||
| """ | |||||
| Register a factory function for the workflow node execution repository. | |||||
| Args: | |||||
| factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance | |||||
| """ | |||||
| cls._register_factory("workflow_node_execution", factory_func) | |||||
| @classmethod | |||||
| def create_workflow_node_execution_repository( | |||||
| cls, params: Optional[Mapping[str, Any]] = None | |||||
| ) -> WorkflowNodeExecutionRepository: | |||||
| """ | |||||
| Create a new WorkflowNodeExecutionRepository instance with the provided parameters. | |||||
| Args: | |||||
| params: A dictionary of parameters to pass to the factory function | |||||
| Returns: | |||||
| A new instance of the WorkflowNodeExecutionRepository | |||||
| Raises: | |||||
| ValueError: If no factory function is registered for the workflow_node_execution repository type | |||||
| """ | |||||
| # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc | |||||
| return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params)) |
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | ||||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import ( | from core.app.entities.app_invoke_entities import ( | ||||
| InvokeFrom, | InvokeFrom, | ||||
| WorkflowTaskState, | WorkflowTaskState, | ||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage | |||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | ||||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.enums import CreatedByRole | from models.enums import CreatedByRole | ||||
| else: | else: | ||||
| raise ValueError(f"Invalid user type: {type(user)}") | raise ValueError(f"Invalid user type: {type(user)}") | ||||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||||
| self._workflow_cycle_manager = WorkflowCycleManager( | |||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| workflow_system_variables={ | workflow_system_variables={ | ||||
| SystemVariableKey.FILES: application_generate_entity.files, | SystemVariableKey.FILES: application_generate_entity.files, |
| ) | ) | ||||
| class WorkflowCycleManage: | |||||
| class WorkflowCycleManager: | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| *, | *, |
| """ | |||||
| Extension for initializing repositories. | |||||
| This extension registers repository implementations with the RepositoryFactory. | |||||
| """ | |||||
| from core.repositories.repository_registry import register_repositories | |||||
| from dify_app import DifyApp | |||||
| def init_app(_app: DifyApp) -> None: | |||||
| """ | |||||
| Initialize repository implementations. | |||||
| Args: | |||||
| _app: The Flask application instance (unused) | |||||
| """ | |||||
| register_repositories() |
| from typing import Optional | from typing import Optional | ||||
| import contexts | import contexts | ||||
| from core.workflow.repository import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | from core.workflow.repository.workflow_node_execution_repository import OrderConfig | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | from libs.infinite_scroll_pagination import InfiniteScrollPagination | ||||
| return [] | return [] | ||||
| # Use the repository to get the node executions | # Use the repository to get the node executions | ||||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": app_model.tenant_id, | |||||
| "app_id": app_model.id, | |||||
| "session_factory": db.session.get_bind(), | |||||
| } | |||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||||
| ) | ) | ||||
| # Use the repository to get the node executions with ordering | # Use the repository to get the node executions with ordering |
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | ||||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.variables import Variable | from core.variables import Variable | ||||
| from core.workflow.entities.node_entities import NodeRunResult | from core.workflow.entities.node_entities import NodeRunResult | ||||
| from core.workflow.errors import WorkflowNodeRunFailedError | from core.workflow.errors import WorkflowNodeRunFailedError | ||||
| from core.workflow.nodes.event import RunCompletedEvent | from core.workflow.nodes.event import RunCompletedEvent | ||||
| from core.workflow.nodes.event.types import NodeEvent | from core.workflow.nodes.event.types import NodeEvent | ||||
| from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | ||||
| from core.workflow.repository import RepositoryFactory | |||||
| from core.workflow.workflow_entry import WorkflowEntry | from core.workflow.workflow_entry import WorkflowEntry | ||||
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| workflow_node_execution.workflow_id = draft_workflow.id | workflow_node_execution.workflow_id = draft_workflow.id | ||||
| # Use the repository to save the workflow node execution | # Use the repository to save the workflow node execution | ||||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": app_model.tenant_id, | |||||
| "app_id": app_model.id, | |||||
| "session_factory": db.session.get_bind(), | |||||
| } | |||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id | |||||
| ) | ) | ||||
| repository.save(workflow_node_execution) | repository.save(workflow_node_execution) | ||||
| from sqlalchemy import delete | from sqlalchemy import delete | ||||
| from sqlalchemy.exc import SQLAlchemyError | from sqlalchemy.exc import SQLAlchemyError | ||||
| from core.workflow.repository import RepositoryFactory | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import AppDatasetJoin | from models.dataset import AppDatasetJoin | ||||
| from models.model import ( | from models.model import ( | ||||
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | ||||
| # Create a repository instance for WorkflowNodeExecution | # Create a repository instance for WorkflowNodeExecution | ||||
| repository = RepositoryFactory.create_workflow_node_execution_repository( | |||||
| params={ | |||||
| "tenant_id": tenant_id, | |||||
| "app_id": app_id, | |||||
| "session_factory": db.session.get_bind(), | |||||
| } | |||||
| repository = SQLAlchemyWorkflowNodeExecutionRepository( | |||||
| session_factory=db.engine, tenant_id=tenant_id, app_id=app_id | |||||
| ) | ) | ||||
| # Use the clear method to delete all records for this tenant_id and app_id | # Use the clear method to delete all records for this tenant_id and app_id |
| import json | |||||
| import time | |||||
| from datetime import UTC, datetime | |||||
| from unittest.mock import MagicMock, patch | |||||
| import pytest | |||||
| from sqlalchemy.orm import Session | |||||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||||
| from core.app.entities.queue_entities import ( | |||||
| QueueNodeFailedEvent, | |||||
| QueueNodeStartedEvent, | |||||
| QueueNodeSucceededEvent, | |||||
| ) | |||||
| from core.workflow.enums import SystemVariableKey | |||||
| from core.workflow.nodes import NodeType | |||||
| from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository | |||||
| from core.workflow.workflow_cycle_manager import WorkflowCycleManager | |||||
| from models.enums import CreatedByRole | |||||
| from models.workflow import ( | |||||
| Workflow, | |||||
| WorkflowNodeExecution, | |||||
| WorkflowNodeExecutionStatus, | |||||
| WorkflowRun, | |||||
| WorkflowRunStatus, | |||||
| ) | |||||
| @pytest.fixture | |||||
| def mock_app_generate_entity(): | |||||
| entity = MagicMock(spec=AdvancedChatAppGenerateEntity) | |||||
| entity.inputs = {"query": "test query"} | |||||
| entity.invoke_from = InvokeFrom.WEB_APP | |||||
| # Create app_config as a separate mock | |||||
| app_config = MagicMock() | |||||
| app_config.tenant_id = "test-tenant-id" | |||||
| app_config.app_id = "test-app-id" | |||||
| entity.app_config = app_config | |||||
| return entity | |||||
| @pytest.fixture | |||||
| def mock_workflow_system_variables(): | |||||
| return { | |||||
| SystemVariableKey.QUERY: "test query", | |||||
| SystemVariableKey.CONVERSATION_ID: "test-conversation-id", | |||||
| SystemVariableKey.USER_ID: "test-user-id", | |||||
| SystemVariableKey.APP_ID: "test-app-id", | |||||
| SystemVariableKey.WORKFLOW_ID: "test-workflow-id", | |||||
| SystemVariableKey.WORKFLOW_RUN_ID: "test-workflow-run-id", | |||||
| } | |||||
| @pytest.fixture | |||||
| def mock_node_execution_repository(): | |||||
| repo = MagicMock(spec=WorkflowNodeExecutionRepository) | |||||
| repo.get_by_node_execution_id.return_value = None | |||||
| repo.get_running_executions.return_value = [] | |||||
| return repo | |||||
| @pytest.fixture | |||||
| def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository): | |||||
| return WorkflowCycleManager( | |||||
| application_generate_entity=mock_app_generate_entity, | |||||
| workflow_system_variables=mock_workflow_system_variables, | |||||
| workflow_node_execution_repository=mock_node_execution_repository, | |||||
| ) | |||||
| @pytest.fixture | |||||
| def mock_session(): | |||||
| session = MagicMock(spec=Session) | |||||
| return session | |||||
| @pytest.fixture | |||||
| def mock_workflow(): | |||||
| workflow = MagicMock(spec=Workflow) | |||||
| workflow.id = "test-workflow-id" | |||||
| workflow.tenant_id = "test-tenant-id" | |||||
| workflow.app_id = "test-app-id" | |||||
| workflow.type = "chat" | |||||
| workflow.version = "1.0" | |||||
| workflow.graph = json.dumps({"nodes": [], "edges": []}) | |||||
| return workflow | |||||
| @pytest.fixture | |||||
| def mock_workflow_run(): | |||||
| workflow_run = MagicMock(spec=WorkflowRun) | |||||
| workflow_run.id = "test-workflow-run-id" | |||||
| workflow_run.tenant_id = "test-tenant-id" | |||||
| workflow_run.app_id = "test-app-id" | |||||
| workflow_run.workflow_id = "test-workflow-id" | |||||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||||
| workflow_run.created_by_role = CreatedByRole.ACCOUNT | |||||
| workflow_run.created_by = "test-user-id" | |||||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| workflow_run.inputs_dict = {"query": "test query"} | |||||
| workflow_run.outputs_dict = {"answer": "test answer"} | |||||
| return workflow_run | |||||
| def test_init( | |||||
| workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository | |||||
| ): | |||||
| """Test initialization of WorkflowCycleManager""" | |||||
| assert workflow_cycle_manager._workflow_run is None | |||||
| assert workflow_cycle_manager._workflow_node_executions == {} | |||||
| assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity | |||||
| assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables | |||||
| assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository | |||||
| def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow): | |||||
| """Test _handle_workflow_run_start method""" | |||||
| # Mock session.scalar to return the workflow and max sequence | |||||
| mock_session.scalar.side_effect = [mock_workflow, 5] | |||||
| # Call the method | |||||
| workflow_run = workflow_cycle_manager._handle_workflow_run_start( | |||||
| session=mock_session, | |||||
| workflow_id="test-workflow-id", | |||||
| user_id="test-user-id", | |||||
| created_by_role=CreatedByRole.ACCOUNT, | |||||
| ) | |||||
| # Verify the result | |||||
| assert workflow_run.tenant_id == mock_workflow.tenant_id | |||||
| assert workflow_run.app_id == mock_workflow.app_id | |||||
| assert workflow_run.workflow_id == mock_workflow.id | |||||
| assert workflow_run.sequence_number == 6 # max_sequence + 1 | |||||
| assert workflow_run.status == WorkflowRunStatus.RUNNING | |||||
| assert workflow_run.created_by_role == CreatedByRole.ACCOUNT | |||||
| assert workflow_run.created_by == "test-user-id" | |||||
| # Verify session.add was called | |||||
| mock_session.add.assert_called_once_with(workflow_run) | |||||
| def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||||
| """Test _handle_workflow_run_success method""" | |||||
| # Mock _get_workflow_run to return the mock_workflow_run | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_run_success( | |||||
| session=mock_session, | |||||
| workflow_run_id="test-workflow-run-id", | |||||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||||
| total_tokens=100, | |||||
| total_steps=5, | |||||
| outputs={"answer": "test answer"}, | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == mock_workflow_run | |||||
| assert result.status == WorkflowRunStatus.SUCCEEDED | |||||
| assert result.outputs == json.dumps({"answer": "test answer"}) | |||||
| assert result.total_tokens == 100 | |||||
| assert result.total_steps == 5 | |||||
| assert result.finished_at is not None | |||||
| def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run): | |||||
| """Test _handle_workflow_run_failed method""" | |||||
| # Mock _get_workflow_run to return the mock_workflow_run | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||||
| # Mock get_running_executions to return an empty list | |||||
| workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_run_failed( | |||||
| session=mock_session, | |||||
| workflow_run_id="test-workflow-run-id", | |||||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||||
| total_tokens=50, | |||||
| total_steps=3, | |||||
| status=WorkflowRunStatus.FAILED, | |||||
| error="Test error message", | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == mock_workflow_run | |||||
| assert result.status == WorkflowRunStatus.FAILED.value | |||||
| assert result.error == "Test error message" | |||||
| assert result.total_tokens == 50 | |||||
| assert result.total_steps == 3 | |||||
| assert result.finished_at is not None | |||||
| def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): | |||||
| """Test _handle_node_execution_start method""" | |||||
| # Create a mock event | |||||
| event = MagicMock(spec=QueueNodeStartedEvent) | |||||
| event.node_execution_id = "test-node-execution-id" | |||||
| event.node_id = "test-node-id" | |||||
| event.node_type = NodeType.LLM | |||||
| # Create node_data as a separate mock | |||||
| node_data = MagicMock() | |||||
| node_data.title = "Test Node" | |||||
| event.node_data = node_data | |||||
| event.predecessor_node_id = "test-predecessor-node-id" | |||||
| event.node_run_index = 1 | |||||
| event.parallel_mode_run_id = "test-parallel-mode-run-id" | |||||
| event.in_iteration_id = "test-iteration-id" | |||||
| event.in_loop_id = "test-loop-id" | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_node_execution_start( | |||||
| workflow_run=mock_workflow_run, | |||||
| event=event, | |||||
| ) | |||||
| # Verify the result | |||||
| assert result.tenant_id == mock_workflow_run.tenant_id | |||||
| assert result.app_id == mock_workflow_run.app_id | |||||
| assert result.workflow_id == mock_workflow_run.workflow_id | |||||
| assert result.workflow_run_id == mock_workflow_run.id | |||||
| assert result.node_execution_id == event.node_execution_id | |||||
| assert result.node_id == event.node_id | |||||
| assert result.node_type == event.node_type.value | |||||
| assert result.title == event.node_data.title | |||||
| assert result.status == WorkflowNodeExecutionStatus.RUNNING.value | |||||
| assert result.created_by_role == mock_workflow_run.created_by_role | |||||
| assert result.created_by == mock_workflow_run.created_by | |||||
| # Verify save was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) | |||||
| # Verify the node execution was added to the cache | |||||
| assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result | |||||
| def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): | |||||
| """Test _get_workflow_run method""" | |||||
| # Mock session.scalar to return the workflow run | |||||
| mock_session.scalar.return_value = mock_workflow_run | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._get_workflow_run( | |||||
| session=mock_session, | |||||
| workflow_run_id="test-workflow-run-id", | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == mock_workflow_run | |||||
| assert workflow_cycle_manager._workflow_run == mock_workflow_run | |||||
| def test_handle_workflow_node_execution_success(workflow_cycle_manager): | |||||
| """Test _handle_workflow_node_execution_success method""" | |||||
| # Create a mock event | |||||
| event = MagicMock(spec=QueueNodeSucceededEvent) | |||||
| event.node_execution_id = "test-node-execution-id" | |||||
| event.inputs = {"input": "test input"} | |||||
| event.process_data = {"process": "test process"} | |||||
| event.outputs = {"output": "test output"} | |||||
| event.execution_metadata = {"metadata": "test metadata"} | |||||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| # Create a mock workflow node execution | |||||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||||
| node_execution.node_execution_id = "test-node-execution-id" | |||||
| # Mock _get_workflow_node_execution to return the mock node execution | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_success( | |||||
| event=event, | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value | |||||
| assert result.inputs == json.dumps(event.inputs) | |||||
| assert result.process_data == json.dumps(event.process_data) | |||||
| assert result.outputs == json.dumps(event.outputs) | |||||
| assert result.finished_at is not None | |||||
| assert result.elapsed_time is not None | |||||
| # Verify update was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) | |||||
| def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): | |||||
| """Test _handle_workflow_run_partial_success method""" | |||||
| # Mock _get_workflow_run to return the mock_workflow_run | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_run_partial_success( | |||||
| session=mock_session, | |||||
| workflow_run_id="test-workflow-run-id", | |||||
| start_at=time.perf_counter() - 10, # 10 seconds ago | |||||
| total_tokens=75, | |||||
| total_steps=4, | |||||
| outputs={"partial_answer": "test partial answer"}, | |||||
| exceptions_count=2, | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == mock_workflow_run | |||||
| assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value | |||||
| assert result.outputs == json.dumps({"partial_answer": "test partial answer"}) | |||||
| assert result.total_tokens == 75 | |||||
| assert result.total_steps == 4 | |||||
| assert result.exceptions_count == 2 | |||||
| assert result.finished_at is not None | |||||
| def test_handle_workflow_node_execution_failed(workflow_cycle_manager): | |||||
| """Test _handle_workflow_node_execution_failed method""" | |||||
| # Create a mock event | |||||
| event = MagicMock(spec=QueueNodeFailedEvent) | |||||
| event.node_execution_id = "test-node-execution-id" | |||||
| event.inputs = {"input": "test input"} | |||||
| event.process_data = {"process": "test process"} | |||||
| event.outputs = {"output": "test output"} | |||||
| event.execution_metadata = {"metadata": "test metadata"} | |||||
| event.start_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| event.error = "Test error message" | |||||
| # Create a mock workflow node execution | |||||
| node_execution = MagicMock(spec=WorkflowNodeExecution) | |||||
| node_execution.node_execution_id = "test-node-execution-id" | |||||
| # Mock _get_workflow_node_execution to return the mock node execution | |||||
| with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution): | |||||
| # Call the method | |||||
| result = workflow_cycle_manager._handle_workflow_node_execution_failed( | |||||
| event=event, | |||||
| ) | |||||
| # Verify the result | |||||
| assert result == node_execution | |||||
| assert result.status == WorkflowNodeExecutionStatus.FAILED.value | |||||
| assert result.error == "Test error message" | |||||
| assert result.inputs == json.dumps(event.inputs) | |||||
| assert result.process_data == json.dumps(event.process_data) | |||||
| assert result.outputs == json.dumps(event.outputs) | |||||
| assert result.finished_at is not None | |||||
| assert result.elapsed_time is not None | |||||
| assert result.execution_metadata == json.dumps(event.execution_metadata) | |||||
| # Verify update was called | |||||
| workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution) |
| from pytest_mock import MockerFixture | from pytest_mock import MockerFixture | ||||
| from sqlalchemy.orm import Session, sessionmaker | from sqlalchemy.orm import Session, sessionmaker | ||||
| from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository | |||||
| from core.workflow.repository.workflow_node_execution_repository import OrderConfig | from core.workflow.repository.workflow_node_execution_repository import OrderConfig | ||||
| from models.workflow import WorkflowNodeExecution | from models.workflow import WorkflowNodeExecution | ||||
| """Test get_by_node_execution_id method.""" | """Test get_by_node_execution_id method.""" | ||||
| session_obj, _ = session | session_obj, _ = session | ||||
| # Set up mock | # Set up mock | ||||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| """Test get_by_workflow_run method.""" | """Test get_by_workflow_run method.""" | ||||
| session_obj, _ = session | session_obj, _ = session | ||||
| # Set up mock | # Set up mock | ||||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| """Test get_running_executions method.""" | """Test get_running_executions method.""" | ||||
| session_obj, _ = session | session_obj, _ = session | ||||
| # Set up mock | # Set up mock | ||||
| mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select") | |||||
| mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") | |||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_select.return_value = mock_stmt | mock_select.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt | ||||
| """Test clear method.""" | """Test clear method.""" | ||||
| session_obj, _ = session | session_obj, _ = session | ||||
| # Set up mock | # Set up mock | ||||
| mock_delete = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.delete") | |||||
| mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete") | |||||
| mock_stmt = mocker.MagicMock() | mock_stmt = mocker.MagicMock() | ||||
| mock_delete.return_value = mock_stmt | mock_delete.return_value = mock_stmt | ||||
| mock_stmt.where.return_value = mock_stmt | mock_stmt.where.return_value = mock_stmt |