Explorar el Código

Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

tags/2.0.0-beta.1
jyong hace 2 meses
padre
commit
1db04aa729
Se han modificado 71 ficheros con 801 adiciones y 2326 borrados
  1. 6
    8
      api/controllers/console/app/message.py
  2. 6
    5
      api/core/tools/utils/message_transformer.py
  3. 4
    6
      api/core/tools/workflow_as_tool/tool.py
  4. 1
    0
      api/core/workflow/enums.py
  5. 77
    8
      api/core/workflow/graph/graph.py
  6. 2
    0
      api/core/workflow/graph_engine/command_channels/in_memory_channel.py
  7. 3
    2
      api/core/workflow/graph_engine/command_channels/redis_channel.py
  8. 2
    0
      api/core/workflow/graph_engine/command_processing/command_handlers.py
  9. 2
    1
      api/core/workflow/graph_engine/command_processing/command_processor.py
  10. 1
    2
      api/core/workflow/graph_engine/domain/graph_execution.py
  11. 2
    3
      api/core/workflow/graph_engine/domain/node_execution.py
  12. 3
    3
      api/core/workflow/graph_engine/entities/commands.py
  13. 0
    2
      api/core/workflow/graph_engine/error_handling/__init__.py
  14. 3
    2
      api/core/workflow/graph_engine/error_handling/abort_strategy.py
  15. 3
    2
      api/core/workflow/graph_engine/error_handling/default_value_strategy.py
  16. 19
    20
      api/core/workflow/graph_engine/error_handling/error_handler.py
  17. 3
    2
      api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py
  18. 3
    2
      api/core/workflow/graph_engine/error_handling/retry_strategy.py
  19. 86
    6
      api/core/workflow/graph_engine/event_management/event_collector.py
  20. 2
    0
      api/core/workflow/graph_engine/event_management/event_emitter.py
  21. 66
    91
      api/core/workflow/graph_engine/event_management/event_handlers.py
  22. 17
    15
      api/core/workflow/graph_engine/graph_engine.py
  23. 9
    4
      api/core/workflow/graph_engine/graph_traversal/branch_handler.py
  24. 19
    10
      api/core/workflow/graph_engine/graph_traversal/edge_processor.py
  25. 4
    1
      api/core/workflow/graph_engine/graph_traversal/node_readiness.py
  26. 8
    6
      api/core/workflow/graph_engine/graph_traversal/skip_propagator.py
  27. 3
    4
      api/core/workflow/graph_engine/layers/base.py
  28. 3
    2
      api/core/workflow/graph_engine/layers/debug_logging.py
  29. 4
    3
      api/core/workflow/graph_engine/layers/execution_limits.py
  30. 3
    2
      api/core/workflow/graph_engine/manager.py
  31. 8
    5
      api/core/workflow/graph_engine/orchestration/dispatcher.py
  32. 2
    1
      api/core/workflow/graph_engine/orchestration/execution_coordinator.py
  33. 4
    3
      api/core/workflow/graph_engine/output_registry/registry.py
  34. 3
    2
      api/core/workflow/graph_engine/output_registry/stream.py
  35. 2
    2
      api/core/workflow/graph_engine/protocols/error_strategy.py
  36. 3
    2
      api/core/workflow/graph_engine/response_coordinator/coordinator.py
  37. 6
    4
      api/core/workflow/graph_engine/state_management/edge_state_manager.py
  38. 2
    0
      api/core/workflow/graph_engine/state_management/execution_tracker.py
  39. 2
    0
      api/core/workflow/graph_engine/state_management/node_state_manager.py
  40. 6
    5
      api/core/workflow/graph_engine/worker.py
  41. 2
    0
      api/core/workflow/graph_engine/worker_management/activity_tracker.py
  42. 3
    0
      api/core/workflow/graph_engine/worker_management/dynamic_scaler.py
  43. 5
    4
      api/core/workflow/graph_engine/worker_management/worker_factory.py
  44. 2
    0
      api/core/workflow/graph_engine/worker_management/worker_pool.py
  45. 2
    1
      api/core/workflow/nodes/start/start_node.py
  46. 1
    1
      api/extensions/ext_storage.py
  47. 1
    1
      api/extensions/storage/storage_type.py
  48. 0
    4
      api/factories/variable_factory.py
  49. 5
    5
      api/libs/helper.py
  50. 2
    2
      api/models/model.py
  51. 8
    7
      api/models/workflow.py
  52. 3
    5
      api/services/dataset_service.py
  53. 17
    10
      api/services/tools/builtin_tools_manage_service.py
  54. 5
    6
      api/services/workflow_service.py
  55. 3
    2
      api/tasks/annotation/disable_annotation_reply_task.py
  56. 11
    11
      api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py
  57. 2
    2
      api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py
  58. 281
    0
      api/tests/unit_tests/core/workflow/graph/test_graph.py
  59. 1
    1
      api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py
  60. 7
    3
      api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py
  61. 6
    7
      api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py
  62. 0
    353
      api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
  63. 0
    0
      api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py
  64. 0
    909
      api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
  65. 0
    624
      api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
  66. 0
    116
      api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
  67. 18
    2
      sdks/nodejs-client/index.d.ts
  68. 1
    0
      web/app/components/share/utils.ts
  69. 4
    2
      web/context/web-app-context.tsx
  70. 9
    4
      web/service/base.ts
  71. 0
    8
      web/service/use-share.ts

+ 6
- 8
api/controllers/console/app/message.py Ver fichero

@@ -3,6 +3,7 @@ import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound

from controllers.console import api
@@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
.all()
)

has_more = False
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.where(

has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)

if rest_count > 0:
has_more = True
)

history_messages = list(reversed(history_messages))


+ 6
- 5
api/core/tools/utils/message_transformer.py Ver fichero

@@ -8,20 +8,21 @@ from uuid import UUID

import numpy as np
import pytz
from flask_login import current_user

from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models.account import Account

logger = logging.getLogger(__name__)


def safe_json_value(v):
if isinstance(v, datetime):
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
if not tz_name:
tz_name = "UTC"
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
@@ -46,7 +47,7 @@ def safe_json_value(v):
return v


def safe_json_dict(d):
def safe_json_dict(d: dict):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}

+ 4
- 6
api/core/tools/workflow_as_tool/tool.py Ver fichero

@@ -3,8 +3,6 @@ import logging
from collections.abc import Generator
from typing import Any, Optional, cast

from flask_login import current_user

