Co-authored-by: liangxin <liangxin@shein.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: liangxin <xinlmain@gmail.com>tags/1.8.0
| @@ -5,7 +5,7 @@ cd web && pnpm install | |||
| pipx install uv | |||
| echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc | |||
| echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc | |||
| echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc | |||
| echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc | |||
| echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc | |||
| echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc | |||
| @@ -74,7 +74,7 @@ | |||
| 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. | |||
| ```bash | |||
| uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin | |||
| uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage | |||
| ``` | |||
| Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: | |||
| @@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings): | |||
| """ | |||
| CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( | |||
| description="Repository implementation for WorkflowExecution. Specify as a module path", | |||
| description="Repository implementation for WorkflowExecution. Options: " | |||
| "'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), " | |||
| "'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'", | |||
| default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", | |||
| ) | |||
| CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( | |||
| description="Repository implementation for WorkflowNodeExecution. Specify as a module path", | |||
| description="Repository implementation for WorkflowNodeExecution. Options: " | |||
| "'core.repositories.sqlalchemy_workflow_node_execution_repository." | |||
| "SQLAlchemyWorkflowNodeExecutionRepository' (default), " | |||
| "'core.repositories.celery_workflow_node_execution_repository." | |||
| "CeleryWorkflowNodeExecutionRepository'", | |||
| default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", | |||
| ) | |||
| @@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces | |||
| defined in the core.workflow.repository package. | |||
| """ | |||
| from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository | |||
| from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository | |||
| from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError | |||
| from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository | |||
| __all__ = [ | |||
| "CeleryWorkflowExecutionRepository", | |||
| "CeleryWorkflowNodeExecutionRepository", | |||
| "DifyCoreRepositoryFactory", | |||
| "RepositoryImportError", | |||
| "SQLAlchemyWorkflowNodeExecutionRepository", | |||
| @@ -0,0 +1,126 @@ | |||
| """ | |||
| Celery-based implementation of the WorkflowExecutionRepository. | |||
| This implementation uses Celery tasks for asynchronous storage operations, | |||
| providing improved performance by offloading database operations to background workers. | |||
| """ | |||
| import logging | |||
| from typing import Optional, Union | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository | |||
| from libs.helper import extract_tenant_id | |||
| from models import Account, CreatorUserRole, EndUser | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| from tasks.workflow_execution_tasks import ( | |||
| save_workflow_execution_task, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): | |||
| """ | |||
| Celery-based implementation of the WorkflowExecutionRepository interface. | |||
| This implementation provides asynchronous storage capabilities by using Celery tasks | |||
| to handle database operations in background workers. This improves performance by | |||
| reducing the blocking time for workflow execution storage operations. | |||
| Key features: | |||
| - Asynchronous save operations using Celery tasks | |||
| - Support for multi-tenancy through tenant/app filtering | |||
| - Automatic retry and error handling through Celery | |||
| """ | |||
| _session_factory: sessionmaker | |||
| _tenant_id: str | |||
| _app_id: Optional[str] | |||
| _triggered_from: Optional[WorkflowRunTriggeredFrom] | |||
| _creator_user_id: str | |||
| _creator_user_role: CreatorUserRole | |||
| def __init__( | |||
| self, | |||
| session_factory: sessionmaker | Engine, | |||
| user: Union[Account, EndUser], | |||
| app_id: Optional[str], | |||
| triggered_from: Optional[WorkflowRunTriggeredFrom], | |||
| ): | |||
| """ | |||
| Initialize the repository with Celery task configuration and context information. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine for fallback operations | |||
| user: Account or EndUser object containing tenant_id, user ID, and role information | |||
| app_id: App ID for filtering by application (can be None) | |||
| triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) | |||
| """ | |||
| # Store session factory for fallback operations | |||
| if isinstance(session_factory, Engine): | |||
| self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) | |||
| elif isinstance(session_factory, sessionmaker): | |||
| self._session_factory = session_factory | |||
| else: | |||
| raise ValueError( | |||
| f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | |||
| ) | |||
| # Extract tenant_id from user | |||
| tenant_id = extract_tenant_id(user) | |||
| if not tenant_id: | |||
| raise ValueError("User must have a tenant_id or current_tenant_id") | |||
| self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None | |||
| # Store app context | |||
| self._app_id = app_id | |||
| # Extract user context | |||
| self._triggered_from = triggered_from | |||
| self._creator_user_id = user.id | |||
| # Determine user role based on user type | |||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||
| logger.info( | |||
| "Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s", | |||
| self._tenant_id, | |||
| self._app_id, | |||
| self._triggered_from, | |||
| ) | |||
| def save(self, execution: WorkflowExecution) -> None: | |||
| """ | |||
| Save or update a WorkflowExecution instance asynchronously using Celery. | |||
| This method queues the save operation as a Celery task and returns immediately, | |||
| providing improved performance for high-throughput scenarios. | |||
| Args: | |||
| execution: The WorkflowExecution instance to save or update | |||
| """ | |||
| try: | |||
| # Serialize execution for Celery task | |||
| execution_data = execution.model_dump() | |||
| # Queue the save operation as a Celery task (fire and forget) | |||
| save_workflow_execution_task.delay( | |||
| execution_data=execution_data, | |||
| tenant_id=self._tenant_id, | |||
| app_id=self._app_id or "", | |||
| triggered_from=self._triggered_from.value if self._triggered_from else "", | |||
| creator_user_id=self._creator_user_id, | |||
| creator_user_role=self._creator_user_role.value, | |||
| ) | |||
| logger.debug("Queued async save for workflow execution: %s", execution.id_) | |||
| except Exception as e: | |||
| logger.exception("Failed to queue save operation for execution %s", execution.id_) | |||
| # In case of Celery failure, we could implement a fallback to synchronous save | |||
| # For now, we'll re-raise the exception | |||
| raise | |||
| @@ -0,0 +1,190 @@ | |||
| """ | |||
| Celery-based implementation of the WorkflowNodeExecutionRepository. | |||
| This implementation uses Celery tasks for asynchronous storage operations, | |||
| providing improved performance by offloading database operations to background workers. | |||
| """ | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from typing import Optional, Union | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution | |||
| from core.workflow.repositories.workflow_node_execution_repository import ( | |||
| OrderConfig, | |||
| WorkflowNodeExecutionRepository, | |||
| ) | |||
| from libs.helper import extract_tenant_id | |||
| from models import Account, CreatorUserRole, EndUser | |||
| from models.workflow import WorkflowNodeExecutionTriggeredFrom | |||
| from tasks.workflow_node_execution_tasks import ( | |||
| save_workflow_node_execution_task, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): | |||
| """ | |||
| Celery-based implementation of the WorkflowNodeExecutionRepository interface. | |||
| This implementation provides asynchronous storage capabilities by using Celery tasks | |||
| to handle database operations in background workers. This improves performance by | |||
| reducing the blocking time for workflow node execution storage operations. | |||
| Key features: | |||
| - Asynchronous save operations using Celery tasks | |||
| - In-memory cache for immediate reads | |||
| - Support for multi-tenancy through tenant/app filtering | |||
| - Automatic retry and error handling through Celery | |||
| """ | |||
| _session_factory: sessionmaker | |||
| _tenant_id: str | |||
| _app_id: Optional[str] | |||
| _triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom] | |||
| _creator_user_id: str | |||
| _creator_user_role: CreatorUserRole | |||
| _execution_cache: dict[str, WorkflowNodeExecution] | |||
| _workflow_execution_mapping: dict[str, list[str]] | |||
| def __init__( | |||
| self, | |||
| session_factory: sessionmaker | Engine, | |||
| user: Union[Account, EndUser], | |||
| app_id: Optional[str], | |||
| triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], | |||
| ): | |||
| """ | |||
| Initialize the repository with Celery task configuration and context information. | |||
| Args: | |||
| session_factory: SQLAlchemy sessionmaker or engine for fallback operations | |||
| user: Account or EndUser object containing tenant_id, user ID, and role information | |||
| app_id: App ID for filtering by application (can be None) | |||
| triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) | |||
| """ | |||
| # Store session factory for fallback operations | |||
| if isinstance(session_factory, Engine): | |||
| self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) | |||
| elif isinstance(session_factory, sessionmaker): | |||
| self._session_factory = session_factory | |||
| else: | |||
| raise ValueError( | |||
| f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" | |||
| ) | |||
| # Extract tenant_id from user | |||
| tenant_id = extract_tenant_id(user) | |||
| if not tenant_id: | |||
| raise ValueError("User must have a tenant_id or current_tenant_id") | |||
| self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None | |||
| # Store app context | |||
| self._app_id = app_id | |||
| # Extract user context | |||
| self._triggered_from = triggered_from | |||
| self._creator_user_id = user.id | |||
| # Determine user role based on user type | |||
| self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER | |||
| # In-memory cache for workflow node executions | |||
| self._execution_cache: dict[str, WorkflowNodeExecution] = {} | |||
| # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval | |||
| self._workflow_execution_mapping: dict[str, list[str]] = {} | |||
| logger.info( | |||
| "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", | |||
| self._tenant_id, | |||
| self._app_id, | |||
| self._triggered_from, | |||
| ) | |||
| def save(self, execution: WorkflowNodeExecution) -> None: | |||
| """ | |||
| Save or update a WorkflowNodeExecution instance to cache and asynchronously to database. | |||
| This method stores the execution in cache immediately for fast reads and queues | |||
| the save operation as a Celery task without tracking the task status. | |||
| Args: | |||
| execution: The WorkflowNodeExecution instance to save or update | |||
| """ | |||
| try: | |||
| # Store in cache immediately for fast reads | |||
| self._execution_cache[execution.id] = execution | |||
| # Update workflow execution mapping for efficient retrieval | |||
| if execution.workflow_execution_id: | |||
| if execution.workflow_execution_id not in self._workflow_execution_mapping: | |||
| self._workflow_execution_mapping[execution.workflow_execution_id] = [] | |||
| if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]: | |||
| self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id) | |||
| # Serialize execution for Celery task | |||
| execution_data = execution.model_dump() | |||
| # Queue the save operation as a Celery task (fire and forget) | |||
| save_workflow_node_execution_task.delay( | |||
| execution_data=execution_data, | |||
| tenant_id=self._tenant_id, | |||
| app_id=self._app_id or "", | |||
| triggered_from=self._triggered_from.value if self._triggered_from else "", | |||
| creator_user_id=self._creator_user_id, | |||
| creator_user_role=self._creator_user_role.value, | |||
| ) | |||
| logger.debug("Cached and queued async save for workflow node execution: %s", execution.id) | |||
| except Exception as e: | |||
| logger.exception("Failed to cache or queue save operation for node execution %s", execution.id) | |||
| # In case of Celery failure, we could implement a fallback to synchronous save | |||
| # For now, we'll re-raise the exception | |||
| raise | |||
| def get_by_workflow_run( | |||
| self, | |||
| workflow_run_id: str, | |||
| order_config: Optional[OrderConfig] = None, | |||
| ) -> Sequence[WorkflowNodeExecution]: | |||
| """ | |||
| Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. | |||
| Args: | |||
| workflow_run_id: The workflow run ID | |||
| order_config: Optional configuration for ordering results | |||
| Returns: | |||
| A sequence of WorkflowNodeExecution instances | |||
| """ | |||
| try: | |||
| # Get execution IDs for this workflow run from cache | |||
| execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) | |||
| # Retrieve executions from cache | |||
| result = [] | |||
| for execution_id in execution_ids: | |||
| if execution_id in self._execution_cache: | |||
| result.append(self._execution_cache[execution_id]) | |||
| # Apply ordering if specified | |||
| if order_config and result: | |||
| # Sort based on the order configuration | |||
| reverse = order_config.order_direction == "desc" | |||
| # Sort by multiple fields if specified | |||
| for field_name in reversed(order_config.order_by): | |||
| result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) | |||
| logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) | |||
| return result | |||
| except Exception as e: | |||
| logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) | |||
| return [] | |||
| @@ -94,11 +94,9 @@ class DifyCoreRepositoryFactory: | |||
| def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: | |||
| """ | |||
| Validate that a repository class constructor accepts required parameters. | |||
| Args: | |||
| repository_class: The class to validate | |||
| required_params: List of required parameter names | |||
| Raises: | |||
| RepositoryImportError: If the constructor doesn't accept required parameters | |||
| """ | |||
| @@ -158,10 +156,8 @@ class DifyCoreRepositoryFactory: | |||
| try: | |||
| repository_class = cls._import_class(class_path) | |||
| cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) | |||
| cls._validate_constructor_signature( | |||
| repository_class, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| # All repository types now use the same constructor parameters | |||
| return repository_class( # type: ignore[no-any-return] | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -204,10 +200,8 @@ class DifyCoreRepositoryFactory: | |||
| try: | |||
| repository_class = cls._import_class(class_path) | |||
| cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) | |||
| cls._validate_constructor_signature( | |||
| repository_class, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| # All repository types now use the same constructor parameters | |||
| return repository_class( # type: ignore[no-any-return] | |||
| session_factory=session_factory, | |||
| user=user, | |||
| @@ -1,4 +1,5 @@ | |||
| from collections.abc import Mapping | |||
| from decimal import Decimal | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| @@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter: | |||
| return value | |||
| if isinstance(value, (bool, int, str, float)): | |||
| return value | |||
| if isinstance(value, Decimal): | |||
| # Convert Decimal to float for JSON serialization | |||
| return float(value) | |||
| if isinstance(value, Segment): | |||
| return self._to_json_encodable_recursive(value.value) | |||
| if isinstance(value, File): | |||
| @@ -32,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then | |||
| exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ | |||
| --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ | |||
| -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin} | |||
| -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage} | |||
| elif [[ "${MODE}" == "beat" ]]; then | |||
| exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} | |||
| @@ -0,0 +1,136 @@ | |||
| """ | |||
| Celery tasks for asynchronous workflow execution storage operations. | |||
| These tasks provide asynchronous storage capabilities for workflow execution data, | |||
| improving performance by offloading storage operations to background workers. | |||
| """ | |||
| import json | |||
| import logging | |||
| from celery import shared_task # type: ignore[import-untyped] | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from extensions.ext_database import db | |||
| from models import CreatorUserRole, WorkflowRun | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) | |||
| def save_workflow_execution_task( | |||
| self, | |||
| execution_data: dict, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| triggered_from: str, | |||
| creator_user_id: str, | |||
| creator_user_role: str, | |||
| ) -> bool: | |||
| """ | |||
| Asynchronously save or update a workflow execution to the database. | |||
| Args: | |||
| execution_data: Serialized WorkflowExecution data | |||
| tenant_id: Tenant ID for multi-tenancy | |||
| app_id: Application ID | |||
| triggered_from: Source of the execution trigger | |||
| creator_user_id: ID of the user who created the execution | |||
| creator_user_role: Role of the user who created the execution | |||
| Returns: | |||
| True if successful, False otherwise | |||
| """ | |||
| try: | |||
| # Create a new session for this task | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| with session_factory() as session: | |||
| # Deserialize execution data | |||
| execution = WorkflowExecution.model_validate(execution_data) | |||
| # Check if workflow run already exists | |||
| existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_)) | |||
| if existing_run: | |||
| # Update existing workflow run | |||
| _update_workflow_run_from_execution(existing_run, execution) | |||
| logger.debug("Updated existing workflow run: %s", execution.id_) | |||
| else: | |||
| # Create new workflow run | |||
| workflow_run = _create_workflow_run_from_execution( | |||
| execution=execution, | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| triggered_from=WorkflowRunTriggeredFrom(triggered_from), | |||
| creator_user_id=creator_user_id, | |||
| creator_user_role=CreatorUserRole(creator_user_role), | |||
| ) | |||
| session.add(workflow_run) | |||
| logger.debug("Created new workflow run: %s", execution.id_) | |||
| session.commit() | |||
| return True | |||
| except Exception as e: | |||
| logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown")) | |||
| # Retry the task with exponential backoff | |||
| raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) | |||
| def _create_workflow_run_from_execution( | |||
| execution: WorkflowExecution, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| triggered_from: WorkflowRunTriggeredFrom, | |||
| creator_user_id: str, | |||
| creator_user_role: CreatorUserRole, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Create a WorkflowRun database model from a WorkflowExecution domain entity. | |||
| """ | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.id = execution.id_ | |||
| workflow_run.tenant_id = tenant_id | |||
| workflow_run.app_id = app_id | |||
| workflow_run.workflow_id = execution.workflow_id | |||
| workflow_run.type = execution.workflow_type.value | |||
| workflow_run.triggered_from = triggered_from.value | |||
| workflow_run.version = execution.workflow_version | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) | |||
| workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) | |||
| workflow_run.status = execution.status.value | |||
| workflow_run.outputs = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" | |||
| ) | |||
| workflow_run.error = execution.error_message | |||
| workflow_run.elapsed_time = execution.elapsed_time | |||
| workflow_run.total_tokens = execution.total_tokens | |||
| workflow_run.total_steps = execution.total_steps | |||
| workflow_run.created_by_role = creator_user_role.value | |||
| workflow_run.created_by = creator_user_id | |||
| workflow_run.created_at = execution.started_at | |||
| workflow_run.finished_at = execution.finished_at | |||
| return workflow_run | |||
| def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None: | |||
| """ | |||
| Update a WorkflowRun database model from a WorkflowExecution domain entity. | |||
| """ | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| workflow_run.status = execution.status.value | |||
| workflow_run.outputs = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" | |||
| ) | |||
| workflow_run.error = execution.error_message | |||
| workflow_run.elapsed_time = execution.elapsed_time | |||
| workflow_run.total_tokens = execution.total_tokens | |||
| workflow_run.total_steps = execution.total_steps | |||
| workflow_run.finished_at = execution.finished_at | |||
| @@ -0,0 +1,171 @@ | |||
| """ | |||
| Celery tasks for asynchronous workflow node execution storage operations. | |||
| These tasks provide asynchronous storage capabilities for workflow node execution data, | |||
| improving performance by offloading storage operations to background workers. | |||
| """ | |||
| import json | |||
| import logging | |||
| from celery import shared_task # type: ignore[import-untyped] | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.workflow_node_execution import ( | |||
| WorkflowNodeExecution, | |||
| ) | |||
| from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter | |||
| from extensions.ext_database import db | |||
| from models import CreatorUserRole, WorkflowNodeExecutionModel | |||
| from models.workflow import WorkflowNodeExecutionTriggeredFrom | |||
| logger = logging.getLogger(__name__) | |||
| @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) | |||
| def save_workflow_node_execution_task( | |||
| self, | |||
| execution_data: dict, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| triggered_from: str, | |||
| creator_user_id: str, | |||
| creator_user_role: str, | |||
| ) -> bool: | |||
| """ | |||
| Asynchronously save or update a workflow node execution to the database. | |||
| Args: | |||
| execution_data: Serialized WorkflowNodeExecution data | |||
| tenant_id: Tenant ID for multi-tenancy | |||
| app_id: Application ID | |||
| triggered_from: Source of the execution trigger | |||
| creator_user_id: ID of the user who created the execution | |||
| creator_user_role: Role of the user who created the execution | |||
| Returns: | |||
| True if successful, False otherwise | |||
| """ | |||
| try: | |||
| # Create a new session for this task | |||
| session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) | |||
| with session_factory() as session: | |||
| # Deserialize execution data | |||
| execution = WorkflowNodeExecution.model_validate(execution_data) | |||
| # Check if node execution already exists | |||
| existing_execution = session.scalar( | |||
| select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id) | |||
| ) | |||
| if existing_execution: | |||
| # Update existing node execution | |||
| _update_node_execution_from_domain(existing_execution, execution) | |||
| logger.debug("Updated existing workflow node execution: %s", execution.id) | |||
| else: | |||
| # Create new node execution | |||
| node_execution = _create_node_execution_from_domain( | |||
| execution=execution, | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from), | |||
| creator_user_id=creator_user_id, | |||
| creator_user_role=CreatorUserRole(creator_user_role), | |||
| ) | |||
| session.add(node_execution) | |||
| logger.debug("Created new workflow node execution: %s", execution.id) | |||
| session.commit() | |||
| return True | |||
| except Exception as e: | |||
| logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown")) | |||
| # Retry the task with exponential backoff | |||
| raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) | |||
| def _create_node_execution_from_domain( | |||
| execution: WorkflowNodeExecution, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| triggered_from: WorkflowNodeExecutionTriggeredFrom, | |||
| creator_user_id: str, | |||
| creator_user_role: CreatorUserRole, | |||
| ) -> WorkflowNodeExecutionModel: | |||
| """ | |||
| Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. | |||
| """ | |||
| node_execution = WorkflowNodeExecutionModel() | |||
| node_execution.id = execution.id | |||
| node_execution.tenant_id = tenant_id | |||
| node_execution.app_id = app_id | |||
| node_execution.workflow_id = execution.workflow_id | |||
| node_execution.triggered_from = triggered_from.value | |||
| node_execution.workflow_run_id = execution.workflow_execution_id | |||
| node_execution.index = execution.index | |||
| node_execution.predecessor_node_id = execution.predecessor_node_id | |||
| node_execution.node_id = execution.node_id | |||
| node_execution.node_type = execution.node_type.value | |||
| node_execution.title = execution.title | |||
| node_execution.node_execution_id = execution.node_execution_id | |||
| # Serialize complex data as JSON | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" | |||
| node_execution.process_data = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" | |||
| ) | |||
| node_execution.outputs = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" | |||
| ) | |||
| # Convert metadata enum keys to strings for JSON serialization | |||
| if execution.metadata: | |||
| metadata_for_json = { | |||
| key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() | |||
| } | |||
| node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) | |||
| else: | |||
| node_execution.execution_metadata = "{}" | |||
| node_execution.status = execution.status.value | |||
| node_execution.error = execution.error | |||
| node_execution.elapsed_time = execution.elapsed_time | |||
| node_execution.created_by_role = creator_user_role.value | |||
| node_execution.created_by = creator_user_id | |||
| node_execution.created_at = execution.created_at | |||
| node_execution.finished_at = execution.finished_at | |||
| return node_execution | |||
| def _update_node_execution_from_domain( | |||
| node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution | |||
| ) -> None: | |||
| """ | |||
| Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. | |||
| """ | |||
| # Update serialized data | |||
| json_converter = WorkflowRuntimeTypeConverter() | |||
| node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" | |||
| node_execution.process_data = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" | |||
| ) | |||
| node_execution.outputs = ( | |||
| json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" | |||
| ) | |||
| # Convert metadata enum keys to strings for JSON serialization | |||
| if execution.metadata: | |||
| metadata_for_json = { | |||
| key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() | |||
| } | |||
| node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) | |||
| else: | |||
| node_execution.execution_metadata = "{}" | |||
| # Update other fields | |||
| node_execution.status = execution.status.value | |||
| node_execution.error = execution.error | |||
| node_execution.elapsed_time = execution.elapsed_time | |||
| node_execution.finished_at = execution.finished_at | |||
| @@ -0,0 +1,247 @@ | |||
| """ | |||
| Unit tests for CeleryWorkflowExecutionRepository. | |||
| These tests verify the Celery-based asynchronous storage functionality | |||
| for workflow execution data. | |||
| """ | |||
| from datetime import UTC, datetime | |||
| from unittest.mock import Mock, patch | |||
| from uuid import uuid4 | |||
| import pytest | |||
| from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository | |||
| from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType | |||
| from models import Account, EndUser | |||
| from models.enums import WorkflowRunTriggeredFrom | |||
| @pytest.fixture | |||
| def mock_session_factory(): | |||
| """Mock SQLAlchemy session factory.""" | |||
| from sqlalchemy import create_engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| # Create a real sessionmaker with in-memory SQLite for testing | |||
| engine = create_engine("sqlite:///:memory:") | |||
| return sessionmaker(bind=engine) | |||
| @pytest.fixture | |||
| def mock_account(): | |||
| """Mock Account user.""" | |||
| account = Mock(spec=Account) | |||
| account.id = str(uuid4()) | |||
| account.current_tenant_id = str(uuid4()) | |||
| return account | |||
| @pytest.fixture | |||
| def mock_end_user(): | |||
| """Mock EndUser.""" | |||
| user = Mock(spec=EndUser) | |||
| user.id = str(uuid4()) | |||
| user.tenant_id = str(uuid4()) | |||
| return user | |||
| @pytest.fixture | |||
| def sample_workflow_execution(): | |||
| """Sample WorkflowExecution for testing.""" | |||
| return WorkflowExecution.new( | |||
| id_=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_version="1.0", | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"input1": "value1"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| class TestCeleryWorkflowExecutionRepository: | |||
| """Test cases for CeleryWorkflowExecutionRepository.""" | |||
| def test_init_with_sessionmaker(self, mock_session_factory, mock_account): | |||
| """Test repository initialization with sessionmaker.""" | |||
| app_id = "test-app-id" | |||
| triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| assert repo._tenant_id == mock_account.current_tenant_id | |||
| assert repo._app_id == app_id | |||
| assert repo._triggered_from == triggered_from | |||
| assert repo._creator_user_id == mock_account.id | |||
| assert repo._creator_user_role is not None | |||
| def test_init_basic_functionality(self, mock_session_factory, mock_account): | |||
| """Test repository initialization basic functionality.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, | |||
| ) | |||
| # Verify basic initialization | |||
| assert repo._tenant_id == mock_account.current_tenant_id | |||
| assert repo._app_id == "test-app" | |||
| assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING | |||
| def test_init_with_end_user(self, mock_session_factory, mock_end_user): | |||
| """Test repository initialization with EndUser.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_end_user, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| assert repo._tenant_id == mock_end_user.tenant_id | |||
| def test_init_without_tenant_id_raises_error(self, mock_session_factory): | |||
| """Test that initialization fails without tenant_id.""" | |||
| # Create a mock Account with no tenant_id | |||
| user = Mock(spec=Account) | |||
| user.current_tenant_id = None | |||
| user.id = str(uuid4()) | |||
| with pytest.raises(ValueError, match="User must have a tenant_id"): | |||
| CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=user, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") | |||
| def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution): | |||
| """Test that save operation queues a Celery task without tracking.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| repo.save(sample_workflow_execution) | |||
| # Verify Celery task was queued with correct parameters | |||
| mock_task.delay.assert_called_once() | |||
| call_args = mock_task.delay.call_args[1] | |||
| assert call_args["execution_data"] == sample_workflow_execution.model_dump() | |||
| assert call_args["tenant_id"] == mock_account.current_tenant_id | |||
| assert call_args["app_id"] == "test-app" | |||
| assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value | |||
| assert call_args["creator_user_id"] == mock_account.id | |||
| # Verify no task tracking occurs (no _pending_saves attribute) | |||
| assert not hasattr(repo, "_pending_saves") | |||
| @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") | |||
| def test_save_handles_celery_failure( | |||
| self, mock_task, mock_session_factory, mock_account, sample_workflow_execution | |||
| ): | |||
| """Test that save operation handles Celery task failures.""" | |||
| mock_task.delay.side_effect = Exception("Celery is down") | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| with pytest.raises(Exception, match="Celery is down"): | |||
| repo.save(sample_workflow_execution) | |||
| @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") | |||
| def test_save_operation_fire_and_forget( | |||
| self, mock_task, mock_session_factory, mock_account, sample_workflow_execution | |||
| ): | |||
| """Test that save operation works in fire-and-forget mode.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| # Test that save doesn't block or maintain state | |||
| repo.save(sample_workflow_execution) | |||
| # Verify no pending saves are tracked (no _pending_saves attribute) | |||
| assert not hasattr(repo, "_pending_saves") | |||
| @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") | |||
| def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account): | |||
| """Test multiple save operations work correctly.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| # Create multiple executions | |||
| exec1 = WorkflowExecution.new( | |||
| id_=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_version="1.0", | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"input1": "value1"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| exec2 = WorkflowExecution.new( | |||
| id_=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_version="1.0", | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"input2": "value2"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Save both executions | |||
| repo.save(exec1) | |||
| repo.save(exec2) | |||
| # Should work without issues and not maintain state (no _pending_saves attribute) | |||
| assert not hasattr(repo, "_pending_saves") | |||
| @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") | |||
| def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user): | |||
| """Test save operation with different user types.""" | |||
| repo = CeleryWorkflowExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_end_user, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| execution = WorkflowExecution.new( | |||
| id_=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_version="1.0", | |||
| graph={"nodes": [], "edges": []}, | |||
| inputs={"input1": "value1"}, | |||
| started_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| repo.save(execution) | |||
| # Verify task was called with EndUser context | |||
| mock_task.delay.assert_called_once() | |||
| call_args = mock_task.delay.call_args[1] | |||
| assert call_args["tenant_id"] == mock_end_user.tenant_id | |||
| assert call_args["creator_user_id"] == mock_end_user.id | |||
| @@ -0,0 +1,349 @@ | |||
| """ | |||
| Unit tests for CeleryWorkflowNodeExecutionRepository. | |||
| These tests verify the Celery-based asynchronous storage functionality | |||
| for workflow node execution data. | |||
| """ | |||
| from datetime import UTC, datetime | |||
| from unittest.mock import Mock, patch | |||
| from uuid import uuid4 | |||
| import pytest | |||
| from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository | |||
| from core.workflow.entities.workflow_node_execution import ( | |||
| WorkflowNodeExecution, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.repositories.workflow_node_execution_repository import OrderConfig | |||
| from models import Account, EndUser | |||
| from models.workflow import WorkflowNodeExecutionTriggeredFrom | |||
| @pytest.fixture | |||
| def mock_session_factory(): | |||
| """Mock SQLAlchemy session factory.""" | |||
| from sqlalchemy import create_engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| # Create a real sessionmaker with in-memory SQLite for testing | |||
| engine = create_engine("sqlite:///:memory:") | |||
| return sessionmaker(bind=engine) | |||
| @pytest.fixture | |||
| def mock_account(): | |||
| """Mock Account user.""" | |||
| account = Mock(spec=Account) | |||
| account.id = str(uuid4()) | |||
| account.current_tenant_id = str(uuid4()) | |||
| return account | |||
| @pytest.fixture | |||
| def mock_end_user(): | |||
| """Mock EndUser.""" | |||
| user = Mock(spec=EndUser) | |||
| user.id = str(uuid4()) | |||
| user.tenant_id = str(uuid4()) | |||
| return user | |||
| @pytest.fixture | |||
| def sample_workflow_node_execution(): | |||
| """Sample WorkflowNodeExecution for testing.""" | |||
| return WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| node_execution_id=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_execution_id=str(uuid4()), | |||
| index=1, | |||
| node_id="test_node", | |||
| node_type=NodeType.START, | |||
| title="Test Node", | |||
| inputs={"input1": "value1"}, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| class TestCeleryWorkflowNodeExecutionRepository: | |||
| """Test cases for CeleryWorkflowNodeExecutionRepository.""" | |||
| def test_init_with_sessionmaker(self, mock_session_factory, mock_account): | |||
| """Test repository initialization with sessionmaker.""" | |||
| app_id = "test-app-id" | |||
| triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| assert repo._tenant_id == mock_account.current_tenant_id | |||
| assert repo._app_id == app_id | |||
| assert repo._triggered_from == triggered_from | |||
| assert repo._creator_user_id == mock_account.id | |||
| assert repo._creator_user_role is not None | |||
| def test_init_with_cache_initialized(self, mock_session_factory, mock_account): | |||
| """Test repository initialization with cache properly initialized.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| assert repo._execution_cache == {} | |||
| assert repo._workflow_execution_mapping == {} | |||
| def test_init_with_end_user(self, mock_session_factory, mock_end_user): | |||
| """Test repository initialization with EndUser.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_end_user, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| assert repo._tenant_id == mock_end_user.tenant_id | |||
| def test_init_without_tenant_id_raises_error(self, mock_session_factory): | |||
| """Test that initialization fails without tenant_id.""" | |||
| # Create a mock Account with no tenant_id | |||
| user = Mock(spec=Account) | |||
| user.current_tenant_id = None | |||
| user.id = str(uuid4()) | |||
| with pytest.raises(ValueError, match="User must have a tenant_id"): | |||
| CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=user, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_save_caches_and_queues_celery_task( | |||
| self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution | |||
| ): | |||
| """Test that save operation caches execution and queues a Celery task.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| repo.save(sample_workflow_node_execution) | |||
| # Verify Celery task was queued with correct parameters | |||
| mock_task.delay.assert_called_once() | |||
| call_args = mock_task.delay.call_args[1] | |||
| assert call_args["execution_data"] == sample_workflow_node_execution.model_dump() | |||
| assert call_args["tenant_id"] == mock_account.current_tenant_id | |||
| assert call_args["app_id"] == "test-app" | |||
| assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||
| assert call_args["creator_user_id"] == mock_account.id | |||
| # Verify execution is cached | |||
| assert sample_workflow_node_execution.id in repo._execution_cache | |||
| assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution | |||
| # Verify workflow execution mapping is updated | |||
| assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping | |||
| assert ( | |||
| sample_workflow_node_execution.id | |||
| in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id] | |||
| ) | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_save_handles_celery_failure( | |||
| self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution | |||
| ): | |||
| """Test that save operation handles Celery task failures.""" | |||
| mock_task.delay.side_effect = Exception("Celery is down") | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| with pytest.raises(Exception, match="Celery is down"): | |||
| repo.save(sample_workflow_node_execution) | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_get_by_workflow_run_from_cache( | |||
| self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution | |||
| ): | |||
| """Test that get_by_workflow_run retrieves executions from cache.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Save execution to cache first | |||
| repo.save(sample_workflow_node_execution) | |||
| workflow_run_id = sample_workflow_node_execution.workflow_execution_id | |||
| order_config = OrderConfig(order_by=["index"], order_direction="asc") | |||
| result = repo.get_by_workflow_run(workflow_run_id, order_config) | |||
| # Verify results were retrieved from cache | |||
| assert len(result) == 1 | |||
| assert result[0].id == sample_workflow_node_execution.id | |||
| assert result[0] is sample_workflow_node_execution | |||
| def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): | |||
| """Test get_by_workflow_run without order configuration.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| result = repo.get_by_workflow_run("workflow-run-id") | |||
| # Should return empty list since nothing in cache | |||
| assert len(result) == 0 | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution): | |||
| """Test cache operations work correctly.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Test saving to cache | |||
| repo.save(sample_workflow_node_execution) | |||
| # Verify cache contains the execution | |||
| assert sample_workflow_node_execution.id in repo._execution_cache | |||
| # Test retrieving from cache | |||
| result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) | |||
| assert len(result) == 1 | |||
| assert result[0].id == sample_workflow_node_execution.id | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account): | |||
| """Test multiple executions for the same workflow.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Create multiple executions for the same workflow | |||
| workflow_run_id = str(uuid4()) | |||
| exec1 = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| node_execution_id=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_execution_id=workflow_run_id, | |||
| index=1, | |||
| node_id="node1", | |||
| node_type=NodeType.START, | |||
| title="Node 1", | |||
| inputs={"input1": "value1"}, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| exec2 = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| node_execution_id=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_execution_id=workflow_run_id, | |||
| index=2, | |||
| node_id="node2", | |||
| node_type=NodeType.LLM, | |||
| title="Node 2", | |||
| inputs={"input2": "value2"}, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Save both executions | |||
| repo.save(exec1) | |||
| repo.save(exec2) | |||
| # Verify both are cached and mapped | |||
| assert len(repo._execution_cache) == 2 | |||
| assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 | |||
| # Test retrieval | |||
| result = repo.get_by_workflow_run(workflow_run_id) | |||
| assert len(result) == 2 | |||
| @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") | |||
| def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account): | |||
| """Test ordering functionality works correctly.""" | |||
| repo = CeleryWorkflowNodeExecutionRepository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_account, | |||
| app_id="test-app", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| # Create executions with different indices | |||
| workflow_run_id = str(uuid4()) | |||
| exec1 = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| node_execution_id=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_execution_id=workflow_run_id, | |||
| index=2, | |||
| node_id="node2", | |||
| node_type=NodeType.START, | |||
| title="Node 2", | |||
| inputs={}, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| exec2 = WorkflowNodeExecution( | |||
| id=str(uuid4()), | |||
| node_execution_id=str(uuid4()), | |||
| workflow_id=str(uuid4()), | |||
| workflow_execution_id=workflow_run_id, | |||
| index=1, | |||
| node_id="node1", | |||
| node_type=NodeType.LLM, | |||
| title="Node 1", | |||
| inputs={}, | |||
| status=WorkflowNodeExecutionStatus.RUNNING, | |||
| created_at=datetime.now(UTC).replace(tzinfo=None), | |||
| ) | |||
| # Save in random order | |||
| repo.save(exec1) | |||
| repo.save(exec2) | |||
| # Test ascending order | |||
| order_config = OrderConfig(order_by=["index"], order_direction="asc") | |||
| result = repo.get_by_workflow_run(workflow_run_id, order_config) | |||
| assert len(result) == 2 | |||
| assert result[0].index == 1 | |||
| assert result[1].index == 2 | |||
| # Test descending order | |||
| order_config = OrderConfig(order_by=["index"], order_direction="desc") | |||
| result = repo.get_by_workflow_run(workflow_run_id, order_config) | |||
| assert len(result) == 2 | |||
| assert result[0].index == 2 | |||
| assert result[1].index == 1 | |||
| @@ -59,7 +59,7 @@ class TestRepositoryFactory: | |||
| def get_by_id(self): | |||
| pass | |||
| # Create a mock interface with the same methods | |||
| # Create a mock interface class | |||
| class MockInterface: | |||
| def save(self): | |||
| pass | |||
| @@ -67,20 +67,20 @@ class TestRepositoryFactory: | |||
| def get_by_id(self): | |||
| pass | |||
| # Should not raise an exception | |||
| # Should not raise an exception when all methods are present | |||
| DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) | |||
| def test_validate_repository_interface_missing_methods(self): | |||
| """Test interface validation with missing methods.""" | |||
| # Create a mock class that doesn't implement all required methods | |||
| # Create a mock class that's missing required methods | |||
| class IncompleteRepository: | |||
| def save(self): | |||
| pass | |||
| # Missing get_by_id method | |||
| # Create a mock interface with required methods | |||
| # Create a mock interface that requires both methods | |||
| class MockInterface: | |||
| def save(self): | |||
| pass | |||
| @@ -88,57 +88,39 @@ class TestRepositoryFactory: | |||
| def get_by_id(self): | |||
| pass | |||
| def missing_method(self): | |||
| pass | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) | |||
| assert "does not implement required methods" in str(exc_info.value) | |||
| assert "get_by_id" in str(exc_info.value) | |||
| def test_validate_constructor_signature_success(self): | |||
| """Test successful constructor signature validation.""" | |||
| def test_validate_repository_interface_with_private_methods(self): | |||
| """Test that private methods are ignored during interface validation.""" | |||
| class MockRepository: | |||
| def __init__(self, session_factory, user, app_id, triggered_from): | |||
| def save(self): | |||
| pass | |||
| # Should not raise an exception | |||
| DifyCoreRepositoryFactory._validate_constructor_signature( | |||
| MockRepository, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| def test_validate_constructor_signature_missing_params(self): | |||
| """Test constructor validation with missing parameters.""" | |||
| class IncompleteRepository: | |||
| def __init__(self, session_factory, user): | |||
| # Missing app_id and triggered_from parameters | |||
| def _private_method(self): | |||
| pass | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory._validate_constructor_signature( | |||
| IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| assert "does not accept required parameters" in str(exc_info.value) | |||
| assert "app_id" in str(exc_info.value) | |||
| assert "triggered_from" in str(exc_info.value) | |||
| def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture): | |||
| """Test constructor validation when inspection fails.""" | |||
| # Mock inspect.signature to raise an exception | |||
| mocker.patch("inspect.signature", side_effect=Exception("Inspection failed")) | |||
| # Create a mock interface with private methods | |||
| class MockInterface: | |||
| def save(self): | |||
| pass | |||
| class MockRepository: | |||
| def __init__(self, session_factory): | |||
| def _private_method(self): | |||
| pass | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) | |||
| assert "Failed to validate constructor signature" in str(exc_info.value) | |||
| # Should not raise exception - private methods should be ignored | |||
| DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture): | |||
| """Test successful creation of WorkflowExecutionRepository.""" | |||
| def test_create_workflow_execution_repository_success(self, mock_config): | |||
| """Test successful WorkflowExecutionRepository creation.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| # Create mock dependencies | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| @@ -146,7 +128,7 @@ class TestRepositoryFactory: | |||
| app_id = "test-app-id" | |||
| triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| # Mock the imported class to be a valid repository | |||
| # Create mock repository class and instance | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) | |||
| mock_repository_class.return_value = mock_repository_instance | |||
| @@ -155,7 +137,6 @@ class TestRepositoryFactory: | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), | |||
| ): | |||
| result = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| @@ -177,7 +158,7 @@ class TestRepositoryFactory: | |||
| def test_create_workflow_execution_repository_import_error(self, mock_config): | |||
| """Test WorkflowExecutionRepository creation with import error.""" | |||
| # Setup mock configuration with invalid class path | |||
| mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" | |||
| mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=Account) | |||
| @@ -195,45 +176,46 @@ class TestRepositoryFactory: | |||
| def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): | |||
| """Test WorkflowExecutionRepository creation with validation error.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=Account) | |||
| # Mock import to succeed but validation to fail | |||
| # Mock the import to succeed but validation to fail | |||
| mock_repository_class = MagicMock() | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object( | |||
| DifyCoreRepositoryFactory, | |||
| "_validate_repository_interface", | |||
| side_effect=RepositoryImportError("Interface validation failed"), | |||
| ), | |||
| ): | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| assert "Interface validation failed" in str(exc_info.value) | |||
| mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) | |||
| mocker.patch.object( | |||
| DifyCoreRepositoryFactory, | |||
| "_validate_repository_interface", | |||
| side_effect=RepositoryImportError("Interface validation failed"), | |||
| ) | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| assert "Interface validation failed" in str(exc_info.value) | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture): | |||
| def test_create_workflow_execution_repository_instantiation_error(self, mock_config): | |||
| """Test WorkflowExecutionRepository creation with instantiation error.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=Account) | |||
| # Mock import and validation to succeed but instantiation to fail | |||
| mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) | |||
| # Create a mock repository class that raises exception on instantiation | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_class.side_effect = Exception("Instantiation failed") | |||
| # Mock the validation methods to succeed | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), | |||
| ): | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| @@ -245,18 +227,18 @@ class TestRepositoryFactory: | |||
| assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture): | |||
| """Test successful creation of WorkflowNodeExecutionRepository.""" | |||
| def test_create_workflow_node_execution_repository_success(self, mock_config): | |||
| """Test successful WorkflowNodeExecutionRepository creation.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| # Create mock dependencies | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=EndUser) | |||
| app_id = "test-app-id" | |||
| triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN | |||
| triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP | |||
| # Mock the imported class to be a valid repository | |||
| # Create mock repository class and instance | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) | |||
| mock_repository_class.return_value = mock_repository_instance | |||
| @@ -265,7 +247,6 @@ class TestRepositoryFactory: | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), | |||
| ): | |||
| result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| @@ -287,7 +268,7 @@ class TestRepositoryFactory: | |||
| def test_create_workflow_node_execution_repository_import_error(self, mock_config): | |||
| """Test WorkflowNodeExecutionRepository creation with import error.""" | |||
| # Setup mock configuration with invalid class path | |||
| mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" | |||
| mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=EndUser) | |||
| @@ -297,159 +278,104 @@ class TestRepositoryFactory: | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| assert "Cannot import repository class" in str(exc_info.value) | |||
| def test_repository_import_error_exception(self): | |||
| """Test RepositoryImportError exception.""" | |||
| error_message = "Test error message" | |||
| exception = RepositoryImportError(error_message) | |||
| assert str(exception) == error_message | |||
| assert isinstance(exception, Exception) | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture): | |||
| """Test repository creation with Engine instead of sessionmaker.""" | |||
| def test_create_workflow_node_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): | |||
| """Test WorkflowNodeExecutionRepository creation with validation error.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| # Create mock dependencies with Engine instead of sessionmaker | |||
| mock_engine = MagicMock(spec=Engine) | |||
| mock_user = MagicMock(spec=Account) | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=EndUser) | |||
| # Mock the imported class to be a valid repository | |||
| # Mock the import to succeed but validation to fail | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) | |||
| mock_repository_class.return_value = mock_repository_instance | |||
| # Mock the validation methods | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), | |||
| ): | |||
| result = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=mock_engine, # Using Engine instead of sessionmaker | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| ) | |||
| mocker.patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class) | |||
| mocker.patch.object( | |||
| DifyCoreRepositoryFactory, | |||
| "_validate_repository_interface", | |||
| side_effect=RepositoryImportError("Interface validation failed"), | |||
| ) | |||
| # Verify the repository was created with the Engine | |||
| mock_repository_class.assert_called_once_with( | |||
| session_factory=mock_engine, | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowRunTriggeredFrom.APP_RUN, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| assert result is mock_repository_instance | |||
| assert "Interface validation failed" in str(exc_info.value) | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_workflow_node_execution_repository_validation_error(self, mock_config): | |||
| """Test WorkflowNodeExecutionRepository creation with validation error.""" | |||
| def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): | |||
| """Test WorkflowNodeExecutionRepository creation with instantiation error.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=EndUser) | |||
| # Mock import to succeed but validation to fail | |||
| # Create a mock repository class that raises exception on instantiation | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_class.side_effect = Exception("Instantiation failed") | |||
| # Mock the validation methods to succeed | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object( | |||
| DifyCoreRepositoryFactory, | |||
| "_validate_repository_interface", | |||
| side_effect=RepositoryImportError("Interface validation failed"), | |||
| ), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| ): | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, | |||
| ) | |||
| assert "Interface validation failed" in str(exc_info.value) | |||
| assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) | |||
| def test_repository_import_error_exception(self): | |||
| """Test RepositoryImportError exception handling.""" | |||
| error_message = "Custom error message" | |||
| error = RepositoryImportError(error_message) | |||
| assert str(error) == error_message | |||
| @patch("core.repositories.factory.dify_config") | |||
| def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): | |||
| """Test WorkflowNodeExecutionRepository creation with instantiation error.""" | |||
| def test_create_with_engine_instead_of_sessionmaker(self, mock_config): | |||
| """Test repository creation with Engine instead of sessionmaker.""" | |||
| # Setup mock configuration | |||
| mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" | |||
| mock_session_factory = MagicMock(spec=sessionmaker) | |||
| mock_user = MagicMock(spec=EndUser) | |||
| # Create mock dependencies using Engine instead of sessionmaker | |||
| mock_engine = MagicMock(spec=Engine) | |||
| mock_user = MagicMock(spec=Account) | |||
| app_id = "test-app-id" | |||
| triggered_from = WorkflowRunTriggeredFrom.APP_RUN | |||
| # Mock import and validation to succeed but instantiation to fail | |||
| mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) | |||
| # Create mock repository class and instance | |||
| mock_repository_class = MagicMock() | |||
| mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) | |||
| mock_repository_class.return_value = mock_repository_instance | |||
| # Mock the validation methods | |||
| with ( | |||
| patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), | |||
| patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), | |||
| ): | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory.create_workflow_node_execution_repository( | |||
| session_factory=mock_session_factory, | |||
| user=mock_user, | |||
| app_id="test-app-id", | |||
| triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, | |||
| ) | |||
| assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) | |||
| def test_validate_repository_interface_with_private_methods(self): | |||
| """Test interface validation ignores private methods.""" | |||
| # Create a mock class with private methods | |||
| class MockRepository: | |||
| def save(self): | |||
| pass | |||
| def get_by_id(self): | |||
| pass | |||
| def _private_method(self): | |||
| pass | |||
| # Create a mock interface with private methods | |||
| class MockInterface: | |||
| def save(self): | |||
| pass | |||
| def get_by_id(self): | |||
| pass | |||
| def _private_method(self): | |||
| pass | |||
| # Should not raise an exception (private methods are ignored) | |||
| DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) | |||
| def test_validate_constructor_signature_with_extra_params(self): | |||
| """Test constructor validation with extra parameters (should pass).""" | |||
| class MockRepository: | |||
| def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None): | |||
| pass | |||
| # Should not raise an exception (extra parameters are allowed) | |||
| DifyCoreRepositoryFactory._validate_constructor_signature( | |||
| MockRepository, ["session_factory", "user", "app_id", "triggered_from"] | |||
| ) | |||
| def test_validate_constructor_signature_with_kwargs(self): | |||
| """Test constructor validation with **kwargs (current implementation doesn't support this).""" | |||
| class MockRepository: | |||
| def __init__(self, session_factory, user, **kwargs): | |||
| pass | |||
| result = DifyCoreRepositoryFactory.create_workflow_execution_repository( | |||
| session_factory=mock_engine, # Using Engine instead of sessionmaker | |||
| user=mock_user, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| # Current implementation doesn't handle **kwargs, so this should raise an exception | |||
| with pytest.raises(RepositoryImportError) as exc_info: | |||
| DifyCoreRepositoryFactory._validate_constructor_signature( | |||
| MockRepository, ["session_factory", "user", "app_id", "triggered_from"] | |||
| # Verify the repository was created with correct parameters | |||
| mock_repository_class.assert_called_once_with( | |||
| session_factory=mock_engine, | |||
| user=mock_user, | |||
| app_id=app_id, | |||
| triggered_from=triggered_from, | |||
| ) | |||
| assert "does not accept required parameters" in str(exc_info.value) | |||
| assert "app_id" in str(exc_info.value) | |||
| assert "triggered_from" in str(exc_info.value) | |||
| assert result is mock_repository_instance | |||
| @@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.." | |||
| uv --directory api run \ | |||
| celery -A app.celery worker \ | |||
| -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion | |||
| -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage | |||
| @@ -861,17 +861,23 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms | |||
| # Repository configuration | |||
| # Core workflow execution repository implementation | |||
| # Options: | |||
| # - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) | |||
| # - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository | |||
| CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository | |||
| # Core workflow node execution repository implementation | |||
| # Options: | |||
| # - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) | |||
| # - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository | |||
| CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository | |||
| # API workflow node execution repository implementation | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository | |||
| # API workflow run repository implementation | |||
| API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository | |||
| # API workflow node execution repository implementation | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository | |||
| # HTTP request node in workflow configuration | |||
| HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 | |||
| HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 | |||
| @@ -390,8 +390,8 @@ x-shared-env: &shared-api-worker-env | |||
| WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} | |||
| CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository} | |||
| CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} | |||
| API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} | |||
| API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} | |||
| HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} | |||
| HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} | |||
| HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} | |||