Преглед изворни кода

feat(graph_engine): Add NodeExecutionType.ROOT and auto mark skipped in Graph.init

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.1
-LAN- пре 2 месеци
родитељ
комит
bfbb36756a
No account linked to committer's email address

+ 1
- 0
api/core/workflow/enums.py Прегледај датотеку

@@ -57,6 +57,7 @@ class NodeExecutionType(StrEnum):
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points


class ErrorStrategy(StrEnum):

+ 70
- 1
api/core/workflow/graph/graph.py Прегледај датотеку

@@ -3,7 +3,7 @@ from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Protocol, cast

from core.workflow.enums import NodeType
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node

from .edge import Edge
@@ -186,6 +186,72 @@ class Graph:

return nodes

@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.

Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged

:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]

# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return

# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED

# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED

# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]

# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)

# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)

@classmethod
def init(
cls,
@@ -227,6 +293,9 @@ class Graph:
# Get root node instance
root_node = nodes[root_node_id]

# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)

# Create and return the graph
return cls(
nodes=nodes,

+ 2
- 1
api/core/workflow/nodes/start/start_node.py Прегледај датотеку

@@ -2,7 +2,7 @@ from collections.abc import Mapping
from typing import Any, Optional

from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
@@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData

class StartNode(Node):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT

_node_data: StartNodeData


+ 281
- 0
api/tests/unit_tests/core/workflow/graph/test_graph.py Прегледај датотеку

@@ -0,0 +1,281 @@
"""Unit tests for Graph class methods."""

from unittest.mock import Mock

from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.graph.edge import Edge
from core.workflow.graph.graph import Graph
from core.workflow.nodes.base.node import Node


def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
"""Create a mock node for testing."""
node = Mock(spec=Node)
node.id = node_id
node.execution_type = execution_type
node.state = state
node.node_type = NodeType.START
return node


class TestMarkInactiveRootBranches:
"""Test cases for _mark_inactive_root_branches method."""

def test_single_root_no_marking(self):
"""Test that single root graph doesn't mark anything as skipped."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
}

in_edges = {"child1": ["edge1"]}
out_edges = {"root1": ["edge1"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["child1"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN

def test_multiple_roots_mark_inactive(self):
"""Test marking inactive root branches with multiple root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
}

in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED

def test_shared_downstream_node(self):
"""Test that shared downstream nodes are not skipped if at least one path is active."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
}

in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"shared": ["edge3", "edge4"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"child1": ["edge3"],
"child2": ["edge4"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.SKIPPED
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED

def test_deep_branch_marking(self):
"""Test marking deep branches with multiple levels."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
}

in_edges = {
"level1_a": ["edge1"],
"level1_b": ["edge2"],
"level2_a": ["edge3"],
"level2_b": ["edge4"],
"level3": ["edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"level1_a": ["edge3"],
"level1_b": ["edge4"],
"level2_b": ["edge5"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["level1_a"].state == NodeState.UNKNOWN
assert nodes["level1_b"].state == NodeState.SKIPPED
assert nodes["level2_a"].state == NodeState.UNKNOWN
assert nodes["level2_b"].state == NodeState.SKIPPED
assert nodes["level3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.UNKNOWN
assert edges["edge4"].state == NodeState.SKIPPED
assert edges["edge5"].state == NodeState.SKIPPED

def test_non_root_execution_type(self):
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
}

in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
assert nodes["child1"].state == NodeState.UNKNOWN
assert nodes["child2"].state == NodeState.UNKNOWN
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.UNKNOWN

def test_empty_graph(self):
"""Test handling of empty graph structures."""
nodes = {}
edges = {}
in_edges = {}
out_edges = {}

# Should not raise any errors
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")

def test_three_roots_mark_two_inactive(self):
"""Test with three root nodes where two should be marked inactive."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
}

in_edges = {
"child1": ["edge1"],
"child2": ["edge2"],
"child3": ["edge3"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")

assert nodes["root1"].state == NodeState.SKIPPED
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["child1"].state == NodeState.SKIPPED
assert nodes["child2"].state == NodeState.UNKNOWN
assert nodes["child3"].state == NodeState.SKIPPED
assert edges["edge1"].state == NodeState.SKIPPED
assert edges["edge2"].state == NodeState.UNKNOWN
assert edges["edge3"].state == NodeState.SKIPPED

def test_convergent_paths(self):
"""Test convergent paths where multiple inactive branches lead to same node."""
nodes = {
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
}

edges = {
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
}

in_edges = {
"mid1": ["edge1"],
"mid2": ["edge2"],
"convergent": ["edge3", "edge4", "edge5"],
}
out_edges = {
"root1": ["edge1"],
"root2": ["edge2"],
"root3": ["edge3"],
"mid1": ["edge4"],
"mid2": ["edge5"],
}

Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")

assert nodes["root1"].state == NodeState.UNKNOWN
assert nodes["root2"].state == NodeState.SKIPPED
assert nodes["root3"].state == NodeState.SKIPPED
assert nodes["mid1"].state == NodeState.UNKNOWN
assert nodes["mid2"].state == NodeState.SKIPPED
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
assert edges["edge1"].state == NodeState.UNKNOWN
assert edges["edge2"].state == NodeState.SKIPPED
assert edges["edge3"].state == NodeState.SKIPPED
assert edges["edge4"].state == NodeState.UNKNOWN
assert edges["edge5"].state == NodeState.SKIPPED

Loading…
Откажи
Сачувај