from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import (
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from models.account import Account
from models.model import App, EndUser
from libs.login import current_user
from models.model import App
from models.workflow import Workflow

logger = logging.getLogger(__name__)
@@ -79,11 +77,11 @@ class WorkflowTool(Tool):
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
assert current_user is not None
result = generator.generate(
app_model=app,
workflow=workflow,
user=cast("Account | EndUser", current_user),
user=current_user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,

+ 1
- 0
api/core/workflow/enums.py Ver fichero

@@ -66,6 +66,7 @@ class NodeExecutionType(StrEnum):
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points


class ErrorStrategy(StrEnum):

+ 77
- 8
api/core/workflow/graph/graph.py Ver fichero

@@ -1,9 +1,9 @@
import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, Protocol, cast
from typing import Any, Protocol, cast

from core.workflow.enums import NodeType
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node

from .edge import Edge
@@ -36,10 +36,10 @@ class Graph:
def __init__(
self,
*,
nodes: Optional[dict[str, Node]] = None,
edges: Optional[dict[str, Edge]] = None,
in_edges: Optional[dict[str, list[str]]] = None,
out_edges: Optional[dict[str, list[str]]] = None,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
@@ -81,7 +81,7 @@ class Graph:
cls,
node_configs_map: dict[str, dict[str, Any]],
edge_configs: list[dict[str, Any]],
root_node_id: Optional[str] = None,
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
@@ -186,13 +186,79 @@ class Graph:

return nodes

@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.

Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged

:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]

# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return

# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED

# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED

# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]

# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)

# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)

@classmethod
def init(
cls,
*,
graph_config: Mapping[str, Any],
node_factory: "NodeFactory",
root_node_id: Optional[str] = None,
root_node_id: str | None = None,
) -> "Graph":
"""
Initialize graph
@@ -227,6 +293,9 @@ class Graph:
# Get root node instance
root_node = nodes[root_node_id]

# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)

# Create and return the graph
return cls(
nodes=nodes,

+ 2
- 0
api/core/workflow/graph_engine/command_channels/in_memory_channel.py Ver fichero

@@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi
"""

from queue import Queue
from typing import final

from ..entities.commands import GraphEngineCommand


@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.

+ 3
- 2
api/core/workflow/graph_engine/command_channels/redis_channel.py Ver fichero

@@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
"""

import json
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final

from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand

@@ -15,6 +15,7 @@ if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper


@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
@@ -86,7 +87,7 @@ class RedisChannel:
pipe.expire(self._key, self._command_ttl)
pipe.execute()

def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.


+ 2
- 0
api/core/workflow/graph_engine/command_processing/command_handlers.py Ver fichero

@@ -3,6 +3,7 @@ Command handler implementations.
"""

import logging
from typing import final

from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
@@ -11,6 +12,7 @@ from .command_processor import CommandHandler
logger = logging.getLogger(__name__)


@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""


+ 2
- 1
api/core/workflow/graph_engine/command_processing/command_processor.py Ver fichero

@@ -3,7 +3,7 @@ Main command processor for handling external commands.
"""

import logging
from typing import Protocol
from typing import Protocol, final

from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
@@ -18,6 +18,7 @@ class CommandHandler(Protocol):
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...


@final
class CommandProcessor:
"""
Processes external commands sent to the engine.

+ 1
- 2
api/core/workflow/graph_engine/domain/graph_execution.py Ver fichero

@@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state.
"""

from dataclasses import dataclass, field
from typing import Optional

from .node_execution import NodeExecution

@@ -21,7 +20,7 @@ class GraphExecution:
started: bool = False
completed: bool = False
aborted: bool = False
error: Optional[Exception] = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)

def start(self) -> None:

+ 2
- 3
api/core/workflow/graph_engine/domain/node_execution.py Ver fichero

@@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state.
"""

from dataclasses import dataclass
from typing import Optional

from core.workflow.enums import NodeState

@@ -20,8 +19,8 @@ class NodeExecution:
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: Optional[str] = None
error: Optional[str] = None
execution_id: str | None = None
error: str | None = None

def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""

+ 3
- 3
api/core/workflow/graph_engine/entities/commands.py Ver fichero

@@ -6,7 +6,7 @@ instance to control its execution flow.
"""

from enum import Enum
from typing import Any, Optional
from typing import Any

from pydantic import BaseModel, Field

@@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""

command_type: CommandType = Field(..., description="Type of command")
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")


class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""

command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: Optional[str] = Field(default=None, description="Optional reason for abort")
reason: str | None = Field(default=None, description="Optional reason for abort")

+ 0
- 2
api/core/workflow/graph_engine/error_handling/__init__.py Ver fichero

@@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns.
from .abort_strategy import AbortStrategy
from .default_value_strategy import DefaultValueStrategy
from .error_handler import ErrorHandler
from .error_strategy import ErrorStrategy
from .fail_branch_strategy import FailBranchStrategy
from .retry_strategy import RetryStrategy

@@ -16,7 +15,6 @@ __all__ = [
"AbortStrategy",
"DefaultValueStrategy",
"ErrorHandler",
"ErrorStrategy",
"FailBranchStrategy",
"RetryStrategy",
]

+ 3
- 2
api/core/workflow/graph_engine/error_handling/abort_strategy.py Ver fichero

@@ -3,7 +3,7 @@ Abort error strategy implementation.
"""

import logging
from typing import Optional
from typing import final

from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
@@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
logger = logging.getLogger(__name__)


@final
class AbortStrategy:
"""
Error strategy that aborts execution on failure.
@@ -19,7 +20,7 @@ class AbortStrategy:
It stops the entire graph execution when a node fails.
"""

def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by aborting execution.


+ 3
- 2
api/core/workflow/graph_engine/error_handling/default_value_strategy.py Ver fichero

@@ -2,7 +2,7 @@
Default value error strategy implementation.
"""

from typing import Optional
from typing import final

from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
from core.workflow.node_events import NodeRunResult


@final
class DefaultValueStrategy:
"""
Error strategy that uses default values on failure.
@@ -18,7 +19,7 @@ class DefaultValueStrategy:
predefined default output values.
"""

def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by using default values.


+ 19
- 20
api/core/workflow/graph_engine/error_handling/error_handler.py Ver fichero

@@ -2,7 +2,7 @@
Main error handler that coordinates error strategies.
"""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final

from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
from core.workflow.graph import Graph
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
from ..domain import GraphExecution


@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
@@ -34,16 +35,16 @@ class ErrorHandler:
graph: The workflow graph
graph_execution: The graph execution state
"""
self.graph = graph
self.graph_execution = graph_execution
self._graph = graph
self._graph_execution = graph_execution

# Initialize strategies
self.abort_strategy = AbortStrategy()
self.retry_strategy = RetryStrategy()
self.fail_branch_strategy = FailBranchStrategy()
self.default_value_strategy = DefaultValueStrategy()
self._abort_strategy = AbortStrategy()
self._retry_strategy = RetryStrategy()
self._fail_branch_strategy = FailBranchStrategy()
self._default_value_strategy = DefaultValueStrategy()

def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.

@@ -56,14 +57,14 @@ class ErrorHandler:
Returns:
Optional new event to process, or None to abort
"""
node = self.graph.nodes[event.node_id]
node = self._graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count

# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
result = self._retry_strategy.handle_error(event, self._graph, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
@@ -71,12 +72,10 @@ class ErrorHandler:
# Apply configured error strategy
strategy = node.error_strategy

if strategy is None:
return self.abort_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
else:
# Unknown strategy, default to abort
return self.abort_strategy.handle_error(event, self.graph, retry_count)
match strategy:
case None:
return self._abort_strategy.handle_error(event, self._graph, retry_count)
case ErrorStrategyEnum.FAIL_BRANCH:
return self._fail_branch_strategy.handle_error(event, self._graph, retry_count)
case ErrorStrategyEnum.DEFAULT_VALUE:
return self._default_value_strategy.handle_error(event, self._graph, retry_count)

+ 3
- 2
api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py Ver fichero

@@ -2,7 +2,7 @@
Fail branch error strategy implementation.
"""

from typing import Optional
from typing import final

from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
from core.workflow.node_events import NodeRunResult


@final
class FailBranchStrategy:
"""
Error strategy that continues execution via a fail branch.
@@ -18,7 +19,7 @@ class FailBranchStrategy:
through a designated fail-branch edge.
"""

def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by taking the fail branch.


+ 3
- 2
api/core/workflow/graph_engine/error_handling/retry_strategy.py Ver fichero

@@ -3,12 +3,13 @@ Retry error strategy implementation.
"""

import time
from typing import Optional
from typing import final

from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent


@final
class RetryStrategy:
"""
Error strategy that retries failed nodes.
@@ -17,7 +18,7 @@ class RetryStrategy:
maximum number of retries with configurable intervals.
"""

def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle error by retrying the node.


+ 86
- 6
api/core/workflow/graph_engine/event_management/event_collector.py Ver fichero

@@ -3,12 +3,92 @@ Event collector for buffering and managing events.
"""

import threading
from typing import final

from core.workflow.graph_events import GraphEngineEvent

from ..layers.base import Layer


@final
class ReadWriteLock:
"""
A read-write lock implementation that allows multiple concurrent readers
but only one writer at a time.
"""

def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0

def acquire_read(self) -> None:
"""Acquire a read lock."""
self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()

def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
finally:
self._read_ready.release()

def acquire_write(self) -> None:
"""Acquire a write lock."""
self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()

def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()

def read_lock(self) -> "ReadLockContext":
"""Return a context manager for read locking."""
return ReadLockContext(self)

def write_lock(self) -> "WriteLockContext":
"""Return a context manager for write locking."""
return WriteLockContext(self)


@final
class ReadLockContext:
"""Context manager for read locks."""

def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock

def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self

def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()


@final
class WriteLockContext:
"""Context manager for write locks."""

def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock

def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self

def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()


@final
class EventCollector:
"""
Collects and buffers events for later retrieval.
@@ -20,7 +100,7 @@ class EventCollector:
def __init__(self) -> None:
"""Initialize the event collector."""
self._events: list[GraphEngineEvent] = []
self._lock = threading.Lock()
self._lock = ReadWriteLock()
self._layers: list[Layer] = []

def set_layers(self, layers: list[Layer]) -> None:
@@ -39,7 +119,7 @@ class EventCollector:
Args:
event: The event to collect
"""
with self._lock:
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)

@@ -50,7 +130,7 @@ class EventCollector:
Returns:
List of collected events
"""
with self._lock:
with self._lock.read_lock():
return list(self._events)

def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
@@ -63,7 +143,7 @@ class EventCollector:
Returns:
List of new events
"""
with self._lock:
with self._lock.read_lock():
return list(self._events[start_index:])

def event_count(self) -> int:
@@ -73,12 +153,12 @@ class EventCollector:
Returns:
Number of collected events
"""
with self._lock:
with self._lock.read_lock():
return len(self._events)

def clear(self) -> None:
"""Clear all collected events."""
with self._lock:
with self._lock.write_lock():
self._events.clear()

def _notify_layers(self, event: GraphEngineEvent) -> None:

+ 2
- 0
api/core/workflow/graph_engine/event_management/event_emitter.py Ver fichero

@@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers.
import threading
import time
from collections.abc import Generator
from typing import final

from core.workflow.graph_events import GraphEngineEvent

from .event_collector import EventCollector


@final
class EventEmitter:
"""
Emits collected events as a generator for external consumption.

+ 66
- 91
api/core/workflow/graph_engine/event_management/event_handlers.py Ver fichero

@@ -3,7 +3,7 @@ Event handler implementations for different event types.
"""

import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final

from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)


@final
class EventHandlerRegistry:
"""
Registry of event handlers for different event types.
@@ -52,12 +53,12 @@ class EventHandlerRegistry:
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: Optional["EventCollector"] = None,
branch_handler: Optional["BranchHandler"] = None,
edge_processor: Optional["EdgeProcessor"] = None,
node_state_manager: Optional["NodeStateManager"] = None,
execution_tracker: Optional["ExecutionTracker"] = None,
error_handler: Optional["ErrorHandler"] = None,
event_collector: "EventCollector",
branch_handler: "BranchHandler",
edge_processor: "EdgeProcessor",
node_state_manager: "NodeStateManager",
execution_tracker: "ExecutionTracker",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
@@ -67,23 +68,23 @@ class EventHandlerRegistry:
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Optional event collector for collecting events
branch_handler: Optional branch handler for branch node processing
edge_processor: Optional edge processor for edge traversal
node_state_manager: Optional node state manager
execution_tracker: Optional execution tracker
error_handler: Optional error handler
event_collector: Event collector for collecting events
branch_handler: Branch handler for branch node processing
edge_processor: Edge processor for edge traversal
node_state_manager: Node state manager
execution_tracker: Execution tracker
error_handler: Error handler
"""
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.graph_execution = graph_execution
self.response_coordinator = response_coordinator
self.event_collector = event_collector
self.branch_handler = branch_handler
self.edge_processor = edge_processor
self.node_state_manager = node_state_manager
self.execution_tracker = execution_tracker
self.error_handler = error_handler
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._branch_handler = branch_handler
self._edge_processor = edge_processor
self._node_state_manager = node_state_manager
self._execution_tracker = execution_tracker
self._error_handler = error_handler

def handle_event(self, event: GraphNodeEventBase) -> None:
"""
@@ -93,9 +94,8 @@ class EventHandlerRegistry:
event: The event to handle
"""
# Events in loops or iterations are always collected
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
if self.event_collector:
self.event_collector.collect(event)
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return

# Handle specific event types
@@ -125,12 +125,10 @@ class EventHandlerRegistry:
),
):
# Iteration and loop events are collected directly
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
else:
# Collect unhandled events
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)

