浏览代码

fix(graph_engine): error strategy fall. (#26078)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.9.0
-LAN- 1 个月前
父节点
当前提交
2e2c87c5a1
没有帐户链接到提交者的电子邮件

+ 10
- 2
api/core/workflow/graph_engine/domain/graph_execution.py 查看文件

completed: bool = Field(default=False) completed: bool = Field(default=False)
aborted: bool = Field(default=False) aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None) error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])




def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
completed: bool = False completed: bool = False
aborted: bool = False aborted: bool = False
error: Exception | None = None error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0


def start(self) -> None: def start(self) -> None:
"""Mark the graph execution as started.""" """Mark the graph execution as started."""
completed=self.completed, completed=self.completed,
aborted=self.aborted, aborted=self.aborted,
error=_serialize_error(self.error), error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states, node_executions=node_states,
) )


self.completed = state.completed self.completed = state.completed
self.aborted = state.aborted self.aborted = state.aborted
self.error = _deserialize_error(state.error) self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = { self.node_executions = {
item.node_id: NodeExecution( item.node_id: NodeExecution(
node_id=item.node_id, node_id=item.node_id,
) )
for item in state.node_executions for item in state.node_executions
} }

def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

+ 55
- 11
api/core/workflow/graph_engine/event_management/event_handlers.py 查看文件

""" """


import logging import logging
from collections.abc import Mapping
from functools import singledispatchmethod from functools import singledispatchmethod
from typing import TYPE_CHECKING, final from typing import TYPE_CHECKING, final


from core.workflow.entities import GraphRuntimeState from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
""" """
# Track execution in domain model # Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id) node_execution.mark_started(event.id)


# Track in response coordinator for stream ordering # Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id) self._response_coordinator.track_node_execution(event.node_id, event.id)


# Collect the event
self._event_collector.collect(event)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)


@_dispatch.register @_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None: def _(self, event: NodeRunStreamChunkEvent) -> None:
node_execution.mark_taken() node_execution.mark_taken()


# Store outputs in variable pool # Store outputs in variable pool
self._store_node_outputs(event)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)


# Forward to response coordinator and emit streaming events # Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event) streaming_events = self._response_coordinator.intercept_event(event)


# Handle response node outputs # Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE: if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
self._update_response_outputs(event.node_run_result.outputs)


# Collect the event # Collect the event
self._event_collector.collect(event) self._event_collector.collect(event)
# Update domain model # Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error) node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()


result = self._error_handler.handle_node_failure(event) result = self._error_handler.handle_node_failure(event)


Args: Args:
event: The node exception event event: The node exception event
""" """
# Node continues via fail-branch, so it's technically "succeeded"
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken() node_execution.mark_taken()


# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)

node = self._graph.nodes[event.node_id]

if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")

for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)

for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)

# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)

self._state_manager.finish_execution(event.node_id)

# Collect the exception event for observers
self._event_collector.collect(event)

@_dispatch.register @_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None: def _(self, event: NodeRunRetryEvent) -> None:
""" """
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry() node_execution.increment_retry()


def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)

# Emit retry event for observers
self._event_collector.collect(event)

# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)

def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
""" """
Store node outputs in the variable pool. Store node outputs in the variable pool.


Args: Args:
event: The node succeeded event containing outputs event: The node succeeded event containing outputs
""" """
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)


def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes.""" """Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs # TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state. # in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
for key, value in outputs.items():
if key == "answer": if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "") existing = self._graph_runtime_state.get_output("answer", "")
if existing: if existing:

+ 16
- 4
api/core/workflow/graph_engine/graph_engine.py 查看文件

GraphNodeEventBase, GraphNodeEventBase,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
if self._graph_execution.error: if self._graph_execution.error:
raise self._graph_execution.error raise self._graph_execution.error
else: else:
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)


except Exception as e: except Exception as e:
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise raise


finally: finally:

+ 8
- 0
api/core/workflow/graph_engine/layers/debug_logging.py 查看文件

GraphEngineEvent, GraphEngineEvent,
GraphRunAbortedEvent, GraphRunAbortedEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
if self.include_outputs and event.outputs: if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))


elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))

elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error) self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0: if event.exceptions_count > 0:

+ 2
- 1
api/core/workflow/nodes/iteration/iteration_node.py 查看文件

from core.workflow.graph_events import ( from core.workflow.graph_events import (
GraphNodeEventBase, GraphNodeEventBase,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
) )
from core.workflow.node_events import ( from core.workflow.node_events import (
if isinstance(event, GraphNodeEventBase): if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index) self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event yield event
elif isinstance(event, GraphRunSucceededEvent):
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector) result = variable_pool.get(self._node_data.output_selector)
if result is None: if result is None:
outputs.append(None) outputs.append(None)

+ 120
- 0
api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py 查看文件

"""Tests for graph engine event handlers."""

from __future__ import annotations

from datetime import datetime

from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig


class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""


class _StubErrorHandler:
"""Minimal error handler stub for tests."""


class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""

def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False


def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""

node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)

variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")

event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)

handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)

return handler, event_manager, graph_execution


def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""

node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)

execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()

start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)

retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)

# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)

collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]

assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]

node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

+ 44
- 1
api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py 查看文件

from hypothesis import HealthCheck, given, settings from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st


from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType


# Import the test framework from the new module # Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase




else: else:
assert result.event_sequence_match is True assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"


def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()

fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()

graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)

llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]

engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)

events = list(engine.run())

assert isinstance(events[-1], GraphRunPartialSucceededEvent)

partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"

assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

+ 0
- 65
api/tests/unit_tests/core/workflow/nodes/test_retry.py 查看文件

import pytest

pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)

DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]


def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11


def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

正在加载...
取消
保存