| @@ -3,6 +3,7 @@ import logging | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, fields, marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy import exists, select | |||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||
| from controllers.console import api | |||
| @@ -94,21 +95,18 @@ class ChatMessageListApi(Resource): | |||
| .all() | |||
| ) | |||
| has_more = False | |||
| if len(history_messages) == args["limit"]: | |||
| current_page_first_message = history_messages[-1] | |||
| rest_count = ( | |||
| db.session.query(Message) | |||
| .where( | |||
| has_more = db.session.scalar( | |||
| select( | |||
| exists().where( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| if rest_count > 0: | |||
| has_more = True | |||
| ) | |||
| history_messages = list(reversed(history_messages)) | |||
| @@ -8,20 +8,21 @@ from uuid import UUID | |||
| import numpy as np | |||
| import pytz | |||
| from flask_login import current_user | |||
| from core.file import File, FileTransferMethod, FileType | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| logger = logging.getLogger(__name__) | |||
| def safe_json_value(v): | |||
| if isinstance(v, datetime): | |||
| tz_name = getattr(current_user, "timezone", None) if current_user is not None else None | |||
| if not tz_name: | |||
| tz_name = "UTC" | |||
| tz_name = "UTC" | |||
| if isinstance(current_user, Account) and current_user.timezone is not None: | |||
| tz_name = current_user.timezone | |||
| return v.astimezone(pytz.timezone(tz_name)).isoformat() | |||
| elif isinstance(v, date): | |||
| return v.isoformat() | |||
| @@ -46,7 +47,7 @@ def safe_json_value(v): | |||
| return v | |||
| def safe_json_dict(d): | |||
| def safe_json_dict(d: dict): | |||
| if not isinstance(d, dict): | |||
| raise TypeError("safe_json_dict() expects a dictionary (dict) as input") | |||
| return {k: safe_json_value(v) for k, v in d.items()} | |||
| @@ -3,8 +3,6 @@ import logging | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, cast | |||
| from flask_login import current_user | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| @@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import ( | |||
| from core.tools.errors import ToolInvokeError | |||
| from extensions.ext_database import db | |||
| from factories.file_factory import build_from_mapping | |||
| from models.account import Account | |||
| from models.model import App, EndUser | |||
| from libs.login import current_user | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| logger = logging.getLogger(__name__) | |||
| @@ -79,11 +77,11 @@ class WorkflowTool(Tool): | |||
| generator = WorkflowAppGenerator() | |||
| assert self.runtime is not None | |||
| assert self.runtime.invoke_from is not None | |||
| assert current_user is not None | |||
| result = generator.generate( | |||
| app_model=app, | |||
| workflow=workflow, | |||
| user=cast("Account | EndUser", current_user), | |||
| user=current_user, | |||
| args={"inputs": tool_parameters, "files": files}, | |||
| invoke_from=self.runtime.invoke_from, | |||
| streaming=False, | |||
| @@ -66,6 +66,7 @@ class NodeExecutionType(StrEnum): | |||
| RESPONSE = "response" # Response nodes that stream outputs (Answer, End) | |||
| BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) | |||
| CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) | |||
| ROOT = "root" # Nodes that can serve as execution entry points | |||
| class ErrorStrategy(StrEnum): | |||
| @@ -1,9 +1,9 @@ | |||
| import logging | |||
| from collections import defaultdict | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, Protocol, cast | |||
| from typing import Any, Protocol, cast | |||
| from core.workflow.enums import NodeType | |||
| from core.workflow.enums import NodeExecutionType, NodeState, NodeType | |||
| from core.workflow.nodes.base.node import Node | |||
| from .edge import Edge | |||
| @@ -36,10 +36,10 @@ class Graph: | |||
| def __init__( | |||
| self, | |||
| *, | |||
| nodes: Optional[dict[str, Node]] = None, | |||
| edges: Optional[dict[str, Edge]] = None, | |||
| in_edges: Optional[dict[str, list[str]]] = None, | |||
| out_edges: Optional[dict[str, list[str]]] = None, | |||
| nodes: dict[str, Node] | None = None, | |||
| edges: dict[str, Edge] | None = None, | |||
| in_edges: dict[str, list[str]] | None = None, | |||
| out_edges: dict[str, list[str]] | None = None, | |||
| root_node: Node, | |||
| ): | |||
| """ | |||
| @@ -81,7 +81,7 @@ class Graph: | |||
| cls, | |||
| node_configs_map: dict[str, dict[str, Any]], | |||
| edge_configs: list[dict[str, Any]], | |||
| root_node_id: Optional[str] = None, | |||
| root_node_id: str | None = None, | |||
| ) -> str: | |||
| """ | |||
| Find the root node ID if not specified. | |||
| @@ -186,13 +186,79 @@ class Graph: | |||
| return nodes | |||
| @classmethod | |||
| def _mark_inactive_root_branches( | |||
| cls, | |||
| nodes: dict[str, Node], | |||
| edges: dict[str, Edge], | |||
| in_edges: dict[str, list[str]], | |||
| out_edges: dict[str, list[str]], | |||
| active_root_id: str, | |||
| ) -> None: | |||
| """ | |||
| Mark nodes and edges from inactive root branches as skipped. | |||
| Algorithm: | |||
| 1. Mark inactive root nodes as skipped | |||
| 2. For skipped nodes, mark all their outgoing edges as skipped | |||
| 3. For each edge marked as skipped, check its target node: | |||
| - If ALL incoming edges are skipped, mark the node as skipped | |||
| - Otherwise, leave the node state unchanged | |||
| :param nodes: mapping of node ID to node instance | |||
| :param edges: mapping of edge ID to edge instance | |||
| :param in_edges: mapping of node ID to incoming edge IDs | |||
| :param out_edges: mapping of node ID to outgoing edge IDs | |||
| :param active_root_id: ID of the active root node | |||
| """ | |||
| # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) | |||
| top_level_roots: list[str] = [ | |||
| node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT | |||
| ] | |||
| # If there's only one root or the active root is not a top-level root, no marking needed | |||
| if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: | |||
| return | |||
| # Mark inactive root nodes as skipped | |||
| inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] | |||
| for root_id in inactive_roots: | |||
| if root_id in nodes: | |||
| nodes[root_id].state = NodeState.SKIPPED | |||
| # Recursively mark downstream nodes and edges | |||
| def mark_downstream(node_id: str) -> None: | |||
| """Recursively mark downstream nodes and edges as skipped.""" | |||
| if nodes[node_id].state != NodeState.SKIPPED: | |||
| return | |||
| # If this node is skipped, mark all its outgoing edges as skipped | |||
| out_edge_ids = out_edges.get(node_id, []) | |||
| for edge_id in out_edge_ids: | |||
| edge = edges[edge_id] | |||
| edge.state = NodeState.SKIPPED | |||
| # Check the target node of this edge | |||
| target_node = nodes[edge.head] | |||
| in_edge_ids = in_edges.get(target_node.id, []) | |||
| in_edge_states = [edges[eid].state for eid in in_edge_ids] | |||
| # If all incoming edges are skipped, mark the node as skipped | |||
| if all(state == NodeState.SKIPPED for state in in_edge_states): | |||
| target_node.state = NodeState.SKIPPED | |||
| # Recursively process downstream nodes | |||
| mark_downstream(target_node.id) | |||
| # Process each inactive root and its downstream nodes | |||
| for root_id in inactive_roots: | |||
| mark_downstream(root_id) | |||
| @classmethod | |||
| def init( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_factory: "NodeFactory", | |||
| root_node_id: Optional[str] = None, | |||
| root_node_id: str | None = None, | |||
| ) -> "Graph": | |||
| """ | |||
| Initialize graph | |||
| @@ -227,6 +293,9 @@ class Graph: | |||
| # Get root node instance | |||
| root_node = nodes[root_node_id] | |||
| # Mark inactive root branches as skipped | |||
| cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) | |||
| # Create and return the graph | |||
| return cls( | |||
| nodes=nodes, | |||
| @@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi | |||
| """ | |||
| from queue import Queue | |||
| from typing import final | |||
| from ..entities.commands import GraphEngineCommand | |||
| @final | |||
| class InMemoryChannel: | |||
| """ | |||
| In-memory command channel implementation using a thread-safe queue. | |||
| @@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue. | |||
| """ | |||
| import json | |||
| from typing import TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, final | |||
| from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand | |||
| @@ -15,6 +15,7 @@ if TYPE_CHECKING: | |||
| from extensions.ext_redis import RedisClientWrapper | |||
| @final | |||
| class RedisChannel: | |||
| """ | |||
| Redis-based command channel implementation for distributed systems. | |||
| @@ -86,7 +87,7 @@ class RedisChannel: | |||
| pipe.expire(self._key, self._command_ttl) | |||
| pipe.execute() | |||
| def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]: | |||
| def _deserialize_command(self, data: dict) -> GraphEngineCommand | None: | |||
| """ | |||
| Deserialize a command from dictionary data. | |||
| @@ -3,6 +3,7 @@ Command handler implementations. | |||
| """ | |||
| import logging | |||
| from typing import final | |||
| from ..domain.graph_execution import GraphExecution | |||
| from ..entities.commands import AbortCommand, GraphEngineCommand | |||
| @@ -11,6 +12,7 @@ from .command_processor import CommandHandler | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class AbortCommandHandler(CommandHandler): | |||
| """Handles abort commands.""" | |||
| @@ -3,7 +3,7 @@ Main command processor for handling external commands. | |||
| """ | |||
| import logging | |||
| from typing import Protocol | |||
| from typing import Protocol, final | |||
| from ..domain.graph_execution import GraphExecution | |||
| from ..entities.commands import GraphEngineCommand | |||
| @@ -18,6 +18,7 @@ class CommandHandler(Protocol): | |||
| def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... | |||
| @final | |||
| class CommandProcessor: | |||
| """ | |||
| Processes external commands sent to the engine. | |||
| @@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state. | |||
| """ | |||
| from dataclasses import dataclass, field | |||
| from typing import Optional | |||
| from .node_execution import NodeExecution | |||
| @@ -21,7 +20,7 @@ class GraphExecution: | |||
| started: bool = False | |||
| completed: bool = False | |||
| aborted: bool = False | |||
| error: Optional[Exception] = None | |||
| error: Exception | None = None | |||
| node_executions: dict[str, NodeExecution] = field(default_factory=dict) | |||
| def start(self) -> None: | |||
| @@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state. | |||
| """ | |||
| from dataclasses import dataclass | |||
| from typing import Optional | |||
| from core.workflow.enums import NodeState | |||
| @@ -20,8 +19,8 @@ class NodeExecution: | |||
| node_id: str | |||
| state: NodeState = NodeState.UNKNOWN | |||
| retry_count: int = 0 | |||
| execution_id: Optional[str] = None | |||
| error: Optional[str] = None | |||
| execution_id: str | None = None | |||
| error: str | None = None | |||
| def mark_started(self, execution_id: str) -> None: | |||
| """Mark the node as started with an execution ID.""" | |||
| @@ -6,7 +6,7 @@ instance to control its execution flow. | |||
| """ | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| @@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel): | |||
| """Base class for all GraphEngine commands.""" | |||
| command_type: CommandType = Field(..., description="Type of command") | |||
| payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload") | |||
| payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") | |||
| class AbortCommand(GraphEngineCommand): | |||
| """Command to abort a running workflow execution.""" | |||
| command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") | |||
| reason: Optional[str] = Field(default=None, description="Optional reason for abort") | |||
| reason: str | None = Field(default=None, description="Optional reason for abort") | |||
| @@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns. | |||
| from .abort_strategy import AbortStrategy | |||
| from .default_value_strategy import DefaultValueStrategy | |||
| from .error_handler import ErrorHandler | |||
| from .error_strategy import ErrorStrategy | |||
| from .fail_branch_strategy import FailBranchStrategy | |||
| from .retry_strategy import RetryStrategy | |||
| @@ -16,7 +15,6 @@ __all__ = [ | |||
| "AbortStrategy", | |||
| "DefaultValueStrategy", | |||
| "ErrorHandler", | |||
| "ErrorStrategy", | |||
| "FailBranchStrategy", | |||
| "RetryStrategy", | |||
| ] | |||
| @@ -3,7 +3,7 @@ Abort error strategy implementation. | |||
| """ | |||
| import logging | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent | |||
| @@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class AbortStrategy: | |||
| """ | |||
| Error strategy that aborts execution on failure. | |||
| @@ -19,7 +20,7 @@ class AbortStrategy: | |||
| It stops the entire graph execution when a node fails. | |||
| """ | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle error by aborting execution. | |||
| @@ -2,7 +2,7 @@ | |||
| Default value error strategy implementation. | |||
| """ | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| @@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent | |||
| from core.workflow.node_events import NodeRunResult | |||
| @final | |||
| class DefaultValueStrategy: | |||
| """ | |||
| Error strategy that uses default values on failure. | |||
| @@ -18,7 +19,7 @@ class DefaultValueStrategy: | |||
| predefined default output values. | |||
| """ | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle error by using default values. | |||
| @@ -2,7 +2,7 @@ | |||
| Main error handler that coordinates error strategies. | |||
| """ | |||
| from typing import TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, final | |||
| from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum | |||
| from core.workflow.graph import Graph | |||
| @@ -17,6 +17,7 @@ if TYPE_CHECKING: | |||
| from ..domain import GraphExecution | |||
| @final | |||
| class ErrorHandler: | |||
| """ | |||
| Coordinates error handling strategies for node failures. | |||
| @@ -34,16 +35,16 @@ class ErrorHandler: | |||
| graph: The workflow graph | |||
| graph_execution: The graph execution state | |||
| """ | |||
| self.graph = graph | |||
| self.graph_execution = graph_execution | |||
| self._graph = graph | |||
| self._graph_execution = graph_execution | |||
| # Initialize strategies | |||
| self.abort_strategy = AbortStrategy() | |||
| self.retry_strategy = RetryStrategy() | |||
| self.fail_branch_strategy = FailBranchStrategy() | |||
| self.default_value_strategy = DefaultValueStrategy() | |||
| self._abort_strategy = AbortStrategy() | |||
| self._retry_strategy = RetryStrategy() | |||
| self._fail_branch_strategy = FailBranchStrategy() | |||
| self._default_value_strategy = DefaultValueStrategy() | |||
| def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]: | |||
| def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle a node failure event. | |||
| @@ -56,14 +57,14 @@ class ErrorHandler: | |||
| Returns: | |||
| Optional new event to process, or None to abort | |||
| """ | |||
| node = self.graph.nodes[event.node_id] | |||
| node = self._graph.nodes[event.node_id] | |||
| # Get retry count from NodeExecution | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| retry_count = node_execution.retry_count | |||
| # First check if retry is configured and not exhausted | |||
| if node.retry and retry_count < node.retry_config.max_retries: | |||
| result = self.retry_strategy.handle_error(event, self.graph, retry_count) | |||
| result = self._retry_strategy.handle_error(event, self._graph, retry_count) | |||
| if result: | |||
| # Retry count will be incremented when NodeRunRetryEvent is handled | |||
| return result | |||
| @@ -71,12 +72,10 @@ class ErrorHandler: | |||
| # Apply configured error strategy | |||
| strategy = node.error_strategy | |||
| if strategy is None: | |||
| return self.abort_strategy.handle_error(event, self.graph, retry_count) | |||
| elif strategy == ErrorStrategyEnum.FAIL_BRANCH: | |||
| return self.fail_branch_strategy.handle_error(event, self.graph, retry_count) | |||
| elif strategy == ErrorStrategyEnum.DEFAULT_VALUE: | |||
| return self.default_value_strategy.handle_error(event, self.graph, retry_count) | |||
| else: | |||
| # Unknown strategy, default to abort | |||
| return self.abort_strategy.handle_error(event, self.graph, retry_count) | |||
| match strategy: | |||
| case None: | |||
| return self._abort_strategy.handle_error(event, self._graph, retry_count) | |||
| case ErrorStrategyEnum.FAIL_BRANCH: | |||
| return self._fail_branch_strategy.handle_error(event, self._graph, retry_count) | |||
| case ErrorStrategyEnum.DEFAULT_VALUE: | |||
| return self._default_value_strategy.handle_error(event, self._graph, retry_count) | |||
| @@ -2,7 +2,7 @@ | |||
| Fail branch error strategy implementation. | |||
| """ | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| @@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent | |||
| from core.workflow.node_events import NodeRunResult | |||
| @final | |||
| class FailBranchStrategy: | |||
| """ | |||
| Error strategy that continues execution via a fail branch. | |||
| @@ -18,7 +19,7 @@ class FailBranchStrategy: | |||
| through a designated fail-branch edge. | |||
| """ | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle error by taking the fail branch. | |||
| @@ -3,12 +3,13 @@ Retry error strategy implementation. | |||
| """ | |||
| import time | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent | |||
| @final | |||
| class RetryStrategy: | |||
| """ | |||
| Error strategy that retries failed nodes. | |||
| @@ -17,7 +18,7 @@ class RetryStrategy: | |||
| maximum number of retries with configurable intervals. | |||
| """ | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle error by retrying the node. | |||
| @@ -3,12 +3,92 @@ Event collector for buffering and managing events. | |||
| """ | |||
| import threading | |||
| from typing import final | |||
| from core.workflow.graph_events import GraphEngineEvent | |||
| from ..layers.base import Layer | |||
| @final | |||
| class ReadWriteLock: | |||
| """ | |||
| A read-write lock implementation that allows multiple concurrent readers | |||
| but only one writer at a time. | |||
| """ | |||
| def __init__(self) -> None: | |||
| self._read_ready = threading.Condition(threading.RLock()) | |||
| self._readers = 0 | |||
| def acquire_read(self) -> None: | |||
| """Acquire a read lock.""" | |||
| self._read_ready.acquire() | |||
| try: | |||
| self._readers += 1 | |||
| finally: | |||
| self._read_ready.release() | |||
| def release_read(self) -> None: | |||
| """Release a read lock.""" | |||
| self._read_ready.acquire() | |||
| try: | |||
| self._readers -= 1 | |||
| if self._readers == 0: | |||
| self._read_ready.notify_all() | |||
| finally: | |||
| self._read_ready.release() | |||
| def acquire_write(self) -> None: | |||
| """Acquire a write lock.""" | |||
| self._read_ready.acquire() | |||
| while self._readers > 0: | |||
| self._read_ready.wait() | |||
| def release_write(self) -> None: | |||
| """Release a write lock.""" | |||
| self._read_ready.release() | |||
| def read_lock(self) -> "ReadLockContext": | |||
| """Return a context manager for read locking.""" | |||
| return ReadLockContext(self) | |||
| def write_lock(self) -> "WriteLockContext": | |||
| """Return a context manager for write locking.""" | |||
| return WriteLockContext(self) | |||
| @final | |||
| class ReadLockContext: | |||
| """Context manager for read locks.""" | |||
| def __init__(self, lock: ReadWriteLock) -> None: | |||
| self._lock = lock | |||
| def __enter__(self) -> "ReadLockContext": | |||
| self._lock.acquire_read() | |||
| return self | |||
| def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: | |||
| self._lock.release_read() | |||
| @final | |||
| class WriteLockContext: | |||
| """Context manager for write locks.""" | |||
| def __init__(self, lock: ReadWriteLock) -> None: | |||
| self._lock = lock | |||
| def __enter__(self) -> "WriteLockContext": | |||
| self._lock.acquire_write() | |||
| return self | |||
| def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: | |||
| self._lock.release_write() | |||
| @final | |||
| class EventCollector: | |||
| """ | |||
| Collects and buffers events for later retrieval. | |||
| @@ -20,7 +100,7 @@ class EventCollector: | |||
| def __init__(self) -> None: | |||
| """Initialize the event collector.""" | |||
| self._events: list[GraphEngineEvent] = [] | |||
| self._lock = threading.Lock() | |||
| self._lock = ReadWriteLock() | |||
| self._layers: list[Layer] = [] | |||
| def set_layers(self, layers: list[Layer]) -> None: | |||
| @@ -39,7 +119,7 @@ class EventCollector: | |||
| Args: | |||
| event: The event to collect | |||
| """ | |||
| with self._lock: | |||
| with self._lock.write_lock(): | |||
| self._events.append(event) | |||
| self._notify_layers(event) | |||
| @@ -50,7 +130,7 @@ class EventCollector: | |||
| Returns: | |||
| List of collected events | |||
| """ | |||
| with self._lock: | |||
| with self._lock.read_lock(): | |||
| return list(self._events) | |||
| def get_new_events(self, start_index: int) -> list[GraphEngineEvent]: | |||
| @@ -63,7 +143,7 @@ class EventCollector: | |||
| Returns: | |||
| List of new events | |||
| """ | |||
| with self._lock: | |||
| with self._lock.read_lock(): | |||
| return list(self._events[start_index:]) | |||
| def event_count(self) -> int: | |||
| @@ -73,12 +153,12 @@ class EventCollector: | |||
| Returns: | |||
| Number of collected events | |||
| """ | |||
| with self._lock: | |||
| with self._lock.read_lock(): | |||
| return len(self._events) | |||
| def clear(self) -> None: | |||
| """Clear all collected events.""" | |||
| with self._lock: | |||
| with self._lock.write_lock(): | |||
| self._events.clear() | |||
| def _notify_layers(self, event: GraphEngineEvent) -> None: | |||
| @@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers. | |||
| import threading | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import final | |||
| from core.workflow.graph_events import GraphEngineEvent | |||
| from .event_collector import EventCollector | |||
| @final | |||
| class EventEmitter: | |||
| """ | |||
| Emits collected events as a generator for external consumption. | |||
| @@ -3,7 +3,7 @@ Event handler implementations for different event types. | |||
| """ | |||
| import logging | |||
| from typing import TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, final | |||
| from core.workflow.entities import GraphRuntimeState | |||
| from core.workflow.enums import NodeExecutionType | |||
| @@ -38,6 +38,7 @@ if TYPE_CHECKING: | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class EventHandlerRegistry: | |||
| """ | |||
| Registry of event handlers for different event types. | |||
| @@ -52,12 +53,12 @@ class EventHandlerRegistry: | |||
| graph_runtime_state: GraphRuntimeState, | |||
| graph_execution: GraphExecution, | |||
| response_coordinator: ResponseStreamCoordinator, | |||
| event_collector: Optional["EventCollector"] = None, | |||
| branch_handler: Optional["BranchHandler"] = None, | |||
| edge_processor: Optional["EdgeProcessor"] = None, | |||
| node_state_manager: Optional["NodeStateManager"] = None, | |||
| execution_tracker: Optional["ExecutionTracker"] = None, | |||
| error_handler: Optional["ErrorHandler"] = None, | |||
| event_collector: "EventCollector", | |||
| branch_handler: "BranchHandler", | |||
| edge_processor: "EdgeProcessor", | |||
| node_state_manager: "NodeStateManager", | |||
| execution_tracker: "ExecutionTracker", | |||
| error_handler: "ErrorHandler", | |||
| ) -> None: | |||
| """ | |||
| Initialize the event handler registry. | |||
| @@ -67,23 +68,23 @@ class EventHandlerRegistry: | |||
| graph_runtime_state: Runtime state with variable pool | |||
| graph_execution: Graph execution aggregate | |||
| response_coordinator: Response stream coordinator | |||
| event_collector: Optional event collector for collecting events | |||
| branch_handler: Optional branch handler for branch node processing | |||
| edge_processor: Optional edge processor for edge traversal | |||
| node_state_manager: Optional node state manager | |||
| execution_tracker: Optional execution tracker | |||
| error_handler: Optional error handler | |||
| event_collector: Event collector for collecting events | |||
| branch_handler: Branch handler for branch node processing | |||
| edge_processor: Edge processor for edge traversal | |||
| node_state_manager: Node state manager | |||
| execution_tracker: Execution tracker | |||
| error_handler: Error handler | |||
| """ | |||
| self.graph = graph | |||
| self.graph_runtime_state = graph_runtime_state | |||
| self.graph_execution = graph_execution | |||
| self.response_coordinator = response_coordinator | |||
| self.event_collector = event_collector | |||
| self.branch_handler = branch_handler | |||
| self.edge_processor = edge_processor | |||
| self.node_state_manager = node_state_manager | |||
| self.execution_tracker = execution_tracker | |||
| self.error_handler = error_handler | |||
| self._graph = graph | |||
| self._graph_runtime_state = graph_runtime_state | |||
| self._graph_execution = graph_execution | |||
| self._response_coordinator = response_coordinator | |||
| self._event_collector = event_collector | |||
| self._branch_handler = branch_handler | |||
| self._edge_processor = edge_processor | |||
| self._node_state_manager = node_state_manager | |||
| self._execution_tracker = execution_tracker | |||
| self._error_handler = error_handler | |||
| def handle_event(self, event: GraphNodeEventBase) -> None: | |||
| """ | |||
| @@ -93,9 +94,8 @@ class EventHandlerRegistry: | |||
| event: The event to handle | |||
| """ | |||
| # Events in loops or iterations are always collected | |||
| if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id): | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| if event.in_loop_id or event.in_iteration_id: | |||
| self._event_collector.collect(event) | |||
| return | |||
| # Handle specific event types | |||
| @@ -125,12 +125,10 @@ class EventHandlerRegistry: | |||
| ), | |||
| ): | |||
| # Iteration and loop events are collected directly | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| self._event_collector.collect(event) | |||
| else: | |||
| # Collect unhandled events | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| self._event_collector.collect(event) | |||
| logger.warning("Unhandled event type: %s", type(event).__name__) | |||
| def _handle_node_started(self, event: NodeRunStartedEvent) -> None: | |||
| @@ -141,15 +139,14 @@ class EventHandlerRegistry: | |||
| event: The node started event | |||
| """ | |||
| # Track execution in domain model | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution.mark_started(event.id) | |||
| # Track in response coordinator for stream ordering | |||
| self.response_coordinator.track_node_execution(event.node_id, event.id) | |||
| self._response_coordinator.track_node_execution(event.node_id, event.id) | |||
| # Collect the event | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| self._event_collector.collect(event) | |||
| def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: | |||
| """ | |||
| @@ -159,12 +156,11 @@ class EventHandlerRegistry: | |||
| event: The stream chunk event | |||
| """ | |||
| # Process with response coordinator | |||
| streaming_events = list(self.response_coordinator.intercept_event(event)) | |||
| streaming_events = list(self._response_coordinator.intercept_event(event)) | |||
| # Collect all events | |||
| if self.event_collector: | |||
| for stream_event in streaming_events: | |||
| self.event_collector.collect(stream_event) | |||
| for stream_event in streaming_events: | |||
| self._event_collector.collect(stream_event) | |||
| def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: | |||
| """ | |||
| @@ -177,55 +173,44 @@ class EventHandlerRegistry: | |||
| event: The node succeeded event | |||
| """ | |||
| # Update domain model | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution.mark_taken() | |||
| # Store outputs in variable pool | |||
| self._store_node_outputs(event) | |||
| # Forward to response coordinator and emit streaming events | |||
| streaming_events = list(self.response_coordinator.intercept_event(event)) | |||
| if self.event_collector: | |||
| for stream_event in streaming_events: | |||
| self.event_collector.collect(stream_event) | |||
| streaming_events = self._response_coordinator.intercept_event(event) | |||
| for stream_event in streaming_events: | |||
| self._event_collector.collect(stream_event) | |||
| # Process edges and get ready nodes | |||
| node = self.graph.nodes[event.node_id] | |||
| node = self._graph.nodes[event.node_id] | |||
| if node.execution_type == NodeExecutionType.BRANCH: | |||
| if self.branch_handler: | |||
| ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( | |||
| event.node_id, event.node_run_result.edge_source_handle | |||
| ) | |||
| else: | |||
| ready_nodes, edge_streaming_events = [], [] | |||
| ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion( | |||
| event.node_id, event.node_run_result.edge_source_handle | |||
| ) | |||
| else: | |||
| if self.edge_processor: | |||
| ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id) | |||
| else: | |||
| ready_nodes, edge_streaming_events = [], [] | |||
| ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) | |||
| # Collect streaming events from edge processing | |||
| if self.event_collector: | |||
| for edge_event in edge_streaming_events: | |||
| self.event_collector.collect(edge_event) | |||
| for edge_event in edge_streaming_events: | |||
| self._event_collector.collect(edge_event) | |||
| # Enqueue ready nodes | |||
| if self.node_state_manager and self.execution_tracker: | |||
| for node_id in ready_nodes: | |||
| self.node_state_manager.enqueue_node(node_id) | |||
| self.execution_tracker.add(node_id) | |||
| for node_id in ready_nodes: | |||
| self._node_state_manager.enqueue_node(node_id) | |||
| self._execution_tracker.add(node_id) | |||
| # Update execution tracking | |||
| if self.execution_tracker: | |||
| self.execution_tracker.remove(event.node_id) | |||
| self._execution_tracker.remove(event.node_id) | |||
| # Handle response node outputs | |||
| if node.execution_type == NodeExecutionType.RESPONSE: | |||
| self._update_response_outputs(event) | |||
| # Collect the event | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| self._event_collector.collect(event) | |||
| def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: | |||
| """ | |||
| @@ -235,29 +220,19 @@ class EventHandlerRegistry: | |||
| event: The node failed event | |||
| """ | |||
| # Update domain model | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution.mark_failed(event.error) | |||
| if self.error_handler: | |||
| result = self.error_handler.handle_node_failure(event) | |||
| result = self._error_handler.handle_node_failure(event) | |||
| if result: | |||
| # Process the resulting event (retry, exception, etc.) | |||
| self.handle_event(result) | |||
| else: | |||
| # Abort execution | |||
| self.graph_execution.fail(RuntimeError(event.error)) | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| if self.execution_tracker: | |||
| self.execution_tracker.remove(event.node_id) | |||
| if result: | |||
| # Process the resulting event (retry, exception, etc.) | |||
| self.handle_event(result) | |||
| else: | |||
| # Without error handler, just fail | |||
| self.graph_execution.fail(RuntimeError(event.error)) | |||
| if self.event_collector: | |||
| self.event_collector.collect(event) | |||
| if self.execution_tracker: | |||
| self.execution_tracker.remove(event.node_id) | |||
| # Abort execution | |||
| self._graph_execution.fail(RuntimeError(event.error)) | |||
| self._event_collector.collect(event) | |||
| self._execution_tracker.remove(event.node_id) | |||
| def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: | |||
| """ | |||
| @@ -267,7 +242,7 @@ class EventHandlerRegistry: | |||
| event: The node exception event | |||
| """ | |||
| # Node continues via fail-branch, so it's technically "succeeded" | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution.mark_taken() | |||
| def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: | |||
| @@ -277,7 +252,7 @@ class EventHandlerRegistry: | |||
| Args: | |||
| event: The node retry event | |||
| """ | |||
| node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) | |||
| node_execution.increment_retry() | |||
| def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: | |||
| @@ -288,16 +263,16 @@ class EventHandlerRegistry: | |||
| event: The node succeeded event containing outputs | |||
| """ | |||
| for variable_name, variable_value in event.node_run_result.outputs.items(): | |||
| self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) | |||
| self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) | |||
| def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: | |||
| """Update response outputs for response nodes.""" | |||
| for key, value in event.node_run_result.outputs.items(): | |||
| if key == "answer": | |||
| existing = self.graph_runtime_state.outputs.get("answer", "") | |||
| existing = self._graph_runtime_state.outputs.get("answer", "") | |||
| if existing: | |||
| self.graph_runtime_state.outputs["answer"] = f"{existing}{value}" | |||
| self._graph_runtime_state.outputs["answer"] = f"{existing}{value}" | |||
| else: | |||
| self.graph_runtime_state.outputs["answer"] = value | |||
| self._graph_runtime_state.outputs["answer"] = value | |||
| else: | |||
| self.graph_runtime_state.outputs[key] = value | |||
| self._graph_runtime_state.outputs[key] = value | |||
| @@ -9,7 +9,7 @@ import contextvars | |||
| import logging | |||
| import queue | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Optional | |||
| from typing import final | |||
| from flask import Flask, current_app | |||
| @@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import ( | |||
| GraphEngineEvent, | |||
| GraphNodeEventBase, | |||
| GraphRunAbortedEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunStartedEvent, | |||
| @@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class GraphEngine: | |||
| """ | |||
| Queue-based graph execution engine. | |||
| @@ -62,7 +64,7 @@ class GraphEngine: | |||
| invoke_from: InvokeFrom, | |||
| call_depth: int, | |||
| graph: Graph, | |||
| graph_config: Mapping[str, Any], | |||
| graph_config: Mapping[str, object], | |||
| graph_runtime_state: GraphRuntimeState, | |||
| max_execution_steps: int, | |||
| max_execution_time: int, | |||
| @@ -103,7 +105,7 @@ class GraphEngine: | |||
| # Initialize queues | |||
| self.ready_queue: queue.Queue[str] = queue.Queue() | |||
| self.event_queue: queue.Queue = queue.Queue() | |||
| self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() | |||
| # Initialize subsystems | |||
| self._initialize_subsystems() | |||
| @@ -185,7 +187,7 @@ class GraphEngine: | |||
| event_handler=self.event_handler_registry, | |||
| event_collector=self.event_collector, | |||
| command_processor=self.command_processor, | |||
| worker_pool=self.worker_pool, | |||
| worker_pool=self._worker_pool, | |||
| ) | |||
| self.dispatcher = Dispatcher( | |||
| @@ -209,7 +211,7 @@ class GraphEngine: | |||
| def _setup_worker_management(self) -> None: | |||
| """Initialize worker management subsystem.""" | |||
| # Capture context for workers | |||
| flask_app: Optional[Flask] = None | |||
| flask_app: Flask | None = None | |||
| try: | |||
| flask_app = current_app._get_current_object() # type: ignore | |||
| except RuntimeError: | |||
| @@ -218,8 +220,8 @@ class GraphEngine: | |||
| context_vars = contextvars.copy_context() | |||
| # Create worker management components | |||
| self.activity_tracker = ActivityTracker() | |||
| self.dynamic_scaler = DynamicScaler( | |||
| self._activity_tracker = ActivityTracker() | |||
| self._dynamic_scaler = DynamicScaler( | |||
| min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS), | |||
| max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS), | |||
| scale_up_threshold=( | |||
| @@ -233,15 +235,15 @@ class GraphEngine: | |||
| else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME | |||
| ), | |||
| ) | |||
| self.worker_factory = WorkerFactory(flask_app, context_vars) | |||
| self._worker_factory = WorkerFactory(flask_app, context_vars) | |||
| self.worker_pool = WorkerPool( | |||
| self._worker_pool = WorkerPool( | |||
| ready_queue=self.ready_queue, | |||
| event_queue=self.event_queue, | |||
| graph=self.graph, | |||
| worker_factory=self.worker_factory, | |||
| dynamic_scaler=self.dynamic_scaler, | |||
| activity_tracker=self.activity_tracker, | |||
| worker_factory=self._worker_factory, | |||
| dynamic_scaler=self._dynamic_scaler, | |||
| activity_tracker=self._activity_tracker, | |||
| ) | |||
| def _validate_graph_state_consistency(self) -> None: | |||
| @@ -319,10 +321,10 @@ class GraphEngine: | |||
| def _start_execution(self) -> None: | |||
| """Start execution subsystems.""" | |||
| # Calculate initial worker count | |||
| initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph) | |||
| initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph) | |||
| # Start worker pool | |||
| self.worker_pool.start(initial_workers) | |||
| self._worker_pool.start(initial_workers) | |||
| # Register response nodes | |||
| for node in self.graph.nodes.values(): | |||
| @@ -340,7 +342,7 @@ class GraphEngine: | |||
| def _stop_execution(self) -> None: | |||
| """Stop execution subsystems.""" | |||
| self.dispatcher.stop() | |||
| self.worker_pool.stop() | |||
| self._worker_pool.stop() | |||
| # Don't mark complete here as the dispatcher already does it | |||
| # Notify layers | |||
| @@ -2,15 +2,18 @@ | |||
| Branch node handling for graph traversal. | |||
| """ | |||
| from typing import Optional | |||
| from collections.abc import Sequence | |||
| from typing import final | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events.node import NodeRunStreamChunkEvent | |||
| from ..state_management import EdgeStateManager | |||
| from .edge_processor import EdgeProcessor | |||
| from .skip_propagator import SkipPropagator | |||
| @final | |||
| class BranchHandler: | |||
| """ | |||
| Handles branch node logic during graph traversal. | |||
| @@ -40,7 +43,9 @@ class BranchHandler: | |||
| self.skip_propagator = skip_propagator | |||
| self.edge_state_manager = edge_state_manager | |||
| def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]: | |||
| def handle_branch_completion( | |||
| self, node_id: str, selected_handle: str | None | |||
| ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: | |||
| """ | |||
| Handle completion of a branch node. | |||
| @@ -58,10 +63,10 @@ class BranchHandler: | |||
| raise ValueError(f"Branch node {node_id} completed without selecting a branch") | |||
| # Categorize edges into selected and unselected | |||
| selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| _, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| # Skip all unselected paths | |||
| self.skip_propagator.skip_branch_paths(node_id, unselected_edges) | |||
| self.skip_propagator.skip_branch_paths(unselected_edges) | |||
| # Process selected edges and get ready nodes and streaming events | |||
| return self.edge_processor.process_node_success(node_id, selected_handle) | |||
| @@ -2,13 +2,18 @@ | |||
| Edge processing logic for graph traversal. | |||
| """ | |||
| from collections.abc import Sequence | |||
| from typing import final | |||
| from core.workflow.enums import NodeExecutionType | |||
| from core.workflow.graph import Edge, Graph | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| from ..response_coordinator import ResponseStreamCoordinator | |||
| from ..state_management import EdgeStateManager, NodeStateManager | |||
| @final | |||
| class EdgeProcessor: | |||
| """ | |||
| Processes edges during graph execution. | |||
| @@ -38,7 +43,9 @@ class EdgeProcessor: | |||
| self.node_state_manager = node_state_manager | |||
| self.response_coordinator = response_coordinator | |||
| def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]: | |||
| def process_node_success( | |||
| self, node_id: str, selected_handle: str | None = None | |||
| ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: | |||
| """ | |||
| Process edges after a node succeeds. | |||
| @@ -56,7 +63,7 @@ class EdgeProcessor: | |||
| else: | |||
| return self._process_non_branch_node_edges(node_id) | |||
| def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]: | |||
| def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: | |||
| """ | |||
| Process edges for non-branch nodes (mark all as TAKEN). | |||
| @@ -66,8 +73,8 @@ class EdgeProcessor: | |||
| Returns: | |||
| Tuple of (list of downstream nodes ready for execution, list of streaming events) | |||
| """ | |||
| ready_nodes = [] | |||
| all_streaming_events = [] | |||
| ready_nodes: list[str] = [] | |||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| for edge in outgoing_edges: | |||
| @@ -77,7 +84,9 @@ class EdgeProcessor: | |||
| return ready_nodes, all_streaming_events | |||
| def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]: | |||
| def _process_branch_node_edges( | |||
| self, node_id: str, selected_handle: str | None | |||
| ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: | |||
| """ | |||
| Process edges for branch nodes. | |||
| @@ -94,8 +103,8 @@ class EdgeProcessor: | |||
| if not selected_handle: | |||
| raise ValueError(f"Branch node {node_id} did not select any edge") | |||
| ready_nodes = [] | |||
| all_streaming_events = [] | |||
| ready_nodes: list[str] = [] | |||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | |||
| # Categorize edges | |||
| selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| @@ -112,7 +121,7 @@ class EdgeProcessor: | |||
| return ready_nodes, all_streaming_events | |||
| def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]: | |||
| def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: | |||
| """ | |||
| Mark edge as taken and check downstream node. | |||
| @@ -129,11 +138,11 @@ class EdgeProcessor: | |||
| streaming_events = self.response_coordinator.on_edge_taken(edge.id) | |||
| # Check if downstream node is ready | |||
| ready_nodes = [] | |||
| ready_nodes: list[str] = [] | |||
| if self.node_state_manager.is_node_ready(edge.head): | |||
| ready_nodes.append(edge.head) | |||
| return ready_nodes, list(streaming_events) | |||
| return ready_nodes, streaming_events | |||
| def _process_skipped_edge(self, edge: Edge) -> None: | |||
| """ | |||
| @@ -2,10 +2,13 @@ | |||
| Node readiness checking for execution. | |||
| """ | |||
| from typing import final | |||
| from core.workflow.enums import NodeState | |||
| from core.workflow.graph import Graph | |||
| @final | |||
| class NodeReadinessChecker: | |||
| """ | |||
| Checks if nodes are ready for execution based on their dependencies. | |||
| @@ -71,7 +74,7 @@ class NodeReadinessChecker: | |||
| Returns: | |||
| List of node IDs that are now ready | |||
| """ | |||
| ready_nodes = [] | |||
| ready_nodes: list[str] = [] | |||
| outgoing_edges = self.graph.get_outgoing_edges(from_node_id) | |||
| for edge in outgoing_edges: | |||
| @@ -2,11 +2,15 @@ | |||
| Skip state propagation through the graph. | |||
| """ | |||
| from core.workflow.graph import Graph | |||
| from collections.abc import Sequence | |||
| from typing import final | |||
| from core.workflow.graph import Edge, Graph | |||
| from ..state_management import EdgeStateManager, NodeStateManager | |||
| @final | |||
| class SkipPropagator: | |||
| """ | |||
| Propagates skip states through the graph. | |||
| @@ -57,9 +61,8 @@ class SkipPropagator: | |||
| # If any edge is taken, node may still execute | |||
| if edge_states["has_taken"]: | |||
| # Check if node is ready and enqueue if so | |||
| if self.node_state_manager.is_node_ready(downstream_node_id): | |||
| self.node_state_manager.enqueue_node(downstream_node_id) | |||
| # Enqueue node | |||
| self.node_state_manager.enqueue_node(downstream_node_id) | |||
| return | |||
| # All edges are skipped, propagate skip to this node | |||
| @@ -83,12 +86,11 @@ class SkipPropagator: | |||
| # Recursively propagate skip | |||
| self.propagate_skip_from_edge(edge.id) | |||
| def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None: | |||
| def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: | |||
| """ | |||
| Skip all paths from unselected branch edges. | |||
| Args: | |||
| node_id: The ID of the branch node | |||
| unselected_edges: List of edges not taken by the branch | |||
| """ | |||
| for edge in unselected_edges: | |||
| @@ -6,7 +6,6 @@ intercept and respond to GraphEngine events. | |||
| """ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from core.workflow.entities import GraphRuntimeState | |||
| from core.workflow.graph_engine.protocols.command_channel import CommandChannel | |||
| @@ -28,8 +27,8 @@ class Layer(ABC): | |||
| def __init__(self) -> None: | |||
| """Initialize the layer. Subclasses can override with custom parameters.""" | |||
| self.graph_runtime_state: Optional[GraphRuntimeState] = None | |||
| self.command_channel: Optional[CommandChannel] = None | |||
| self.graph_runtime_state: GraphRuntimeState | None = None | |||
| self.command_channel: CommandChannel | None = None | |||
| def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None: | |||
| """ | |||
| @@ -73,7 +72,7 @@ class Layer(ABC): | |||
| pass | |||
| @abstractmethod | |||
| def on_graph_end(self, error: Optional[Exception]) -> None: | |||
| def on_graph_end(self, error: Exception | None) -> None: | |||
| """ | |||
| Called when graph execution ends. | |||
| @@ -7,7 +7,7 @@ graph execution for debugging purposes. | |||
| import logging | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from typing import Any, final | |||
| from core.workflow.graph_events import ( | |||
| GraphEngineEvent, | |||
| @@ -34,6 +34,7 @@ from core.workflow.graph_events import ( | |||
| from .base import Layer | |||
| @final | |||
| class DebugLoggingLayer(Layer): | |||
| """ | |||
| A layer that provides comprehensive logging of GraphEngine execution. | |||
| @@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer): | |||
| # Log unknown events at debug level | |||
| self.logger.debug("Event: %s", event_class) | |||
| def on_graph_end(self, error: Optional[Exception]) -> None: | |||
| def on_graph_end(self, error: Exception | None) -> None: | |||
| """Log graph execution end with summary statistics.""" | |||
| self.logger.info("=" * 80) | |||
| @@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution. | |||
| import logging | |||
| import time | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType | |||
| from core.workflow.graph_engine.layers import Layer | |||
| @@ -29,6 +29,7 @@ class LimitType(Enum): | |||
| TIME_LIMIT = "time_limit" | |||
| @final | |||
| class ExecutionLimitsLayer(Layer): | |||
| """ | |||
| Layer that enforces execution limits for workflows. | |||
| @@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer): | |||
| self.max_time = max_time | |||
| # Runtime tracking | |||
| self.start_time: Optional[float] = None | |||
| self.start_time: float | None = None | |||
| self.step_count = 0 | |||
| self.logger = logging.getLogger(__name__) | |||
| @@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer): | |||
| if self._reached_time_limitation(): | |||
| self._send_abort_command(LimitType.TIME_LIMIT) | |||
| def on_graph_end(self, error: Optional[Exception]) -> None: | |||
| def on_graph_end(self, error: Exception | None) -> None: | |||
| """Called when graph execution ends.""" | |||
| if self._execution_started and not self._execution_ended: | |||
| self._execution_ended = True | |||
| @@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks. | |||
| Supports stop, pause, and resume operations. | |||
| """ | |||
| from typing import Optional | |||
| from typing import final | |||
| from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel | |||
| from core.workflow.graph_engine.entities.commands import AbortCommand | |||
| from extensions.ext_redis import redis_client | |||
| @final | |||
| class GraphEngineManager: | |||
| """ | |||
| Manager for sending control commands to GraphEngine instances. | |||
| @@ -23,7 +24,7 @@ class GraphEngineManager: | |||
| """ | |||
| @staticmethod | |||
| def send_stop_command(task_id: str, reason: Optional[str] = None) -> None: | |||
| def send_stop_command(task_id: str, reason: str | None = None) -> None: | |||
| """ | |||
| Send a stop command to a running workflow. | |||
| @@ -6,7 +6,9 @@ import logging | |||
| import queue | |||
| import threading | |||
| import time | |||
| from typing import TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, final | |||
| from core.workflow.graph_events.base import GraphNodeEventBase | |||
| from ..event_management import EventCollector, EventEmitter | |||
| from .execution_coordinator import ExecutionCoordinator | |||
| @@ -17,6 +19,7 @@ if TYPE_CHECKING: | |||
| logger = logging.getLogger(__name__) | |||
| @final | |||
| class Dispatcher: | |||
| """ | |||
| Main dispatcher that processes events from the event queue. | |||
| @@ -27,12 +30,12 @@ class Dispatcher: | |||
| def __init__( | |||
| self, | |||
| event_queue: queue.Queue, | |||
| event_queue: queue.Queue[GraphNodeEventBase], | |||
| event_handler: "EventHandlerRegistry", | |||
| event_collector: EventCollector, | |||
| execution_coordinator: ExecutionCoordinator, | |||
| max_execution_time: int, | |||
| event_emitter: Optional[EventEmitter] = None, | |||
| event_emitter: EventEmitter | None = None, | |||
| ) -> None: | |||
| """ | |||
| Initialize the dispatcher. | |||
| @@ -52,9 +55,9 @@ class Dispatcher: | |||
| self.max_execution_time = max_execution_time | |||
| self.event_emitter = event_emitter | |||
| self._thread: Optional[threading.Thread] = None | |||
| self._thread: threading.Thread | None = None | |||
| self._stop_event = threading.Event() | |||
| self._start_time: Optional[float] = None | |||
| self._start_time: float | None = None | |||
| def start(self) -> None: | |||
| """Start the dispatcher thread.""" | |||
| @@ -2,7 +2,7 @@ | |||
| Execution coordinator for managing overall workflow execution. | |||
| """ | |||
| from typing import TYPE_CHECKING | |||
| from typing import TYPE_CHECKING, final | |||
| from ..command_processing import CommandProcessor | |||
| from ..domain import GraphExecution | |||
| @@ -14,6 +14,7 @@ if TYPE_CHECKING: | |||
| from ..event_management import EventHandlerRegistry | |||
| @final | |||
| class ExecutionCoordinator: | |||
| """ | |||
| Coordinates overall execution flow between subsystems. | |||
| @@ -7,7 +7,7 @@ thread-safe storage for node outputs. | |||
| from collections.abc import Sequence | |||
| from threading import RLock | |||
| from typing import TYPE_CHECKING, Optional, Union | |||
| from typing import TYPE_CHECKING, Union, final | |||
| from core.variables import Segment | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -18,6 +18,7 @@ if TYPE_CHECKING: | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| @final | |||
| class OutputRegistry: | |||
| """ | |||
| Thread-safe registry for storing and retrieving node outputs. | |||
| @@ -47,7 +48,7 @@ class OutputRegistry: | |||
| with self._lock: | |||
| self._scalars.add(selector, value) | |||
| def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]: | |||
| def get_scalar(self, selector: Sequence[str]) -> "Segment | None": | |||
| """ | |||
| Get a scalar value for the given selector. | |||
| @@ -81,7 +82,7 @@ class OutputRegistry: | |||
| except ValueError: | |||
| raise ValueError(f"Stream {'.'.join(selector)} is already closed") | |||
| def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]: | |||
| def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": | |||
| """ | |||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||
| @@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry | |||
| to manage streaming data chunks. | |||
| """ | |||
| from typing import TYPE_CHECKING, Optional | |||
| from typing import TYPE_CHECKING, final | |||
| if TYPE_CHECKING: | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| @final | |||
| class Stream: | |||
| """ | |||
| A stream that holds NodeRunStreamChunkEvent objects and tracks read position. | |||
| @@ -41,7 +42,7 @@ class Stream: | |||
| raise ValueError("Cannot append to a closed stream") | |||
| self.events.append(event) | |||
| def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]: | |||
| def pop_next(self) -> "NodeRunStreamChunkEvent | None": | |||
| """ | |||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||
| @@ -2,7 +2,7 @@ | |||
| Base error strategy protocol. | |||
| """ | |||
| from typing import Optional, Protocol | |||
| from typing import Protocol | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent | |||
| @@ -16,7 +16,7 @@ class ErrorStrategy(Protocol): | |||
| node execution failures. | |||
| """ | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: | |||
| def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: | |||
| """ | |||
| Handle a node failure event. | |||
| @@ -9,7 +9,7 @@ import logging | |||
| from collections import deque | |||
| from collections.abc import Sequence | |||
| from threading import RLock | |||
| from typing import Optional, TypeAlias | |||
| from typing import TypeAlias, final | |||
| from uuid import uuid4 | |||
| from core.workflow.enums import NodeExecutionType, NodeState | |||
| @@ -28,6 +28,7 @@ NodeID: TypeAlias = str | |||
| EdgeID: TypeAlias = str | |||
| @final | |||
| class ResponseStreamCoordinator: | |||
| """ | |||
| Manages response streaming sessions without relying on global state. | |||
| @@ -45,7 +46,7 @@ class ResponseStreamCoordinator: | |||
| """ | |||
| self.registry = registry | |||
| self.graph = graph | |||
| self.active_session: Optional[ResponseSession] = None | |||
| self.active_session: ResponseSession | None = None | |||
| self.waiting_sessions: deque[ResponseSession] = deque() | |||
| self.lock = RLock() | |||
| @@ -3,7 +3,8 @@ Manager for edge states during graph execution. | |||
| """ | |||
| import threading | |||
| from typing import TypedDict | |||
| from collections.abc import Sequence | |||
| from typing import TypedDict, final | |||
| from core.workflow.enums import NodeState | |||
| from core.workflow.graph import Edge, Graph | |||
| @@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict): | |||
| all_skipped: bool | |||
| @final | |||
| class EdgeStateManager: | |||
| """ | |||
| Manages edge states and transitions during graph execution. | |||
| @@ -87,7 +89,7 @@ class EdgeStateManager: | |||
| with self._lock: | |||
| return self.graph.edges[edge_id].state | |||
| def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]: | |||
| def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: | |||
| """ | |||
| Categorize branch edges into selected and unselected. | |||
| @@ -100,8 +102,8 @@ class EdgeStateManager: | |||
| """ | |||
| with self._lock: | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| selected_edges = [] | |||
| unselected_edges = [] | |||
| selected_edges: list[Edge] = [] | |||
| unselected_edges: list[Edge] = [] | |||
| for edge in outgoing_edges: | |||
| if edge.source_handle == selected_handle: | |||
| @@ -3,8 +3,10 @@ Tracker for currently executing nodes. | |||
| """ | |||
| import threading | |||
| from typing import final | |||
| @final | |||
| class ExecutionTracker: | |||
| """ | |||
| Tracks nodes that are currently being executed. | |||
| @@ -4,11 +4,13 @@ Manager for node states during graph execution. | |||
| import queue | |||
| import threading | |||
| from typing import final | |||
| from core.workflow.enums import NodeState | |||
| from core.workflow.graph import Graph | |||
| @final | |||
| class NodeStateManager: | |||
| """ | |||
| Manages node states and the ready queue for execution. | |||
| @@ -11,7 +11,7 @@ import threading | |||
| import time | |||
| from collections.abc import Callable | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from typing import final | |||
| from uuid import uuid4 | |||
| from flask import Flask | |||
| @@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| @final | |||
| class Worker(threading.Thread): | |||
| """ | |||
| Worker thread that executes nodes from the ready queue. | |||
| @@ -38,10 +39,10 @@ class Worker(threading.Thread): | |||
| event_queue: queue.Queue[GraphNodeEventBase], | |||
| graph: Graph, | |||
| worker_id: int = 0, | |||
| flask_app: Optional[Flask] = None, | |||
| context_vars: Optional[contextvars.Context] = None, | |||
| on_idle_callback: Optional[Callable[[int], None]] = None, | |||
| on_active_callback: Optional[Callable[[int], None]] = None, | |||
| flask_app: Flask | None = None, | |||
| context_vars: contextvars.Context | None = None, | |||
| on_idle_callback: Callable[[int], None] | None = None, | |||
| on_active_callback: Callable[[int], None] | None = None, | |||
| ) -> None: | |||
| """ | |||
| Initialize worker thread. | |||
| @@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity. | |||
| import threading | |||
| import time | |||
| from typing import final | |||
| @final | |||
| class ActivityTracker: | |||
| """ | |||
| Tracks worker activity for scaling decisions. | |||
| @@ -2,9 +2,12 @@ | |||
| Dynamic scaler for worker pool sizing. | |||
| """ | |||
| from typing import final | |||
| from core.workflow.graph import Graph | |||
| @final | |||
| class DynamicScaler: | |||
| """ | |||
| Manages dynamic scaling decisions for the worker pool. | |||
| @@ -5,7 +5,7 @@ Factory for creating worker instances. | |||
| import contextvars | |||
| import queue | |||
| from collections.abc import Callable | |||
| from typing import Optional | |||
| from typing import final | |||
| from flask import Flask | |||
| @@ -14,6 +14,7 @@ from core.workflow.graph import Graph | |||
| from ..worker import Worker | |||
| @final | |||
| class WorkerFactory: | |||
| """ | |||
| Factory for creating worker instances with proper context. | |||
| @@ -24,7 +25,7 @@ class WorkerFactory: | |||
| def __init__( | |||
| self, | |||
| flask_app: Optional[Flask], | |||
| flask_app: Flask | None, | |||
| context_vars: contextvars.Context, | |||
| ) -> None: | |||
| """ | |||
| @@ -43,8 +44,8 @@ class WorkerFactory: | |||
| ready_queue: queue.Queue[str], | |||
| event_queue: queue.Queue, | |||
| graph: Graph, | |||
| on_idle_callback: Optional[Callable[[int], None]] = None, | |||
| on_active_callback: Optional[Callable[[int], None]] = None, | |||
| on_idle_callback: Callable[[int], None] | None = None, | |||
| on_active_callback: Callable[[int], None] | None = None, | |||
| ) -> Worker: | |||
| """ | |||
| Create a new worker instance. | |||
| @@ -4,6 +4,7 @@ Worker pool management. | |||
| import queue | |||
| import threading | |||
| from typing import final | |||
| from core.workflow.graph import Graph | |||
| @@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler | |||
| from .worker_factory import WorkerFactory | |||
| @final | |||
| class WorkerPool: | |||
| """ | |||
| Manages a pool of worker threads for executing nodes. | |||
| @@ -2,7 +2,7 @@ from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID | |||
| from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus | |||
| from core.workflow.node_events import NodeRunResult | |||
| from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||
| from core.workflow.nodes.base.node import Node | |||
| @@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData | |||
| class StartNode(Node): | |||
| node_type = NodeType.START | |||
| execution_type = NodeExecutionType.ROOT | |||
| _node_data: StartNodeData | |||
| @@ -65,7 +65,7 @@ class Storage: | |||
| from extensions.storage.volcengine_tos_storage import VolcengineTosStorage | |||
| return VolcengineTosStorage | |||
| case StorageType.SUPBASE: | |||
| case StorageType.SUPABASE: | |||
| from extensions.storage.supabase_storage import SupabaseStorage | |||
| return SupabaseStorage | |||
| @@ -14,4 +14,4 @@ class StorageType(StrEnum): | |||
| S3 = "s3" | |||
| TENCENT_COS = "tencent-cos" | |||
| VOLCENGINE_TOS = "volcengine-tos" | |||
| SUPBASE = "supabase" | |||
| SUPABASE = "supabase" | |||
| @@ -137,10 +137,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen | |||
| return cast(Variable, result) | |||
| def infer_segment_type_from_value(value: Any, /) -> SegmentType: | |||
| return build_segment(value).value_type | |||
| def build_segment(value: Any, /) -> Segment: | |||
| # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` | |||
| # below | |||
| @@ -301,8 +301,8 @@ class TokenManager: | |||
| if expiry_minutes is None: | |||
| raise ValueError(f"Expiry minutes for {token_type} token is not set") | |||
| token_key = cls._get_token_key(token, token_type) | |||
| expiry_time = int(expiry_minutes * 60) | |||
| redis_client.setex(token_key, expiry_time, json.dumps(token_data)) | |||
| expiry_seconds = int(expiry_minutes * 60) | |||
| redis_client.setex(token_key, expiry_seconds, json.dumps(token_data)) | |||
| if account_id: | |||
| cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes) | |||
| @@ -336,11 +336,11 @@ class TokenManager: | |||
| @classmethod | |||
| def _set_current_token_for_account( | |||
| cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] | |||
| cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float] | |||
| ): | |||
| key = cls._get_account_token_key(account_id, token_type) | |||
| expiry_time = int(expiry_hours * 60 * 60) | |||
| redis_client.setex(key, expiry_time, token) | |||
| expiry_seconds = int(expiry_minutes * 60) | |||
| redis_client.setex(key, expiry_seconds, token) | |||
| @classmethod | |||
| def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | |||
| @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import UserMixin | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| from configs import dify_config | |||
| @@ -1556,7 +1556,7 @@ class ApiToken(Base): | |||
| def generate_api_key(prefix, n): | |||
| while True: | |||
| result = prefix + generate_string(n) | |||
| if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: | |||
| if db.session.scalar(select(exists().where(ApiToken.token == result))): | |||
| continue | |||
| return result | |||
| @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from uuid import uuid4 | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import DateTime, orm | |||
| from sqlalchemy import DateTime, exists, orm, select | |||
| from core.file.constants import maybe_file_object | |||
| from core.file.models import File | |||
| @@ -348,12 +348,13 @@ class Workflow(Base): | |||
| """ | |||
| from models.tools import WorkflowToolProvider | |||
| return ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) | |||
| .count() | |||
| > 0 | |||
| stmt = select( | |||
| exists().where( | |||
| WorkflowToolProvider.tenant_id == self.tenant_id, | |||
| WorkflowToolProvider.app_id == self.app_id, | |||
| ) | |||
| ) | |||
| return db.session.execute(stmt).scalar_one() | |||
| @property | |||
| def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | |||
| @@ -952,7 +953,7 @@ def _naive_utc_datetime(): | |||
| class WorkflowDraftVariable(Base): | |||
| """`WorkflowDraftVariable` record variables and outputs generated during | |||
| debugging worfklow or chatflow. | |||
| debugging workflow or chatflow. | |||
| IMPORTANT: This model maintains multiple invariant rules that must be preserved. | |||
| Do not instantiate this class directly with the constructor. | |||
| @@ -9,7 +9,7 @@ from collections import Counter | |||
| from typing import Any, Literal, Optional | |||
| from flask_login import current_user | |||
| from sqlalchemy import func, select | |||
| from sqlalchemy import exists, func, select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -845,10 +845,8 @@ class DatasetService: | |||
| @staticmethod | |||
| def dataset_use_check(dataset_id) -> bool: | |||
| count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() | |||
| if count > 0: | |||
| return True | |||
| return False | |||
| stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) | |||
| return db.session.execute(stmt).scalar_one() | |||
| @staticmethod | |||
| def check_dataset_permission(dataset, user): | |||
| @@ -4,6 +4,7 @@ from collections.abc import Mapping | |||
| from pathlib import Path | |||
| from typing import Any, Optional | |||
| from sqlalchemy import exists, select | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| @@ -190,11 +191,14 @@ class BuiltinToolManageService: | |||
| # update name if provided | |||
| if name and name != db_provider.name: | |||
| # check if the name is already used | |||
| if ( | |||
| session.query(BuiltinToolProvider) | |||
| .filter_by(tenant_id=tenant_id, provider=provider, name=name) | |||
| .count() | |||
| > 0 | |||
| if session.scalar( | |||
| select( | |||
| exists().where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| BuiltinToolProvider.name == name, | |||
| ) | |||
| ) | |||
| ): | |||
| raise ValueError(f"the credential name '{name}' is already used") | |||
| @@ -246,11 +250,14 @@ class BuiltinToolManageService: | |||
| ) | |||
| else: | |||
| # check if the name is already used | |||
| if ( | |||
| session.query(BuiltinToolProvider) | |||
| .filter_by(tenant_id=tenant_id, provider=provider, name=name) | |||
| .count() | |||
| > 0 | |||
| if session.scalar( | |||
| select( | |||
| exists().where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| BuiltinToolProvider.name == name, | |||
| ) | |||
| ) | |||
| ): | |||
| raise ValueError(f"the credential name '{name}' is already used") | |||
| @@ -4,7 +4,7 @@ import uuid | |||
| from collections.abc import Callable, Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy import exists, select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.app.app_config.entities import VariableEntityType | |||
| @@ -83,15 +83,14 @@ class WorkflowService: | |||
| ) | |||
| def is_workflow_exist(self, app_model: App) -> bool: | |||
| return ( | |||
| db.session.query(Workflow) | |||
| .where( | |||
| stmt = select( | |||
| exists().where( | |||
| Workflow.tenant_id == app_model.tenant_id, | |||
| Workflow.app_id == app_model.id, | |||
| Workflow.version == Workflow.VERSION_DRAFT, | |||
| ) | |||
| .count() | |||
| ) > 0 | |||
| ) | |||
| return db.session.execute(stmt).scalar_one() | |||
| def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: | |||
| """ | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import exists, select | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from extensions.ext_database import db | |||
| @@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): | |||
| start_at = time.perf_counter() | |||
| # get app info | |||
| app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() | |||
| annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() | |||
| annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) | |||
| if not app: | |||
| logger.info(click.style(f"App not found: {app_id}", fg="red")) | |||
| db.session.close() | |||
| @@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): | |||
| ) | |||
| try: | |||
| if annotations_count > 0: | |||
| if annotations_exists: | |||
| vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) | |||
| vector.delete() | |||
| except Exception: | |||
| @@ -39,7 +39,7 @@ def test_page_result(text, cursor, maxlen, expected): | |||
| # Tests: get_url | |||
| # --------------------------- | |||
| @pytest.fixture | |||
| def stub_support_types(monkeypatch): | |||
| def stub_support_types(monkeypatch: pytest.MonkeyPatch): | |||
| """Stub supported content types list.""" | |||
| import core.tools.utils.web_reader_tool as mod | |||
| @@ -48,7 +48,7 @@ def stub_support_types(monkeypatch): | |||
| return mod | |||
| def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): | |||
| def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| # HEAD 200 but content-type not supported and not text/html | |||
| def fake_head(url, headers=None, follow_redirects=True, timeout=None): | |||
| return FakeResponse( | |||
| @@ -62,7 +62,7 @@ def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): | |||
| assert result == "Unsupported content-type [image/png] of URL." | |||
| def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types): | |||
| def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """ | |||
| When content-type is in SUPPORT_URL_CONTENT_TYPES, | |||
| should call ExtractProcessor.load_from_url and return its text. | |||
| @@ -88,7 +88,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_ | |||
| assert result == "PDF extracted text" | |||
| def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types): | |||
| def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """200 + text/html → GET, chardet detects encoding, readability returns article which is templated.""" | |||
| def fake_head(url, headers=None, follow_redirects=True, timeout=None): | |||
| @@ -121,7 +121,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_suppor | |||
| assert "Hello world" in out | |||
| def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types): | |||
| def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """If readability returns no text, should return empty string.""" | |||
| def fake_head(url, headers=None, follow_redirects=True, timeout=None): | |||
| @@ -142,7 +142,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_su | |||
| assert out == "" | |||
| def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types): | |||
| def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed.""" | |||
| def fake_head(url, headers=None, follow_redirects=True, timeout=None): | |||
| @@ -175,7 +175,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types): | |||
| assert "X" in out | |||
| def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types): | |||
| def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """HEAD returns non-200 and non-403 → should directly return code message.""" | |||
| def fake_head(url, headers=None, follow_redirects=True, timeout=None): | |||
| @@ -189,7 +189,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types): | |||
| assert out == "URL returned status code 500." | |||
| def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types): | |||
| def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """ | |||
| If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type, | |||
| it should route to ExtractProcessor.load_from_url. | |||
| @@ -213,7 +213,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch, stub_suppor | |||
| assert out == "From ExtractProcessor via filename" | |||
| def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types): | |||
| def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types): | |||
| """ | |||
| If chardet returns an encoding but content.decode raises, should fallback to response.text. | |||
| """ | |||
| @@ -250,7 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_supp | |||
| # --------------------------- | |||
| def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch): | |||
| def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch): | |||
| # stub readabilipy.simple_json_from_html_string | |||
| def fake_simple_json_from_html_string(html, use_readability=True): | |||
| return { | |||
| @@ -271,7 +271,7 @@ def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch): | |||
| assert article.text[0]["text"] == "world" | |||
| def test_extract_using_readabilipy_defaults_when_missing(monkeypatch): | |||
| def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch): | |||
| def fake_simple_json_from_html_string(html, use_readability=True): | |||
| return {} # all missing | |||
| @@ -8,7 +8,7 @@ from core.tools.errors import ToolInvokeError | |||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||
| def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch): | |||
| def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): | |||
| """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when | |||
| `WorkflowAppGenerator.generate` returns a result with `error` key inside | |||
| the `data` element. | |||
| @@ -40,7 +40,7 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel | |||
| "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | |||
| lambda *args, **kwargs: {"data": {"error": "oops"}}, | |||
| ) | |||
| monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None) | |||
| monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) | |||
| with pytest.raises(ToolInvokeError) as exc_info: | |||
| # WorkflowTool always returns a generator, so we need to iterate to | |||
| @@ -0,0 +1,281 @@ | |||
| """Unit tests for Graph class methods.""" | |||
| from unittest.mock import Mock | |||
| from core.workflow.enums import NodeExecutionType, NodeState, NodeType | |||
| from core.workflow.graph.edge import Edge | |||
| from core.workflow.graph.graph import Graph | |||
| from core.workflow.nodes.base.node import Node | |||
| def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: | |||
| """Create a mock node for testing.""" | |||
| node = Mock(spec=Node) | |||
| node.id = node_id | |||
| node.execution_type = execution_type | |||
| node.state = state | |||
| node.node_type = NodeType.START | |||
| return node | |||
| class TestMarkInactiveRootBranches: | |||
| """Test cases for _mark_inactive_root_branches method.""" | |||
| def test_single_root_no_marking(self): | |||
| """Test that single root graph doesn't mark anything as skipped.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), | |||
| } | |||
| in_edges = {"child1": ["edge1"]} | |||
| out_edges = {"root1": ["edge1"]} | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["child1"].state == NodeState.UNKNOWN | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| def test_multiple_roots_mark_inactive(self): | |||
| """Test marking inactive root branches with multiple root nodes.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "root2": create_mock_node("root2", NodeExecutionType.ROOT), | |||
| "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), | |||
| "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), | |||
| } | |||
| in_edges = {"child1": ["edge1"], "child2": ["edge2"]} | |||
| out_edges = {"root1": ["edge1"], "root2": ["edge2"]} | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["root2"].state == NodeState.SKIPPED | |||
| assert nodes["child1"].state == NodeState.UNKNOWN | |||
| assert nodes["child2"].state == NodeState.SKIPPED | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| assert edges["edge2"].state == NodeState.SKIPPED | |||
| def test_shared_downstream_node(self): | |||
| """Test that shared downstream nodes are not skipped if at least one path is active.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "root2": create_mock_node("root2", NodeExecutionType.ROOT), | |||
| "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), | |||
| "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), | |||
| "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), | |||
| "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), | |||
| "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), | |||
| } | |||
| in_edges = { | |||
| "child1": ["edge1"], | |||
| "child2": ["edge2"], | |||
| "shared": ["edge3", "edge4"], | |||
| } | |||
| out_edges = { | |||
| "root1": ["edge1"], | |||
| "root2": ["edge2"], | |||
| "child1": ["edge3"], | |||
| "child2": ["edge4"], | |||
| } | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["root2"].state == NodeState.SKIPPED | |||
| assert nodes["child1"].state == NodeState.UNKNOWN | |||
| assert nodes["child2"].state == NodeState.SKIPPED | |||
| assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| assert edges["edge2"].state == NodeState.SKIPPED | |||
| assert edges["edge3"].state == NodeState.UNKNOWN | |||
| assert edges["edge4"].state == NodeState.SKIPPED | |||
| def test_deep_branch_marking(self): | |||
| """Test marking deep branches with multiple levels.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "root2": create_mock_node("root2", NodeExecutionType.ROOT), | |||
| "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), | |||
| "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), | |||
| "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), | |||
| "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), | |||
| "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), | |||
| "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), | |||
| "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), | |||
| "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), | |||
| } | |||
| in_edges = { | |||
| "level1_a": ["edge1"], | |||
| "level1_b": ["edge2"], | |||
| "level2_a": ["edge3"], | |||
| "level2_b": ["edge4"], | |||
| "level3": ["edge5"], | |||
| } | |||
| out_edges = { | |||
| "root1": ["edge1"], | |||
| "root2": ["edge2"], | |||
| "level1_a": ["edge3"], | |||
| "level1_b": ["edge4"], | |||
| "level2_b": ["edge5"], | |||
| } | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["root2"].state == NodeState.SKIPPED | |||
| assert nodes["level1_a"].state == NodeState.UNKNOWN | |||
| assert nodes["level1_b"].state == NodeState.SKIPPED | |||
| assert nodes["level2_a"].state == NodeState.UNKNOWN | |||
| assert nodes["level2_b"].state == NodeState.SKIPPED | |||
| assert nodes["level3"].state == NodeState.SKIPPED | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| assert edges["edge2"].state == NodeState.SKIPPED | |||
| assert edges["edge3"].state == NodeState.UNKNOWN | |||
| assert edges["edge4"].state == NodeState.SKIPPED | |||
| assert edges["edge5"].state == NodeState.SKIPPED | |||
| def test_non_root_execution_type(self): | |||
| """Test that nodes with non-ROOT execution type are not treated as root nodes.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), | |||
| "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), | |||
| "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), | |||
| } | |||
| in_edges = {"child1": ["edge1"], "child2": ["edge2"]} | |||
| out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped | |||
| assert nodes["child1"].state == NodeState.UNKNOWN | |||
| assert nodes["child2"].state == NodeState.UNKNOWN | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| assert edges["edge2"].state == NodeState.UNKNOWN | |||
| def test_empty_graph(self): | |||
| """Test handling of empty graph structures.""" | |||
| nodes = {} | |||
| edges = {} | |||
| in_edges = {} | |||
| out_edges = {} | |||
| # Should not raise any errors | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") | |||
| def test_three_roots_mark_two_inactive(self): | |||
| """Test with three root nodes where two should be marked inactive.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "root2": create_mock_node("root2", NodeExecutionType.ROOT), | |||
| "root3": create_mock_node("root3", NodeExecutionType.ROOT), | |||
| "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), | |||
| "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), | |||
| "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), | |||
| "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), | |||
| } | |||
| in_edges = { | |||
| "child1": ["edge1"], | |||
| "child2": ["edge2"], | |||
| "child3": ["edge3"], | |||
| } | |||
| out_edges = { | |||
| "root1": ["edge1"], | |||
| "root2": ["edge2"], | |||
| "root3": ["edge3"], | |||
| } | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") | |||
| assert nodes["root1"].state == NodeState.SKIPPED | |||
| assert nodes["root2"].state == NodeState.UNKNOWN # Active root | |||
| assert nodes["root3"].state == NodeState.SKIPPED | |||
| assert nodes["child1"].state == NodeState.SKIPPED | |||
| assert nodes["child2"].state == NodeState.UNKNOWN | |||
| assert nodes["child3"].state == NodeState.SKIPPED | |||
| assert edges["edge1"].state == NodeState.SKIPPED | |||
| assert edges["edge2"].state == NodeState.UNKNOWN | |||
| assert edges["edge3"].state == NodeState.SKIPPED | |||
| def test_convergent_paths(self): | |||
| """Test convergent paths where multiple inactive branches lead to same node.""" | |||
| nodes = { | |||
| "root1": create_mock_node("root1", NodeExecutionType.ROOT), | |||
| "root2": create_mock_node("root2", NodeExecutionType.ROOT), | |||
| "root3": create_mock_node("root3", NodeExecutionType.ROOT), | |||
| "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), | |||
| "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), | |||
| "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), | |||
| } | |||
| edges = { | |||
| "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), | |||
| "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), | |||
| "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), | |||
| "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), | |||
| "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), | |||
| } | |||
| in_edges = { | |||
| "mid1": ["edge1"], | |||
| "mid2": ["edge2"], | |||
| "convergent": ["edge3", "edge4", "edge5"], | |||
| } | |||
| out_edges = { | |||
| "root1": ["edge1"], | |||
| "root2": ["edge2"], | |||
| "root3": ["edge3"], | |||
| "mid1": ["edge4"], | |||
| "mid2": ["edge5"], | |||
| } | |||
| Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") | |||
| assert nodes["root1"].state == NodeState.UNKNOWN | |||
| assert nodes["root2"].state == NodeState.SKIPPED | |||
| assert nodes["root3"].state == NodeState.SKIPPED | |||
| assert nodes["mid1"].state == NodeState.UNKNOWN | |||
| assert nodes["mid2"].state == NodeState.SKIPPED | |||
| assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 | |||
| assert edges["edge1"].state == NodeState.UNKNOWN | |||
| assert edges["edge2"].state == NodeState.SKIPPED | |||
| assert edges["edge3"].state == NodeState.SKIPPED | |||
| assert edges["edge4"].state == NodeState.UNKNOWN | |||
| assert edges["edge5"].state == NodeState.SKIPPED | |||
| @@ -21,7 +21,6 @@ from .test_mock_config import MockConfigBuilder | |||
| from .test_table_runner import TableTestRunner, WorkflowTestCase | |||
| @pytest.mark.skip | |||
| class TestComplexBranchWorkflow: | |||
| """Test suite for complex branch workflow with parallel execution.""" | |||
| @@ -30,6 +29,7 @@ class TestComplexBranchWorkflow: | |||
| self.runner = TableTestRunner() | |||
| self.fixture_path = "test_complex_branch" | |||
| @pytest.mark.skip(reason="output in this workflow can be random") | |||
| def test_hello_branch_with_llm(self): | |||
| """ | |||
| Test when query contains 'hello' - should trigger true branch. | |||
| @@ -12,7 +12,7 @@ This module provides a robust table-driven testing framework with support for: | |||
| import logging | |||
| import time | |||
| from collections.abc import Callable | |||
| from collections.abc import Callable, Sequence | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from dataclasses import dataclass, field | |||
| from pathlib import Path | |||
| @@ -34,7 +34,11 @@ from core.workflow.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_engine import GraphEngine | |||
| from core.workflow.graph_engine.command_channels import InMemoryChannel | |||
| from core.workflow.graph_events import GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent | |||
| from core.workflow.graph_events import ( | |||
| GraphEngineEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| ) | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| @@ -57,7 +61,7 @@ class WorkflowTestCase: | |||
| timeout: float = 30.0 | |||
| mock_config: Optional[MockConfig] = None | |||
| use_auto_mock: bool = False | |||
| expected_event_sequence: Optional[list[type[GraphEngineEvent]]] = None | |||
| expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None | |||
| tags: list[str] = field(default_factory=list) | |||
| skip: bool = False | |||
| skip_reason: str = "" | |||
| @@ -9,13 +9,6 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ | |||
| from .test_table_runner import TableTestRunner, WorkflowTestCase | |||
| def mock_template_transform_run(self): | |||
| """Mock the TemplateTransformNode._run() method to return results based on node title.""" | |||
| title = self._node_data.title | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) | |||
| @pytest.mark.skip | |||
| class TestVariableAggregator: | |||
| """Test cases for the variable aggregator workflow.""" | |||
| @@ -37,6 +30,12 @@ class TestVariableAggregator: | |||
| description: str, | |||
| ) -> None: | |||
| """Test all four combinations of switch1 and switch2.""" | |||
| def mock_template_transform_run(self): | |||
| """Mock the TemplateTransformNode._run() method to return results based on node title.""" | |||
| title = self._node_data.title | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) | |||
| with patch.object( | |||
| TemplateTransformNode, | |||
| "_run", | |||
| @@ -1,353 +0,0 @@ | |||
| import httpx | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.file import File, FileTransferMethod, FileType | |||
| from core.variables import ArrayFileVariable, FileVariable | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.enums import WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end.entities import EndStreamParam | |||
| from core.workflow.nodes.http_request import ( | |||
| BodyData, | |||
| HttpRequestNode, | |||
| HttpRequestNodeAuthorization, | |||
| HttpRequestNodeBody, | |||
| HttpRequestNodeData, | |||
| ) | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| @pytest.mark.skip( | |||
| reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " | |||
| "needs rewrite for new architecture" | |||
| ) | |||
| def test_http_request_node_binary_file(monkeypatch): | |||
| data = HttpRequestNodeData( | |||
| title="test", | |||
| method="post", | |||
| url="http://example.org/post", | |||
| authorization=HttpRequestNodeAuthorization(type="no-auth"), | |||
| headers="", | |||
| params="", | |||
| body=HttpRequestNodeBody( | |||
| type="binary", | |||
| data=[ | |||
| BodyData( | |||
| key="file", | |||
| type="file", | |||
| value="", | |||
| file=["1111", "file"], | |||
| ) | |||
| ], | |||
| ), | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| ) | |||
| variable_pool.add( | |||
| ["1111", "file"], | |||
| FileVariable( | |||
| name="file", | |||
| value=File( | |||
| tenant_id="1", | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1111", | |||
| storage_key="", | |||
| ), | |||
| ), | |||
| ) | |||
| node_config = { | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| } | |||
| node = HttpRequestNode( | |||
| id="1", | |||
| config=node_config, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| # Initialize node data | |||
| node.init_node_data(node_config["data"]) | |||
| monkeypatch.setattr( | |||
| "core.workflow.nodes.http_request.executor.file_manager.download", | |||
| lambda *args, **kwargs: b"test", | |||
| ) | |||
| monkeypatch.setattr( | |||
| "core.helper.ssrf_proxy.post", | |||
| lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]), | |||
| ) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["body"] == "test" | |||
| @pytest.mark.skip( | |||
| reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " | |||
| "needs rewrite for new architecture" | |||
| ) | |||
| def test_http_request_node_form_with_file(monkeypatch): | |||
| data = HttpRequestNodeData( | |||
| title="test", | |||
| method="post", | |||
| url="http://example.org/post", | |||
| authorization=HttpRequestNodeAuthorization(type="no-auth"), | |||
| headers="", | |||
| params="", | |||
| body=HttpRequestNodeBody( | |||
| type="form-data", | |||
| data=[ | |||
| BodyData( | |||
| key="file", | |||
| type="file", | |||
| file=["1111", "file"], | |||
| ), | |||
| BodyData( | |||
| key="name", | |||
| type="text", | |||
| value="test", | |||
| ), | |||
| ], | |||
| ), | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| ) | |||
| variable_pool.add( | |||
| ["1111", "file"], | |||
| FileVariable( | |||
| name="file", | |||
| value=File( | |||
| tenant_id="1", | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="1111", | |||
| storage_key="", | |||
| ), | |||
| ), | |||
| ) | |||
| node_config = { | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| } | |||
| node = HttpRequestNode( | |||
| id="1", | |||
| config=node_config, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| # Initialize node data | |||
| node.init_node_data(node_config["data"]) | |||
| monkeypatch.setattr( | |||
| "core.workflow.nodes.http_request.executor.file_manager.download", | |||
| lambda *args, **kwargs: b"test", | |||
| ) | |||
| def attr_checker(*args, **kwargs): | |||
| assert kwargs["data"] == {"name": "test"} | |||
| assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))] | |||
| return httpx.Response(200, content=b"") | |||
| monkeypatch.setattr( | |||
| "core.helper.ssrf_proxy.post", | |||
| attr_checker, | |||
| ) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["body"] == "" | |||
| @pytest.mark.skip( | |||
| reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " | |||
| "needs rewrite for new architecture" | |||
| ) | |||
| def test_http_request_node_form_with_multiple_files(monkeypatch): | |||
| data = HttpRequestNodeData( | |||
| title="test", | |||
| method="post", | |||
| url="http://example.org/upload", | |||
| authorization=HttpRequestNodeAuthorization(type="no-auth"), | |||
| headers="", | |||
| params="", | |||
| body=HttpRequestNodeBody( | |||
| type="form-data", | |||
| data=[ | |||
| BodyData( | |||
| key="files", | |||
| type="file", | |||
| file=["1111", "files"], | |||
| ), | |||
| BodyData( | |||
| key="name", | |||
| type="text", | |||
| value="test", | |||
| ), | |||
| ], | |||
| ), | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| ) | |||
| files = [ | |||
| File( | |||
| tenant_id="1", | |||
| type=FileType.IMAGE, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="file1", | |||
| filename="image1.jpg", | |||
| mime_type="image/jpeg", | |||
| storage_key="", | |||
| ), | |||
| File( | |||
| tenant_id="1", | |||
| type=FileType.DOCUMENT, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| related_id="file2", | |||
| filename="document.pdf", | |||
| mime_type="application/pdf", | |||
| storage_key="", | |||
| ), | |||
| ] | |||
| variable_pool.add( | |||
| ["1111", "files"], | |||
| ArrayFileVariable( | |||
| name="files", | |||
| value=files, | |||
| ), | |||
| ) | |||
| node_config = { | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| } | |||
| node = HttpRequestNode( | |||
| id="1", | |||
| config=node_config, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| # Initialize node data | |||
| node.init_node_data(node_config["data"]) | |||
| monkeypatch.setattr( | |||
| "core.workflow.nodes.http_request.executor.file_manager.download", | |||
| lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", | |||
| ) | |||
| def attr_checker(*args, **kwargs): | |||
| assert kwargs["data"] == {"name": "test"} | |||
| assert len(kwargs["files"]) == 2 | |||
| assert kwargs["files"][0][0] == "files" | |||
| assert kwargs["files"][1][0] == "files" | |||
| file_tuples = [f[1] for f in kwargs["files"]] | |||
| file_contents = [f[1] for f in file_tuples] | |||
| file_types = [f[2] for f in file_tuples] | |||
| assert b"test_image_data" in file_contents | |||
| assert b"test_pdf_data" in file_contents | |||
| assert "image/jpeg" in file_types | |||
| assert "application/pdf" in file_types | |||
| return httpx.Response(200, content=b'{"status":"success"}') | |||
| monkeypatch.setattr( | |||
| "core.helper.ssrf_proxy.post", | |||
| attr_checker, | |||
| ) | |||
| result = node._run() | |||
| assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert result.outputs is not None | |||
| assert result.outputs["body"] == '{"status":"success"}' | |||
| print(result.outputs["body"]) | |||
| @@ -1,909 +0,0 @@ | |||
| import time | |||
| import uuid | |||
| from unittest.mock import patch | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.variables.segments import ArrayAnySegment, ArrayStringSegment | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.enums import WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.node_events import NodeRunResult, StreamCompletedEvent | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode | |||
| from core.workflow.nodes.iteration.iteration_node import IterationNode | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| @pytest.mark.skip( | |||
| reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" | |||
| ) | |||
| def test_run(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "tt", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "answer-2", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "answer-4", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "tt", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "{{#tt.output#}}", | |||
| "iteration_id": "iteration-1", | |||
| "title": "answer 2", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 123", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "hi", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, | |||
| "id": "answer-4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables=SystemVariable( | |||
| user_id="1", | |||
| files=[], | |||
| query="dify", | |||
| conversation_id="abababa", | |||
| ), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| node_config = { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "tt", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| } | |||
| iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| config=node_config, | |||
| ) | |||
| # Initialize node data | |||
| iteration_node.init_node_data(node_config["data"]) | |||
| def tt_generator(self): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={"iterator_selector": "dify"}, | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| result = iteration_node._run() | |||
| count = 0 | |||
| for item in result: | |||
| # print(type(item), item) | |||
| count += 1 | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} | |||
| assert count == 20 | |||
| @pytest.mark.skip( | |||
| reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" | |||
| ) | |||
| def test_run_parallel(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-target", | |||
| "source": "iteration-start", | |||
| "target": "tt", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-2-target", | |||
| "source": "iteration-start", | |||
| "target": "tt-2", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "tt", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "tt-2-source-if-else-target", | |||
| "source": "tt-2", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "answer-2", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "answer-4", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "{{#tt.output#}}", | |||
| "iteration_id": "iteration-1", | |||
| "title": "answer 2", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "title": "iteration-start", | |||
| "type": "iteration-start", | |||
| }, | |||
| "id": "iteration-start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 123", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 321", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt-2", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "hi", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, | |||
| "id": "answer-4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables=SystemVariable( | |||
| user_id="1", | |||
| files=[], | |||
| query="dify", | |||
| conversation_id="abababa", | |||
| ), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) | |||
| node_config = { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| } | |||
| iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| config=node_config, | |||
| ) | |||
| # Initialize node data | |||
| iteration_node.init_node_data(node_config["data"]) | |||
| def tt_generator(self): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={"iterator_selector": "dify"}, | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| result = iteration_node._run() | |||
| count = 0 | |||
| for item in result: | |||
| count += 1 | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} | |||
| assert count == 32 | |||
| @pytest.mark.skip( | |||
| reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" | |||
| ) | |||
| def test_iteration_run_in_parallel_mode(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-target", | |||
| "source": "iteration-start", | |||
| "target": "tt", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-2-target", | |||
| "source": "iteration-start", | |||
| "target": "tt-2", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "tt", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "tt-2-source-if-else-target", | |||
| "source": "tt-2", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "answer-2", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "answer-4", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "{{#tt.output#}}", | |||
| "iteration_id": "iteration-1", | |||
| "title": "answer 2", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "title": "iteration-start", | |||
| "type": "iteration-start", | |||
| }, | |||
| "id": "iteration-start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 123", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 321", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt-2", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "hi", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, | |||
| "id": "answer-4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables=SystemVariable( | |||
| user_id="1", | |||
| files=[], | |||
| query="dify", | |||
| conversation_id="abababa", | |||
| ), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) | |||
| parallel_node_config = { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| }, | |||
| "id": "iteration-1", | |||
| } | |||
| parallel_iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| config=parallel_node_config, | |||
| ) | |||
| # Initialize node data | |||
| parallel_iteration_node.init_node_data(parallel_node_config["data"]) | |||
| sequential_node_config = { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| }, | |||
| "id": "iteration-1", | |||
| } | |||
| sequential_iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| config=sequential_node_config, | |||
| ) | |||
| # Initialize node data | |||
| sequential_iteration_node.init_node_data(sequential_node_config["data"]) | |||
| def tt_generator(self): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={"iterator_selector": "dify"}, | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| parallel_result = parallel_iteration_node._run() | |||
| sequential_result = sequential_iteration_node._run() | |||
| assert parallel_iteration_node._node_data.parallel_nums == 10 | |||
| assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED | |||
| count = 0 | |||
| parallel_arr = [] | |||
| sequential_arr = [] | |||
| for item in parallel_result: | |||
| count += 1 | |||
| parallel_arr.append(item) | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} | |||
| assert count == 32 | |||
| for item in sequential_result: | |||
| sequential_arr.append(item) | |||
| count += 1 | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} | |||
| assert count == 64 | |||
| @pytest.mark.skip( | |||
| reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" | |||
| ) | |||
| def test_iteration_run_error_handle(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "iteration-start", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "tt", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "tt2", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt2", "output"], | |||
| "output_type": "array[string]", | |||
| "start_node_id": "if-else", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1.split(arg2) }}", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [ | |||
| {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, | |||
| {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, | |||
| ], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }}", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [ | |||
| {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, | |||
| ], | |||
| }, | |||
| "id": "tt2", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "title": "iteration-start", | |||
| "type": "iteration-start", | |||
| }, | |||
| "id": "iteration-start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "1", | |||
| "variable_selector": ["iteration-1", "item"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables=SystemVariable( | |||
| user_id="1", | |||
| files=[], | |||
| query="dify", | |||
| conversation_id="abababa", | |||
| ), | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) | |||
| node_factory = DifyNodeFactory( | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| ) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| pool.add(["pe", "list_output"], ["1", "1"]) | |||
| error_node_config = { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, | |||
| }, | |||
| "id": "iteration-1", | |||
| } | |||
| iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph_runtime_state=graph_runtime_state, | |||
| config=error_node_config, | |||
| ) | |||
| # Initialize node data | |||
| iteration_node.init_node_data(error_node_config["data"]) | |||
| # execute continue on error node | |||
| result = iteration_node._run() | |||
| result_arr = [] | |||
| count = 0 | |||
| for item in result: | |||
| result_arr.append(item) | |||
| count += 1 | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[None, None])} | |||
| assert count == 14 | |||
| # execute remove abnormal output | |||
| iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT | |||
| result = iteration_node._run() | |||
| count = 0 | |||
| for item in result: | |||
| count += 1 | |||
| if isinstance(item, StreamCompletedEvent): | |||
| assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[])} | |||
| assert count == 14 | |||
| @@ -1,624 +0,0 @@ | |||
| import time | |||
| from unittest.mock import patch | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.enums import ( | |||
| WorkflowNodeExecutionMetadataKey, | |||
| WorkflowNodeExecutionStatus, | |||
| ) | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_engine import GraphEngine | |||
| from core.workflow.graph_engine.command_channels import InMemoryChannel | |||
| from core.workflow.graph_events import ( | |||
| GraphRunPartialSucceededEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| ) | |||
| from core.workflow.node_events import NodeRunResult, StreamCompletedEvent | |||
| from core.workflow.nodes.llm.node import LLMNode | |||
| from core.workflow.nodes.node_factory import DifyNodeFactory | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models.enums import UserFrom | |||
| class ContinueOnErrorTestHelper: | |||
| @staticmethod | |||
| def get_code_node( | |||
| code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} | |||
| ): | |||
| """Helper method to create a code node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "outputs": {"result": {"type": "number"}}, | |||
| "error_strategy": error_strategy, | |||
| "title": "code", | |||
| "variables": [], | |||
| "code_language": "python3", | |||
| "code": "\n".join([line[4:] for line in code.split("\n")]), | |||
| "type": "code", | |||
| **retry_config, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_http_node( | |||
| error_strategy: str = "fail-branch", | |||
| default_value: dict | None = None, | |||
| authorization_success: bool = False, | |||
| retry_config: dict = {}, | |||
| ): | |||
| """Helper method to create a http node configuration""" | |||
| authorization = ( | |||
| { | |||
| "type": "api-key", | |||
| "config": { | |||
| "type": "basic", | |||
| "api_key": "ak-xxx", | |||
| "header": "api-key", | |||
| }, | |||
| } | |||
| if authorization_success | |||
| else { | |||
| "type": "api-key", | |||
| # missing config field | |||
| } | |||
| ) | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "http", | |||
| "desc": "", | |||
| "method": "get", | |||
| "url": "http://example.com", | |||
| "authorization": authorization, | |||
| "headers": "X-Header:123", | |||
| "params": "A:b", | |||
| "body": None, | |||
| "type": "http-request", | |||
| "error_strategy": error_strategy, | |||
| **retry_config, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a http node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "type": "http-request", | |||
| "title": "HTTP Request", | |||
| "desc": "", | |||
| "variables": [], | |||
| "method": "get", | |||
| "url": "https://api.github.com/issues", | |||
| "authorization": {"type": "no-auth", "config": None}, | |||
| "headers": "", | |||
| "params": "", | |||
| "body": {"type": "none", "data": []}, | |||
| "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a tool node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "a", | |||
| "desc": "a", | |||
| "provider_id": "maths", | |||
| "provider_type": "builtin", | |||
| "provider_name": "maths", | |||
| "tool_name": "eval_expression", | |||
| "tool_label": "eval_expression", | |||
| "tool_configurations": {}, | |||
| "tool_parameters": { | |||
| "expression": { | |||
| "type": "variable", | |||
| "value": ["1", "123", "args1"], | |||
| } | |||
| }, | |||
| "type": "tool", | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node.node_data.default_value = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a llm node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, | |||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): | |||
| """Helper method to create a graph engine instance for testing""" | |||
| # Create graph initialization parameters | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable( | |||
| user_id="aaa", | |||
| files=[], | |||
| query="clear", | |||
| conversation_id="abababa", | |||
| ), | |||
| user_inputs=user_inputs or {"uid": "takato"}, | |||
| ) | |||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | |||
| node_factory = DifyNodeFactory(init_params, graph_runtime_state) | |||
| graph = Graph.init(graph_config=graph_config, node_factory=node_factory) | |||
| return GraphEngine( | |||
| tenant_id="111", | |||
| app_id="222", | |||
| workflow_id="333", | |||
| graph_config=graph_config, | |||
| user_id="444", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| call_depth=0, | |||
| graph=graph, | |||
| graph_runtime_state=graph_runtime_state, | |||
| max_execution_steps=500, | |||
| max_execution_time=1200, | |||
| command_channel=InMemoryChannel(), | |||
| ) | |||
| DEFAULT_VALUE_EDGE = [ | |||
| { | |||
| "id": "start-source-node-target", | |||
| "source": "start", | |||
| "target": "node", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-source-answer-target", | |||
| "source": "node", | |||
| "target": "answer", | |||
| "sourceHandle": "source", | |||
| }, | |||
| ] | |||
| FAIL_BRANCH_EDGES = [ | |||
| { | |||
| "id": "start-source-node-target", | |||
| "source": "start", | |||
| "target": "node", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-true-success-target", | |||
| "source": "node", | |||
| "target": "success", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-false-error-target", | |||
| "source": "node", | |||
| "target": "error", | |||
| "sourceHandle": "fail-branch", | |||
| }, | |||
| ] | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_code_default_value_continue_on_error(): | |||
| error_code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": 1 / 0, | |||
| } | |||
| """ | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_code_node( | |||
| error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_code_fail_branch_continue_on_error(): | |||
| error_code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": 1 / 0, | |||
| } | |||
| """ | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_code_node(error_code), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events | |||
| ) | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_http_node_default_value_continue_on_error(): | |||
| """Test HTTP node with default value error strategy""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_http_node( | |||
| "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} | |||
| for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_http_node_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| # def test_tool_node_default_value_continue_on_error(): | |||
| # """Test tool node with default value error strategy""" | |||
| # graph_config = { | |||
| # "edges": DEFAULT_VALUE_EDGE, | |||
| # "nodes": [ | |||
| # {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| # {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| # ContinueOnErrorTestHelper.get_tool_node( | |||
| # "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] | |||
| # ), | |||
| # ], | |||
| # } | |||
| # graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| # events = list(graph_engine.run()) | |||
| # assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| # assert any( | |||
| # isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501 | |||
| # ) | |||
| # assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| # def test_tool_node_fail_branch_continue_on_error(): | |||
| # """Test HTTP node with fail-branch error strategy""" | |||
| # graph_config = { | |||
| # "edges": FAIL_BRANCH_EDGES, | |||
| # "nodes": [ | |||
| # {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| # { | |||
| # "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, | |||
| # "id": "success", | |||
| # }, | |||
| # { | |||
| # "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, | |||
| # "id": "error", | |||
| # }, | |||
| # ContinueOnErrorTestHelper.get_tool_node(), | |||
| # ], | |||
| # } | |||
| # graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| # events = list(graph_engine.run()) | |||
| # assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| # assert any( | |||
| # isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501 | |||
| # ) | |||
| # assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_llm_node_default_value_continue_on_error(): | |||
| """Test LLM node with default value error strategy""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_llm_node( | |||
| "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_llm_node_fail_branch_continue_on_error(): | |||
| """Test LLM node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_llm_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_status_code_error_http_node_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_error_status_code_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_variable_pool_error_type_variable(): | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_error_status_code_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| list(graph_engine.run()) | |||
| error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) | |||
| error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) | |||
| assert error_message != None | |||
| assert error_type.value == "HTTPResponseCodeError" | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_no_node_in_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES[:-1], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"}, | |||
| ContinueOnErrorTestHelper.get_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 | |||
| @pytest.mark.skip( | |||
| reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " | |||
| "not fully implemented in MVP of queue-based engine" | |||
| ) | |||
| def test_stream_output_with_fail_branch_continue_on_error(): | |||
| """Test stream output with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_llm_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| def llm_generator(self): | |||
| contents = ["hi", "bye", "good morning"] | |||
| yield NodeRunStreamChunkEvent( | |||
| node_id=self.node_id, | |||
| node_type=self._node_type, | |||
| selector=[self.node_id, "text"], | |||
| chunk=contents[0], | |||
| is_final=False, | |||
| ) | |||
| yield StreamCompletedEvent( | |||
| node_run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={}, | |||
| process_data={}, | |||
| outputs={}, | |||
| metadata={ | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1, | |||
| WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 1, | |||
| WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", | |||
| }, | |||
| ) | |||
| ) | |||
| with patch.object(LLMNode, "_run", new=llm_generator): | |||
| events = list(graph_engine.run()) | |||
| assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1 | |||
| assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events) | |||
| @@ -1,116 +0,0 @@ | |||
| from collections.abc import Generator | |||
| import pytest | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||
| from core.tools.errors import ToolInvokeError | |||
| from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool | |||
| from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionStatus | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.node_events import NodeRunResult, StreamCompletedEvent | |||
| from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end.entities import EndStreamParam | |||
| from core.workflow.nodes.tool import ToolNode | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| from core.workflow.system_variable import SystemVariable | |||
| from models import UserFrom | |||
| def _create_tool_node(): | |||
| data = ToolNodeData( | |||
| title="Test Tool", | |||
| tool_parameters={}, | |||
| provider_id="test_tool", | |||
| provider_type=ToolProviderType.WORKFLOW, | |||
| provider_name="test tool", | |||
| tool_name="test tool", | |||
| tool_label="test tool", | |||
| tool_configurations={}, | |||
| plugin_unique_identifier=None, | |||
| desc="Exception handling test tool", | |||
| error_strategy=ErrorStrategy.FAIL_BRANCH, | |||
| version="1", | |||
| ) | |||
| variable_pool = VariablePool( | |||
| system_variables=SystemVariable.empty(), | |||
| user_inputs={}, | |||
| ) | |||
| node_config = { | |||
| "id": "1", | |||
| "data": data.model_dump(), | |||
| } | |||
| node = ToolNode( | |||
| id="1", | |||
| config=node_config, | |||
| graph_init_params=GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_id="1", | |||
| graph_config={}, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| call_depth=0, | |||
| ), | |||
| graph=Graph( | |||
| root_node_id="1", | |||
| answer_stream_generate_routes=AnswerStreamGenerateRoute( | |||
| answer_dependencies={}, | |||
| answer_generate_route={}, | |||
| ), | |||
| end_stream_param=EndStreamParam( | |||
| end_dependencies={}, | |||
| end_stream_variable_selector_mapping={}, | |||
| ), | |||
| ), | |||
| graph_runtime_state=GraphRuntimeState( | |||
| variable_pool=variable_pool, | |||
| start_at=0, | |||
| ), | |||
| ) | |||
| # Initialize node data | |||
| node.init_node_data(node_config["data"]) | |||
| return node | |||
| class MockToolRuntime: | |||
| def get_merged_runtime_parameters(self): | |||
| pass | |||
| def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]: | |||
| yield from [] | |||
| raise ToolInvokeError("oops") | |||
| @pytest.mark.skip( | |||
| reason="Tool node test uses old Graph constructor incompatible with new queue-based engine - " | |||
| "needs rewrite for new architecture" | |||
| ) | |||
| def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): | |||
| """Ensure that ToolNode can handle ToolInvokeError when transforming | |||
| messages generated by ToolEngine.generic_invoke. | |||
| """ | |||
| tool_node = _create_tool_node() | |||
| # Need to patch ToolManager and ToolEngine so that we don't | |||
| # have to set up a database. | |||
| monkeypatch.setattr( | |||
| "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime() | |||
| ) | |||
| monkeypatch.setattr( | |||
| "core.tools.tool_engine.ToolEngine.generic_invoke", | |||
| lambda *args, **kwargs: mock_message_stream(), | |||
| ) | |||
| streams = list(tool_node._run()) | |||
| assert len(streams) == 1 | |||
| stream = streams[0] | |||
| assert isinstance(stream, StreamCompletedEvent) | |||
| result = stream.node_run_result | |||
| assert isinstance(result, NodeRunResult) | |||
| assert result.status == WorkflowNodeExecutionStatus.FAILED | |||
| assert "oops" in result.error | |||
| assert "Failed to invoke tool" in result.error | |||
| assert result.error_type == "ToolInvokeError" | |||
| @@ -14,6 +14,22 @@ interface HeaderParams { | |||
| interface User { | |||
| } | |||
| interface DifyFileBase { | |||
| type: "image" | |||
| } | |||
| export interface DifyRemoteFile extends DifyFileBase { | |||
| transfer_method: "remote_url" | |||
| url: string | |||
| } | |||
| export interface DifyLocalFile extends DifyFileBase { | |||
| transfer_method: "local_file" | |||
| upload_file_id: string | |||
| } | |||
| export type DifyFile = DifyRemoteFile | DifyLocalFile; | |||
| export declare class DifyClient { | |||
| constructor(apiKey: string, baseUrl?: string); | |||
| @@ -44,7 +60,7 @@ export declare class CompletionClient extends DifyClient { | |||
| inputs: any, | |||
| user: User, | |||
| stream?: boolean, | |||
| files?: File[] | null | |||
| files?: DifyFile[] | null | |||
| ): Promise<any>; | |||
| } | |||
| @@ -55,7 +71,7 @@ export declare class ChatClient extends DifyClient { | |||
| user: User, | |||
| stream?: boolean, | |||
| conversation_id?: string | null, | |||
| files?: File[] | null | |||
| files?: DifyFile[] | null | |||
| ): Promise<any>; | |||
| getSuggested(message_id: string, user: User): Promise<any>; | |||
| @@ -32,6 +32,7 @@ export const checkOrSetAccessToken = async (appCode?: string | null) => { | |||
| [userId || 'DEFAULT']: res.access_token, | |||
| } | |||
| localStorage.setItem('token', JSON.stringify(accessTokenJson)) | |||
| localStorage.removeItem(CONVERSATION_ID_INFO) | |||
| } | |||
| } | |||
| @@ -11,6 +11,7 @@ import type { FC, PropsWithChildren } from 'react' | |||
| import { useEffect } from 'react' | |||
| import { useState } from 'react' | |||
| import { create } from 'zustand' | |||
| import { useGlobalPublicStore } from './global-public-context' | |||
| type WebAppStore = { | |||
| shareCode: string | null | |||
| @@ -56,6 +57,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => { | |||
| } | |||
| const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => { | |||
| const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) | |||
| const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode) | |||
| const updateShareCode = useWebAppStore(state => state.updateShareCode) | |||
| const pathname = usePathname() | |||
| @@ -69,7 +71,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => { | |||
| }, [shareCode, updateShareCode]) | |||
| const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode) | |||
| const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false) | |||
| const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true) | |||
| useEffect(() => { | |||
| if (accessModeResult?.accessMode) { | |||
| @@ -86,7 +88,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => { | |||
| } | |||
| }, [accessModeResult, updateWebAppAccessMode, shareCode]) | |||
| if (isFetching || isFetchingAccessToken) { | |||
| if (isGlobalPending || isFetching || isFetchingAccessToken) { | |||
| return <div className='flex h-full w-full items-center justify-center'> | |||
| <Loading /> | |||
| </div> | |||
| @@ -430,9 +430,7 @@ export const ssePost = async ( | |||
| .then((res) => { | |||
| if (!/^[23]\d{2}$/.test(String(res.status))) { | |||
| if (res.status === 401) { | |||
| refreshAccessTokenOrRelogin(TIME_OUT).then(() => { | |||
| ssePost(url, fetchOptions, otherOptions) | |||
| }).catch(() => { | |||
| if (isPublicAPI) { | |||
| res.json().then((data: any) => { | |||
| if (isPublicAPI) { | |||
| if (data.code === 'web_app_access_denied') | |||
| @@ -449,7 +447,14 @@ export const ssePost = async ( | |||
| } | |||
| } | |||
| }) | |||
| }) | |||
| } | |||
| else { | |||
| refreshAccessTokenOrRelogin(TIME_OUT).then(() => { | |||
| ssePost(url, fetchOptions, otherOptions) | |||
| }).catch((err) => { | |||
| console.error(err) | |||
| }) | |||
| } | |||
| } | |||
| else { | |||
| res.json().then((data) => { | |||
| @@ -1,20 +1,12 @@ | |||
| import { useGlobalPublicStore } from '@/context/global-public-context' | |||
| import { AccessMode } from '@/models/access-control' | |||
| import { useQuery } from '@tanstack/react-query' | |||
| import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share' | |||
| const NAME_SPACE = 'webapp' | |||
| export const useGetWebAppAccessModeByCode = (code: string | null) => { | |||
| const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) | |||
| return useQuery({ | |||
| queryKey: [NAME_SPACE, 'appAccessMode', code], | |||
| queryFn: () => { | |||
| if (systemFeatures.webapp_auth.enabled === false) { | |||
| return { | |||
| accessMode: AccessMode.PUBLIC, | |||
| } | |||
| } | |||
| if (!code || code.length === 0) | |||
| return Promise.reject(new Error('App code is required to get access mode')) | |||