def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
@@ -141,15 +139,14 @@ class EventHandlerRegistry:
event: The node started event
"""
# Track execution in domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_started(event.id)

# Track in response coordinator for stream ordering
self.response_coordinator.track_node_execution(event.node_id, event.id)
self._response_coordinator.track_node_execution(event.node_id, event.id)

# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)

def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
@@ -159,12 +156,11 @@ class EventHandlerRegistry:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self.response_coordinator.intercept_event(event))
streaming_events = list(self._response_coordinator.intercept_event(event))

# Collect all events
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)

def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
@@ -177,55 +173,44 @@ class EventHandlerRegistry:
event: The node succeeded event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()

# Store outputs in variable pool
self._store_node_outputs(event)

# Forward to response coordinator and emit streaming events
streaming_events = list(self.response_coordinator.intercept_event(event))
if self.event_collector:
for stream_event in streaming_events:
self.event_collector.collect(stream_event)
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)

# Process edges and get ready nodes
node = self.graph.nodes[event.node_id]
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
if self.branch_handler:
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
if self.edge_processor:
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
else:
ready_nodes, edge_streaming_events = [], []
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)

# Collect streaming events from edge processing
if self.event_collector:
for edge_event in edge_streaming_events:
self.event_collector.collect(edge_event)
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)

# Enqueue ready nodes
if self.node_state_manager and self.execution_tracker:
for node_id in ready_nodes:
self.node_state_manager.enqueue_node(node_id)
self.execution_tracker.add(node_id)
for node_id in ready_nodes:
self._node_state_manager.enqueue_node(node_id)
self._execution_tracker.add(node_id)

# Update execution tracking
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
self._execution_tracker.remove(event.node_id)

# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)

# Collect the event
if self.event_collector:
self.event_collector.collect(event)
self._event_collector.collect(event)

def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
@@ -235,29 +220,19 @@ class EventHandlerRegistry:
event: The node failed event
"""
# Update domain model
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)

if self.error_handler:
result = self.error_handler.handle_node_failure(event)
result = self._error_handler.handle_node_failure(event)

if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Abort execution
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
if result:
# Process the resulting event (retry, exception, etc.)
self.handle_event(result)
else:
# Without error handler, just fail
self.graph_execution.fail(RuntimeError(event.error))
if self.event_collector:
self.event_collector.collect(event)
if self.execution_tracker:
self.execution_tracker.remove(event.node_id)
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._execution_tracker.remove(event.node_id)

def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
"""
@@ -267,7 +242,7 @@ class EventHandlerRegistry:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()

def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
@@ -277,7 +252,7 @@ class EventHandlerRegistry:
Args:
event: The node retry event
"""
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()

def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
@@ -288,16 +263,16 @@ class EventHandlerRegistry:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)

def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self.graph_runtime_state.outputs.get("answer", "")
existing = self._graph_runtime_state.outputs.get("answer", "")
if existing:
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
else:
self.graph_runtime_state.outputs["answer"] = value
self._graph_runtime_state.outputs["answer"] = value
else:
self.graph_runtime_state.outputs[key] = value
self._graph_runtime_state.outputs[key] = value

+ 17
- 15
api/core/workflow/graph_engine/graph_engine.py Ver fichero

@@ -9,7 +9,7 @@ import contextvars
import logging
import queue
from collections.abc import Generator, Mapping
from typing import Any, Optional
from typing import final

from flask import Flask, current_app

@@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
@@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo
logger = logging.getLogger(__name__)


@final
class GraphEngine:
"""
Queue-based graph execution engine.
@@ -62,7 +64,7 @@ class GraphEngine:
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
@@ -103,7 +105,7 @@ class GraphEngine:

# Initialize queues
self.ready_queue: queue.Queue[str] = queue.Queue()
self.event_queue: queue.Queue = queue.Queue()
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

# Initialize subsystems
self._initialize_subsystems()
@@ -185,7 +187,7 @@ class GraphEngine:
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
worker_pool=self.worker_pool,
worker_pool=self._worker_pool,
)

self.dispatcher = Dispatcher(
@@ -209,7 +211,7 @@ class GraphEngine:
def _setup_worker_management(self) -> None:
"""Initialize worker management subsystem."""
# Capture context for workers
flask_app: Optional[Flask] = None
flask_app: Flask | None = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
@@ -218,8 +220,8 @@ class GraphEngine:
context_vars = contextvars.copy_context()

# Create worker management components
self.activity_tracker = ActivityTracker()
self.dynamic_scaler = DynamicScaler(
self._activity_tracker = ActivityTracker()
self._dynamic_scaler = DynamicScaler(
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
scale_up_threshold=(
@@ -233,15 +235,15 @@ class GraphEngine:
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
),
)
self.worker_factory = WorkerFactory(flask_app, context_vars)
self._worker_factory = WorkerFactory(flask_app, context_vars)

self.worker_pool = WorkerPool(
self._worker_pool = WorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_factory=self.worker_factory,
dynamic_scaler=self.dynamic_scaler,
activity_tracker=self.activity_tracker,
worker_factory=self._worker_factory,
dynamic_scaler=self._dynamic_scaler,
activity_tracker=self._activity_tracker,
)

def _validate_graph_state_consistency(self) -> None:
@@ -319,10 +321,10 @@ class GraphEngine:
def _start_execution(self) -> None:
"""Start execution subsystems."""
# Calculate initial worker count
initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph)
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)

# Start worker pool
self.worker_pool.start(initial_workers)
self._worker_pool.start(initial_workers)

# Register response nodes
for node in self.graph.nodes.values():
@@ -340,7 +342,7 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self.dispatcher.stop()
self.worker_pool.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

# Notify layers

+ 9
- 4
api/core/workflow/graph_engine/graph_traversal/branch_handler.py Ver fichero

@@ -2,15 +2,18 @@
Branch node handling for graph traversal.
"""

from typing import Optional
from collections.abc import Sequence
from typing import final

from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent

from ..state_management import EdgeStateManager
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator


@final
class BranchHandler:
"""
Handles branch node logic during graph traversal.
@@ -40,7 +43,9 @@ class BranchHandler:
self.skip_propagator = skip_propagator
self.edge_state_manager = edge_state_manager

def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.

@@ -58,10 +63,10 @@ class BranchHandler:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")

# Categorize edges into selected and unselected
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)

# Skip all unselected paths
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
self.skip_propagator.skip_branch_paths(unselected_edges)

# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)

+ 19
- 10
api/core/workflow/graph_engine/graph_traversal/edge_processor.py Ver fichero

@@ -2,13 +2,18 @@
Edge processing logic for graph traversal.
"""

from collections.abc import Sequence
from typing import final

from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent

from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager


@final
class EdgeProcessor:
"""
Processes edges during graph execution.
@@ -38,7 +43,9 @@ class EdgeProcessor:
self.node_state_manager = node_state_manager
self.response_coordinator = response_coordinator

def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.

@@ -56,7 +63,7 @@ class EdgeProcessor:
else:
return self._process_non_branch_node_edges(node_id)

def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).

@@ -66,8 +73,8 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)

for edge in outgoing_edges:
@@ -77,7 +84,9 @@ class EdgeProcessor:

return ready_nodes, all_streaming_events

def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.

@@ -94,8 +103,8 @@ class EdgeProcessor:
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")

ready_nodes = []
all_streaming_events = []
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []

# Categorize edges
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
@@ -112,7 +121,7 @@ class EdgeProcessor:

return ready_nodes, all_streaming_events

def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.

@@ -129,11 +138,11 @@ class EdgeProcessor:
streaming_events = self.response_coordinator.on_edge_taken(edge.id)

# Check if downstream node is ready
ready_nodes = []
ready_nodes: list[str] = []
if self.node_state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)

return ready_nodes, list(streaming_events)
return ready_nodes, streaming_events

def _process_skipped_edge(self, edge: Edge) -> None:
"""

+ 4
- 1
api/core/workflow/graph_engine/graph_traversal/node_readiness.py Ver fichero

@@ -2,10 +2,13 @@
Node readiness checking for execution.
"""

from typing import final

from core.workflow.enums import NodeState
from core.workflow.graph import Graph


@final
class NodeReadinessChecker:
"""
Checks if nodes are ready for execution based on their dependencies.
@@ -71,7 +74,7 @@ class NodeReadinessChecker:
Returns:
List of node IDs that are now ready
"""
ready_nodes = []
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)

for edge in outgoing_edges:

+ 8
- 6
api/core/workflow/graph_engine/graph_traversal/skip_propagator.py Ver fichero

@@ -2,11 +2,15 @@
Skip state propagation through the graph.
"""

from core.workflow.graph import Graph
from collections.abc import Sequence
from typing import final

from core.workflow.graph import Edge, Graph

from ..state_management import EdgeStateManager, NodeStateManager


@final
class SkipPropagator:
"""
Propagates skip states through the graph.
@@ -57,9 +61,8 @@ class SkipPropagator:

# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Check if node is ready and enqueue if so
if self.node_state_manager.is_node_ready(downstream_node_id):
self.node_state_manager.enqueue_node(downstream_node_id)
# Enqueue node
self.node_state_manager.enqueue_node(downstream_node_id)
return

# All edges are skipped, propagate skip to this node
@@ -83,12 +86,11 @@ class SkipPropagator:
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)

def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.

Args:
node_id: The ID of the branch node
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:

+ 3
- 4
api/core/workflow/graph_engine/layers/base.py Ver fichero

@@ -6,7 +6,6 @@ intercept and respond to GraphEngine events.
"""

from abc import ABC, abstractmethod
from typing import Optional

from core.workflow.entities import GraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
@@ -28,8 +27,8 @@ class Layer(ABC):

def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: Optional[GraphRuntimeState] = None
self.command_channel: Optional[CommandChannel] = None
self.graph_runtime_state: GraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None

def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
"""
@@ -73,7 +72,7 @@ class Layer(ABC):
pass

@abstractmethod
def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.


+ 3
- 2
api/core/workflow/graph_engine/layers/debug_logging.py Ver fichero

@@ -7,7 +7,7 @@ graph execution for debugging purposes.

import logging
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, final

from core.workflow.graph_events import (
GraphEngineEvent,
@@ -34,6 +34,7 @@ from core.workflow.graph_events import (
from .base import Layer


@final
class DebugLoggingLayer(Layer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
@@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer):
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)

def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)


+ 4
- 3
api/core/workflow/graph_engine/layers/execution_limits.py Ver fichero

@@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution.
import logging
import time
from enum import Enum
from typing import Optional
from typing import final

from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
@@ -29,6 +29,7 @@ class LimitType(Enum):
TIME_LIMIT = "time_limit"


@final
class ExecutionLimitsLayer(Layer):
"""
Layer that enforces execution limits for workflows.
@@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer):
self.max_time = max_time

# Runtime tracking
self.start_time: Optional[float] = None
self.start_time: float | None = None
self.step_count = 0
self.logger = logging.getLogger(__name__)

@@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer):
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)

def on_graph_end(self, error: Optional[Exception]) -> None:
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True

+ 3
- 2
api/core/workflow/graph_engine/manager.py Ver fichero

@@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""

from typing import Optional
from typing import final

from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from extensions.ext_redis import redis_client


@final
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
@@ -23,7 +24,7 @@ class GraphEngineManager:
"""

@staticmethod
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
def send_stop_command(task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.


+ 8
- 5
api/core/workflow/graph_engine/orchestration/dispatcher.py Ver fichero

@@ -6,7 +6,9 @@ import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final

from core.workflow.graph_events.base import GraphNodeEventBase

from ..event_management import EventCollector, EventEmitter
from .execution_coordinator import ExecutionCoordinator
@@ -17,6 +19,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)


@final
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
@@ -27,12 +30,12 @@ class Dispatcher:

def __init__(
self,
event_queue: queue.Queue,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
execution_coordinator: ExecutionCoordinator,
max_execution_time: int,
event_emitter: Optional[EventEmitter] = None,
event_emitter: EventEmitter | None = None,
) -> None:
"""
Initialize the dispatcher.
@@ -52,9 +55,9 @@ class Dispatcher:
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter

self._thread: Optional[threading.Thread] = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._start_time: Optional[float] = None
self._start_time: float | None = None

def start(self) -> None:
"""Start the dispatcher thread."""

+ 2
- 1
api/core/workflow/graph_engine/orchestration/execution_coordinator.py Ver fichero

@@ -2,7 +2,7 @@
Execution coordinator for managing overall workflow execution.
"""

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, final

from ..command_processing import CommandProcessor
from ..domain import GraphExecution
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry


@final
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.

+ 4
- 3
api/core/workflow/graph_engine/output_registry/registry.py Ver fichero

@@ -7,7 +7,7 @@ thread-safe storage for node outputs.

from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Union, final

from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.
@@ -47,7 +48,7 @@ class OutputRegistry:
with self._lock:
self._scalars.add(selector, value)

def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.

@@ -81,7 +82,7 @@ class OutputRegistry:
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")

def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.


+ 3
- 2
api/core/workflow/graph_engine/output_registry/stream.py Ver fichero

@@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, final

if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
@@ -41,7 +42,7 @@ class Stream:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)

def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.


api/core/workflow/graph_engine/error_handling/error_strategy.py → api/core/workflow/graph_engine/protocols/error_strategy.py Ver fichero

@@ -2,7 +2,7 @@
Base error strategy protocol.
"""

from typing import Optional, Protocol
from typing import Protocol

from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
@@ -16,7 +16,7 @@ class ErrorStrategy(Protocol):
node execution failures.
"""

def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
"""
Handle a node failure event.


+ 3
- 2
api/core/workflow/graph_engine/response_coordinator/coordinator.py Ver fichero

@@ -9,7 +9,7 @@ import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Optional, TypeAlias
from typing import TypeAlias, final
from uuid import uuid4

from core.workflow.enums import NodeExecutionType, NodeState
@@ -28,6 +28,7 @@ NodeID: TypeAlias = str
EdgeID: TypeAlias = str


@final
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
@@ -45,7 +46,7 @@ class ResponseStreamCoordinator:
"""
self.registry = registry
self.graph = graph
self.active_session: Optional[ResponseSession] = None
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()


+ 6
- 4
api/core/workflow/graph_engine/state_management/edge_state_manager.py Ver fichero

@@ -3,7 +3,8 @@ Manager for edge states during graph execution.
"""

import threading
from typing import TypedDict
from collections.abc import Sequence
from typing import TypedDict, final

from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
@@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict):
all_skipped: bool


@final
class EdgeStateManager:
"""
Manages edge states and transitions during graph execution.
@@ -87,7 +89,7 @@ class EdgeStateManager:
with self._lock:
return self.graph.edges[edge_id].state

def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.

@@ -100,8 +102,8 @@ class EdgeStateManager:
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges = []
unselected_edges = []
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []

for edge in outgoing_edges:
if edge.source_handle == selected_handle:

+ 2
- 0
api/core/workflow/graph_engine/state_management/execution_tracker.py Ver fichero

@@ -3,8 +3,10 @@ Tracker for currently executing nodes.
"""

import threading
from typing import final


@final
class ExecutionTracker:
"""
Tracks nodes that are currently being executed.

+ 2
- 0
api/core/workflow/graph_engine/state_management/node_state_manager.py Ver fichero

@@ -4,11 +4,13 @@ Manager for node states during graph execution.

import queue
import threading
from typing import final

from core.workflow.enums import NodeState
from core.workflow.graph import Graph


@final
class NodeStateManager:
"""
Manages node states and the ready queue for execution.

+ 6
- 5
api/core/workflow/graph_engine/worker.py Ver fichero

@@ -11,7 +11,7 @@ import threading
import time
from collections.abc import Callable
from datetime import datetime
from typing import Optional
from typing import final
from uuid import uuid4

from flask import Flask
@@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts


@final
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
@@ -38,10 +39,10 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_id: int = 0,
flask_app: Optional[Flask] = None,
context_vars: Optional[contextvars.Context] = None,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> None:
"""
Initialize worker thread.

+ 2
- 0
api/core/workflow/graph_engine/worker_management/activity_tracker.py Ver fichero

@@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity.

import threading
import time
from typing import final


@final
class ActivityTracker:
"""
Tracks worker activity for scaling decisions.

+ 3
- 0
api/core/workflow/graph_engine/worker_management/dynamic_scaler.py Ver fichero

@@ -2,9 +2,12 @@
Dynamic scaler for worker pool sizing.
"""

from typing import final

from core.workflow.graph import Graph


@final
class DynamicScaler:
"""
Manages dynamic scaling decisions for the worker pool.

+ 5
- 4
api/core/workflow/graph_engine/worker_management/worker_factory.py Ver fichero

@@ -5,7 +5,7 @@ Factory for creating worker instances.
import contextvars
import queue
from collections.abc import Callable
from typing import Optional
from typing import final

from flask import Flask

@@ -14,6 +14,7 @@ from core.workflow.graph import Graph
from ..worker import Worker


@final
class WorkerFactory:
"""
Factory for creating worker instances with proper context.
@@ -24,7 +25,7 @@ class WorkerFactory:

def __init__(
self,
flask_app: Optional[Flask],
flask_app: Flask | None,
context_vars: contextvars.Context,
) -> None:
"""
@@ -43,8 +44,8 @@ class WorkerFactory:
ready_queue: queue.Queue[str],
event_queue: queue.Queue,
graph: Graph,
on_idle_callback: Optional[Callable[[int], None]] = None,
on_active_callback: Optional[Callable[[int], None]] = None,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> Worker:
"""
Create a new worker instance.

+ 2
- 0
api/core/workflow/graph_engine/worker_management/worker_pool.py Ver fichero

@@ -4,6 +4,7 @@ Worker pool management.

import queue
import threading
from typing import final

from core.workflow.graph import Graph

@@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory


@final
class WorkerPool:
"""
Manages a pool of worker threads for executing nodes.

+ 2
- 1
api/core/workflow/nodes/start/start_node.py Ver fichero

@@ -2,7 +2,7 @@ from collections.abc import Mapping
from typing import Any, Optional

from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
@@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData

class StartNode(Node):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT

_node_data: StartNodeData


+ 1
- 1
api/extensions/ext_storage.py Ver fichero

@@ -65,7 +65,7 @@ class Storage:
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage

return VolcengineTosStorage
case StorageType.SUPBASE:
case StorageType.SUPABASE:
from extensions.storage.supabase_storage import SupabaseStorage

return SupabaseStorage

+ 1
- 1
api/extensions/storage/storage_type.py Ver fichero

@@ -14,4 +14,4 @@ class StorageType(StrEnum):
S3 = "s3"
TENCENT_COS = "tencent-cos"
VOLCENGINE_TOS = "volcengine-tos"
SUPBASE = "supabase"
SUPABASE = "supabase"

+ 0
- 4
api/factories/variable_factory.py Ver fichero

@@ -137,10 +137,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
return cast(Variable, result)


def infer_segment_type_from_value(value: Any, /) -> SegmentType:
return build_segment(value).value_type


def build_segment(value: Any, /) -> Segment:
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
# below

+ 5
- 5
api/libs/helper.py Ver fichero

@@ -301,8 +301,8 @@ class TokenManager:
if expiry_minutes is None:
raise ValueError(f"Expiry minutes for {token_type} token is not set")
token_key = cls._get_token_key(token, token_type)
expiry_time = int(expiry_minutes * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(token_key, expiry_seconds, json.dumps(token_data))

if account_id:
cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
@@ -336,11 +336,11 @@ class TokenManager:

@classmethod
def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type)
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(key, expiry_time, token)
expiry_seconds = int(expiry_minutes * 60)
redis_client.setex(key, expiry_seconds, token)

@classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:

+ 2
- 2
api/models/model.py Ver fichero

@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column

from configs import dify_config
@@ -1556,7 +1556,7 @@ class ApiToken(Base):
def generate_api_key(prefix, n):
while True:
result = prefix + generate_string(n)
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
if db.session.scalar(select(exists().where(ApiToken.token == result))):
continue
return result


+ 8
- 7
api/models/workflow.py Ver fichero

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4

import sqlalchemy as sa
from sqlalchemy import DateTime, orm
from sqlalchemy import DateTime, exists, orm, select

from core.file.constants import maybe_file_object
from core.file.models import File
@@ -348,12 +348,13 @@ class Workflow(Base):
"""
from models.tools import WorkflowToolProvider

return (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count()
> 0
stmt = select(
exists().where(
WorkflowToolProvider.tenant_id == self.tenant_id,
WorkflowToolProvider.app_id == self.app_id,
)
)
return db.session.execute(stmt).scalar_one()

@property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
@@ -952,7 +953,7 @@ def _naive_utc_datetime():

class WorkflowDraftVariable(Base):
"""`WorkflowDraftVariable` record variables and outputs generated during
debugging worfklow or chatflow.
debugging workflow or chatflow.

IMPORTANT: This model maintains multiple invariant rules that must be preserved.
Do not instantiate this class directly with the constructor.

+ 3
- 5
api/services/dataset_service.py Ver fichero

@@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Literal, Optional

from flask_login import current_user
from sqlalchemy import func, select
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound

@@ -845,10 +845,8 @@ class DatasetService:

@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
return db.session.execute(stmt).scalar_one()

@staticmethod
def check_dataset_permission(dataset, user):

+ 17
- 10
api/services/tools/builtin_tools_manage_service.py Ver fichero

@@ -4,6 +4,7 @@ from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional

from sqlalchemy import exists, select
from sqlalchemy.orm import Session

from configs import dify_config
@@ -190,11 +191,14 @@ class BuiltinToolManageService:
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")

@@ -246,11 +250,14 @@ class BuiltinToolManageService:
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
if session.scalar(
select(
exists().where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.name == name,
)
)
):
raise ValueError(f"the credential name '{name}' is already used")


+ 5
- 6
api/services/workflow_service.py Ver fichero

@@ -4,7 +4,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker

from core.app.app_config.entities import VariableEntityType
@@ -83,15 +83,14 @@ class WorkflowService:
)

def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
.where(
stmt = select(
exists().where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.count()
) > 0
)
return db.session.execute(stmt).scalar_one()

def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""

+ 3
- 2
api/tasks/annotation/disable_annotation_reply_task.py Ver fichero

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import exists, select

from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
@@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
)

try:
if annotations_count > 0:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:

+ 11
- 11
api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py Ver fichero

@@ -39,7 +39,7 @@ def test_page_result(text, cursor, maxlen, expected):
# Tests: get_url
# ---------------------------
@pytest.fixture
def stub_support_types(monkeypatch):
def stub_support_types(monkeypatch: pytest.MonkeyPatch):
"""Stub supported content types list."""
import core.tools.utils.web_reader_tool as mod

@@ -48,7 +48,7 @@ def stub_support_types(monkeypatch):
return mod


def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
# HEAD 200 but content-type not supported and not text/html
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
return FakeResponse(
@@ -62,7 +62,7 @@ def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
assert result == "Unsupported content-type [image/png] of URL."


def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types):
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
When content-type is in SUPPORT_URL_CONTENT_TYPES,
should call ExtractProcessor.load_from_url and return its text.
@@ -88,7 +88,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_
assert result == "PDF extracted text"


def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types):
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""

def fake_head(url, headers=None, follow_redirects=True, timeout=None):
@@ -121,7 +121,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_suppor
assert "Hello world" in out


def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types):
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""If readability returns no text, should return empty string."""

def fake_head(url, headers=None, follow_redirects=True, timeout=None):
@@ -142,7 +142,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_su
assert out == ""


def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""

def fake_head(url, headers=None, follow_redirects=True, timeout=None):
@@ -175,7 +175,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
assert "X" in out


def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""HEAD returns non-200 and non-403 → should directly return code message."""

def fake_head(url, headers=None, follow_redirects=True, timeout=None):
@@ -189,7 +189,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
assert out == "URL returned status code 500."


def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types):
def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
it should route to ExtractProcessor.load_from_url.
@@ -213,7 +213,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch, stub_suppor
assert out == "From ExtractProcessor via filename"


def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types):
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types):
"""
If chardet returns an encoding but content.decode raises, should fallback to response.text.
"""
@@ -250,7 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_supp
# ---------------------------


def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch):
# stub readabilipy.simple_json_from_html_string
def fake_simple_json_from_html_string(html, use_readability=True):
return {
@@ -271,7 +271,7 @@ def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
assert article.text[0]["text"] == "world"


def test_extract_using_readabilipy_defaults_when_missing(monkeypatch):
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch):
def fake_simple_json_from_html_string(html, use_readability=True):
return {} # all missing


+ 2
- 2
api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py Ver fichero

@@ -8,7 +8,7 @@ from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool


def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
@@ -40,7 +40,7 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)

with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to

+ 281
- 0
api/tests/unit_tests/core/workflow/graph/test_graph.py Ver fichero

@@ -0,0 +1,281 @@
"""Unit tests for Graph class methods."""

from unittest.mock import Mock

from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.graph.edge import Edge
from core.workflow.graph.graph import Graph
from core.workflow.nodes.base.node import Node


def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
"""Create a mock node for testing."""
node = Mock(spec=Node)
node.id = node_id
node.execution_type = execution_type
node.state = state
node.node_type = NodeType.START
return node


class TestMarkInactiveRootBranches:
"""Test cases for _mark_inactive_root_branches method."""

def test_single_root_no_marking(self):
"""Test that single root graph doesn't mark anything as skipped."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
}

in_edges = {"child1": ["edge1"]}
out_edges = {"root1": ["edge1"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["child1"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN

def test_multiple_roots_mark_inactive(self):
"""Test marking inactive root branches with multiple root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
}

in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED

def test_shared_downstream_node(self):
"""Test that shared downstream nodes are not skipped if at least one path is active."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
}

in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"shared": ["edge3", "edge4"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"child1": ["edge3"],
"child2": ["edge4"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED

def test_deep_branch_marking(self):
"""Test marking deep branches with multiple levels."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
}

in_edges = {
"level1_a": ["edge1"],
"level1_b": ["edge2"],
"level2_a": ["edge3"],
"level2_b": ["edge4"],
"level3": ["edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"level1_a": ["edge3"],
"level1_b": ["edge4"],
"level2_b": ["edge5"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["level1_a"].state == NodeState.UNKNOWN
assert nodes["level1_b"].state == NodeState.SKIPPED
assert nodes["level2_a"].state == NodeState.UNKNOWN
assert nodes["level2_b"].state == NodeState.SKIPPED
assert nodes["level3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED
assert edges["edge5"].state == NodeState.SKIPPED

def test_non_root_execution_type(self):
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
}

in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.UNKNOWN

def test_empty_graph(self):
"""Test handling of empty graph structures."""
nodes = {}
edges = {}
in_edges = {}
out_edges = {}

# Should not raise any errors
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")

def test_three_roots_mark_two_inactive(self):
"""Test with three root nodes where two should be marked inactive."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
}

in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"child3": ["edge3"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")

assert nodes["root1"].state == NodeState.SKIPPED
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.SKIPPED
assert nodes["child2"].state == NodeState.UNKNOWN
assert nodes["child3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.SKIPPED
assert edges["edge2"].state == NodeState.UNKNOWN
assert edges["edge3"].state == NodeState.SKIPPED

def test_convergent_paths(self):
"""Test convergent paths where multiple inactive branches lead to same node."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
}

in_edges = {
"mid1": ["edge1"],
"mid2": ["edge2"],
"convergent": ["edge3", "edge4", "edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
"mid1": ["edge4"],
"mid2": ["edge5"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["mid1"].state == NodeState.UNKNOWN
assert nodes["mid2"].state == NodeState.SKIPPED
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.SKIPPED
assert edges["edge4"].state == NodeState.UNKNOWN
assert edges["edge5"].state == NodeState.SKIPPED

+ 1
- 1
api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py Ver fichero

@@ -21,7 +21,6 @@ from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowTestCase


@pytest.mark.skip
class TestComplexBranchWorkflow:
"""Test suite for complex branch workflow with parallel execution."""

@@ -30,6 +29,7 @@ class TestComplexBranchWorkflow:
self.runner = TableTestRunner()
self.fixture_path = "test_complex_branch"

@pytest.mark.skip(reason="output in this workflow can be random")
def test_hello_branch_with_llm(self):
"""
Test when query contains 'hello' - should trigger true branch.

+ 7
- 3
api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py Ver fichero

@@ -12,7 +12,7 @@ This module provides a robust table-driven testing framework with support for:

import logging
import time
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
@@ -34,7 +34,11 @@ from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@@ -57,7 +61,7 @@ class WorkflowTestCase:
timeout: float = 30.0
mock_config: Optional[MockConfig] = None
use_auto_mock: bool = False
expected_event_sequence: Optional[list[type[GraphEngineEvent]]] = None
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
tags: list[str] = field(default_factory=list)
skip: bool = False
skip_reason: str = ""

+ 6
- 7
api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py Ver fichero

@@ -9,13 +9,6 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
from .test_table_runner import TableTestRunner, WorkflowTestCase


def mock_template_transform_run(self):
"""Mock the TemplateTransformNode._run() method to return results based on node title."""
title = self._node_data.title
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title})


@pytest.mark.skip
class TestVariableAggregator:
"""Test cases for the variable aggregator workflow."""

@@ -37,6 +30,12 @@ class TestVariableAggregator:
description: str,
) -> None:
"""Test all four combinations of switch1 and switch2."""

def mock_template_transform_run(self):
"""Mock the TemplateTransformNode._run() method to return results based on node title."""
title = self._node_data.title
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title})

with patch.object(
TemplateTransformNode,
"_run",

+ 0
- 353
api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py Ver fichero

@@ -1,353 +0,0 @@
import httpx
import pytest

from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileVariable, FileVariable
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.entities import EndStreamParam
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNode,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom


@pytest.mark.skip(
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
"needs rewrite for new architecture"
)
def test_http_request_node_binary_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="binary",
data=[
BodyData(
key="file",
type="file",
value="",
file=["1111", "file"],
)
],
),
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)

node_config = {
"id": "1",
"data": data.model_dump(),
}

node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)

# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)
monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == "test"


@pytest.mark.skip(
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
"needs rewrite for new architecture"
)
def test_http_request_node_form_with_file(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/post",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="file",
type="file",
file=["1111", "file"],
),
BodyData(
key="name",
type="text",
value="test",
),
],
),
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
["1111", "file"],
FileVariable(
name="file",
value=File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)

node_config = {
"id": "1",
"data": data.model_dump(),
}

node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)

# Initialize node data
node.init_node_data(node_config["data"])

monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
)

def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}
assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))]
return httpx.Response(200, content=b"")

monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
attr_checker,
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == ""


@pytest.mark.skip(
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
"needs rewrite for new architecture"
)
def test_http_request_node_form_with_multiple_files(monkeypatch):
data = HttpRequestNodeData(
title="test",
method="post",
url="http://example.org/upload",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="",
params="",
body=HttpRequestNodeBody(
type="form-data",
data=[
BodyData(
key="files",
type="file",
file=["1111", "files"],
),
BodyData(
key="name",
type="text",
value="test",
),
],
),
)

variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)

files = [
File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="image1.jpg",
mime_type="image/jpeg",
storage_key="",
),
File(
tenant_id="1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file2",
filename="document.pdf",
mime_type="application/pdf",
storage_key="",
),
]

variable_pool.add(
["1111", "files"],
ArrayFileVariable(
name="files",
value=files,
),
)

node_config = {
"id": "1",
"data": data.model_dump(),
}

node = HttpRequestNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)

# Initialize node data
node.init_node_data(node_config["data"])

monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
)

def attr_checker(*args, **kwargs):
assert kwargs["data"] == {"name": "test"}

assert len(kwargs["files"]) == 2
assert kwargs["files"][0][0] == "files"
assert kwargs["files"][1][0] == "files"

file_tuples = [f[1] for f in kwargs["files"]]
file_contents = [f[1] for f in file_tuples]
file_types = [f[2] for f in file_tuples]

assert b"test_image_data" in file_contents
assert b"test_pdf_data" in file_contents
assert "image/jpeg" in file_types
assert "application/pdf" in file_types

return httpx.Response(200, content=b'{"status":"success"}')

monkeypatch.setattr(
"core.helper.ssrf_proxy.post",
attr_checker,
)

result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == '{"status":"success"}'
print(result.outputs["body"])

+ 0
- 0
api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py Ver fichero


+ 0
- 909
api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py Ver fichero

@@ -1,909 +0,0 @@
import time
import uuid
from unittest.mock import patch

import pytest

from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom


@pytest.mark.skip(
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
)
def test_run():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}

init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

# construct variable pool
pool = VariablePool(
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])

graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)

node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}

iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
)

# Initialize node data
iteration_node.init_node_data(node_config["data"])

def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)

with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
result = iteration_node._run()

count = 0
for item in result:
# print(type(item), item)
count += 1
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}

assert count == 20


@pytest.mark.skip(
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
)
def test_run_parallel():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "iteration-start-source-tt-target",
"source": "iteration-start",
"target": "tt",
},
{
"id": "iteration-start-source-tt-2-target",
"source": "iteration-start",
"target": "tt-2",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "tt-2-source-if-else-target",
"source": "tt-2",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt-2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}

init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

# construct variable pool
pool = VariablePool(
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)

graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])

node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}

iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
)

# Initialize node data
iteration_node.init_node_data(node_config["data"])

def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)

with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
result = iteration_node._run()

count = 0
for item in result:
count += 1
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}

assert count == 32


@pytest.mark.skip(
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
)
def test_iteration_run_in_parallel_mode():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "iteration-start-source-tt-target",
"source": "iteration-start",
"target": "tt",
},
{
"id": "iteration-start-source-tt-2-target",
"source": "iteration-start",
"target": "tt-2",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
},
{
"id": "tt-2-source-if-else-target",
"source": "tt-2",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
},
"id": "tt-2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
"id": "answer-4",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}

init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

# construct variable pool
pool = VariablePool(
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)

graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])

parallel_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}

parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=parallel_node_config,
)

# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}

sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=sequential_node_config,
)

# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])

def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"iterator_selector": "dify"},
outputs={"output": "dify 123"},
)

with patch.object(TemplateTransformNode, "_run", new=tt_generator):
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
for item in parallel_result:
count += 1
parallel_arr.append(item)
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 32

for item in sequential_result:
sequential_arr.append(item)
count += 1
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
assert count == 64


@pytest.mark.skip(
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
)
def test_iteration_run_error_handle():
graph_config = {
"edges": [
{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
},
{
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "tt-source-if-else-target",
"source": "iteration-start",
"target": "if-else",
},
{
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "tt",
},
{
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "tt2",
},
{
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
},
],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt2", "output"],
"output_type": "array[string]",
"start_node_id": "if-else",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1.split(arg2) }}",
"title": "template transform",
"type": "template-transform",
"variables": [
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
{"value_selector": ["iteration-1", "index"], "variable": "arg2"},
],
},
"id": "tt",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }}",
"title": "template transform",
"type": "template-transform",
"variables": [
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
],
},
"id": "tt2",
},
{
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
"id": "answer-3",
},
{
"data": {
"iteration_id": "iteration-1",
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"conditions": [
{
"comparison_operator": "is",
"id": "1721916275284",
"value": "1",
"variable_selector": ["iteration-1", "item"],
}
],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else",
},
"id": "if-else",
},
{
"data": {
"instruction": "test1",
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"parameters": [
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor",
},
"id": "pe",
},
],
}

init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

# construct variable pool
pool = VariablePool(
system_variables=SystemVariable(
user_id="1",
files=[],
query="dify",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
)

graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
pool.add(["pe", "list_output"], ["1", "1"])
error_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
}

iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=error_node_config,
)

# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
count = 0
for item in result:
result_arr.append(item)
count += 1
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}

assert count == 14
# execute remove abnormal output
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:
count += 1
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[])}
assert count == 14

+ 0
- 624
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py Ver fichero

@@ -1,624 +0,0 @@
import time
from unittest.mock import patch

import pytest

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom


class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
"data": {
"outputs": {"result": {"type": "number"}},
"error_strategy": error_strategy,
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
**retry_config,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_http_node(
error_strategy: str = "fail-branch",
default_value: dict | None = None,
authorization_success: bool = False,
retry_config: dict = {},
):
"""Helper method to create a http node configuration"""
authorization = (
{
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
}
if authorization_success
else {
"type": "api-key",
# missing config field
}
)
node = {
"id": "node",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": authorization,
"headers": "X-Header:123",
"params": "A:b",
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
**retry_config,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a http node configuration"""
node = {
"id": "node",
"data": {
"type": "http-request",
"title": "HTTP Request",
"desc": "",
"variables": [],
"method": "get",
"url": "https://api.github.com/issues",
"authorization": {"type": "no-auth", "config": None},
"headers": "",
"params": "",
"body": {"type": "none", "data": []},
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a tool node configuration"""
node = {
"id": "node",
"data": {
"title": "a",
"desc": "a",
"provider_id": "maths",
"provider_type": "builtin",
"provider_name": "maths",
"tool_name": "eval_expression",
"tool_label": "eval_expression",
"tool_configurations": {},
"tool_parameters": {
"expression": {
"type": "variable",
"value": ["1", "123", "args1"],
}
},
"type": "tool",
"error_strategy": error_strategy,
},
}
if default_value:
node.node_data.default_value = default_value
return node

@staticmethod
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a llm node configuration"""
node = {
"id": "node",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node

@staticmethod
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
# Create graph initialization parameters
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)

variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="aaa",
files=[],
query="clear",
conversation_id="abababa",
),
user_inputs=user_inputs or {"uid": "takato"},
)

graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(init_params, graph_runtime_state)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)

return GraphEngine(
tenant_id="111",
app_id="222",
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
command_channel=InMemoryChannel(),
)


DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]

FAIL_BRANCH_EDGES = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-true-success-target",
"source": "node",
"target": "success",
"sourceHandle": "source",
},
{
"id": "node-false-error-target",
"source": "node",
"target": "error",
"sourceHandle": "fail-branch",
},
]


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_code_default_value_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""

graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_code_fail_branch_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""

graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_code_node(error_code),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
)


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_http_node_default_value_continue_on_error():
"""Test HTTP node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


# def test_tool_node_default_value_continue_on_error():
# """Test tool node with default value error strategy"""
# graph_config = {
# "edges": DEFAULT_VALUE_EDGE,
# "nodes": [
# {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
# {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
# ContinueOnErrorTestHelper.get_tool_node(
# "default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
# ),
# ],
# }

# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
# events = list(graph_engine.run())

# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
# assert any(
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501
# )
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


# def test_tool_node_fail_branch_continue_on_error():
# """Test HTTP node with fail-branch error strategy"""
# graph_config = {
# "edges": FAIL_BRANCH_EDGES,
# "nodes": [
# {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
# {
# "data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
# "id": "success",
# },
# {
# "data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
# "id": "error",
# },
# ContinueOnErrorTestHelper.get_tool_node(),
# ],
# }

# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
# events = list(graph_engine.run())

# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
# assert any(
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501
# )
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_llm_node_default_value_continue_on_error():
"""Test LLM node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_llm_node(
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_llm_node_fail_branch_continue_on_error():
"""Test LLM node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_status_code_error_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_variable_pool_error_type_variable():
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
ContinueOnErrorTestHelper.get_http_node(),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())

assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0


@pytest.mark.skip(
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
"not fully implemented in MVP of queue-based engine"
)
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)

def llm_generator(self):
contents = ["hi", "bye", "good morning"]

yield NodeRunStreamChunkEvent(
node_id=self.node_id,
node_type=self._node_type,
selector=[self.node_id, "text"],
chunk=contents[0],
is_final=False,
)

yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 1,
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
},
)
)

with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

+ 0
- 116
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py Ver fichero

@@ -1,116 +0,0 @@
from collections.abc import Generator

import pytest

from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.entities import EndStreamParam
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.system_variable import SystemVariable
from models import UserFrom


def _create_tool_node():
data = ToolNodeData(
title="Test Tool",
tool_parameters={},
provider_id="test_tool",
provider_type=ToolProviderType.WORKFLOW,
provider_name="test tool",
tool_name="test tool",
tool_label="test tool",
tool_configurations={},
plugin_unique_identifier=None,
desc="Exception handling test tool",
error_strategy=ErrorStrategy.FAIL_BRANCH,
version="1",
)
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = ToolNode(
id="1",
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node


class MockToolRuntime:
def get_merged_runtime_parameters(self):
pass


def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
yield from []
raise ToolInvokeError("oops")


@pytest.mark.skip(
reason="Tool node test uses old Graph constructor incompatible with new queue-based engine - "
"needs rewrite for new architecture"
)
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
"""Ensure that ToolNode can handle ToolInvokeError when transforming
messages generated by ToolEngine.generic_invoke.
"""
tool_node = _create_tool_node()

# Need to patch ToolManager and ToolEngine so that we don't
# have to set up a database.
monkeypatch.setattr(
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
)
monkeypatch.setattr(
"core.tools.tool_engine.ToolEngine.generic_invoke",
lambda *args, **kwargs: mock_message_stream(),
)

streams = list(tool_node._run())
assert len(streams) == 1
stream = streams[0]
assert isinstance(stream, StreamCompletedEvent)
result = stream.node_run_result
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "oops" in result.error
assert "Failed to invoke tool" in result.error
assert result.error_type == "ToolInvokeError"

+ 18
- 2
sdks/nodejs-client/index.d.ts Ver fichero

@@ -14,6 +14,22 @@ interface HeaderParams {
interface User {
}

interface DifyFileBase {
type: "image"
}

export interface DifyRemoteFile extends DifyFileBase {
transfer_method: "remote_url"
url: string
}

export interface DifyLocalFile extends DifyFileBase {
transfer_method: "local_file"
upload_file_id: string
}

export type DifyFile = DifyRemoteFile | DifyLocalFile;

export declare class DifyClient {
constructor(apiKey: string, baseUrl?: string);

@@ -44,7 +60,7 @@ export declare class CompletionClient extends DifyClient {
inputs: any,
user: User,
stream?: boolean,
files?: File[] | null
files?: DifyFile[] | null
): Promise<any>;
}

@@ -55,7 +71,7 @@ export declare class ChatClient extends DifyClient {
user: User,
stream?: boolean,
conversation_id?: string | null,
files?: File[] | null
files?: DifyFile[] | null
): Promise<any>;

getSuggested(message_id: string, user: User): Promise<any>;

+ 1
- 0
web/app/components/share/utils.ts Ver fichero

@@ -32,6 +32,7 @@ export const checkOrSetAccessToken = async (appCode?: string | null) => {
[userId || 'DEFAULT']: res.access_token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson))
localStorage.removeItem(CONVERSATION_ID_INFO)
}
}


+ 4
- 2
web/context/web-app-context.tsx Ver fichero

@@ -11,6 +11,7 @@ import type { FC, PropsWithChildren } from 'react'
import { useEffect } from 'react'
import { useState } from 'react'
import { create } from 'zustand'
import { useGlobalPublicStore } from './global-public-context'

type WebAppStore = {
shareCode: string | null
@@ -56,6 +57,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => {
}

const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
const updateShareCode = useWebAppStore(state => state.updateShareCode)
const pathname = usePathname()
@@ -69,7 +71,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
}, [shareCode, updateShareCode])

const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false)
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true)

useEffect(() => {
if (accessModeResult?.accessMode) {
@@ -86,7 +88,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
}
}, [accessModeResult, updateWebAppAccessMode, shareCode])

if (isFetching || isFetchingAccessToken) {
if (isGlobalPending || isFetching || isFetchingAccessToken) {
return <div className='flex h-full w-full items-center justify-center'>
<Loading />
</div>

+ 9
- 4
web/service/base.ts Ver fichero

@@ -430,9 +430,7 @@ export const ssePost = async (
.then((res) => {
if (!/^[23]\d{2}$/.test(String(res.status))) {
if (res.status === 401) {
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
ssePost(url, fetchOptions, otherOptions)
}).catch(() => {
if (isPublicAPI) {
res.json().then((data: any) => {
if (isPublicAPI) {
if (data.code === 'web_app_access_denied')
@@ -449,7 +447,14 @@ export const ssePost = async (
}
}
})
})
}
else {
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
ssePost(url, fetchOptions, otherOptions)
}).catch((err) => {
console.error(err)
})
}
}
else {
res.json().then((data) => {

+ 0
- 8
web/service/use-share.ts Ver fichero

@@ -1,20 +1,12 @@
import { useGlobalPublicStore } from '@/context/global-public-context'
import { AccessMode } from '@/models/access-control'
import { useQuery } from '@tanstack/react-query'
import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share'

const NAME_SPACE = 'webapp'

export const useGetWebAppAccessModeByCode = (code: string | null) => {
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
return useQuery({
queryKey: [NAME_SPACE, 'appAccessMode', code],
queryFn: () => {
if (systemFeatures.webapp_auth.enabled === false) {
return {
accessMode: AccessMode.PUBLIC,
}
}
if (!code || code.length === 0)
return Promise.reject(new Error('App code is required to get access mode'))


Cargando…
Cancelar
Guardar