Pārlūkot izejas kodu

fix(graph_engine): error strategy fall. (#26078)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.9.0
-LAN- pirms 1 mēnesi
vecāks
revīzija
2e2c87c5a1
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 10
- 2
api/core/workflow/graph_engine/domain/graph_execution.py Parādīt failu

@@ -41,7 +41,8 @@ class GraphExecutionState(BaseModel):
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])


def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
@@ -103,7 +104,8 @@ class GraphExecution:
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0

def start(self) -> None:
"""Mark the graph execution as started."""
@@ -172,6 +174,7 @@ class GraphExecution:
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
)

@@ -195,6 +198,7 @@ class GraphExecution:
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
@@ -205,3 +209,7 @@ class GraphExecution:
)
for item in state.node_executions
}

def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1

+ 55
- 11
api/core/workflow/graph_engine/event_management/event_handlers.py Parādīt failu

@@ -3,11 +3,12 @@ Event handler implementations for different event types.
"""

import logging
from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final

from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -122,13 +123,15 @@ class EventHandler:
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)

# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)

# Collect the event
self._event_collector.collect(event)
# Collect the event only for the first attempt; retries remain silent
if is_initial_attempt:
self._event_collector.collect(event)

@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
@@ -161,7 +164,7 @@ class EventHandler:
node_execution.mark_taken()

# Store outputs in variable pool
self._store_node_outputs(event)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)

# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
@@ -191,7 +194,7 @@ class EventHandler:

# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
self._update_response_outputs(event.node_run_result.outputs)

# Collect the event
self._event_collector.collect(event)
@@ -207,6 +210,7 @@ class EventHandler:
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()

result = self._error_handler.handle_node_failure(event)

@@ -227,10 +231,40 @@ class EventHandler:
Args:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
# Node continues via fail-branch/default-value, treat as completion
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()

# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)

node = self._graph.nodes[event.node_id]

if node.error_strategy == ErrorStrategy.DEFAULT_VALUE:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
elif node.error_strategy == ErrorStrategy.FAIL_BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}")

for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)

for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)

# Update response outputs if applicable
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event.node_run_result.outputs)

self._state_manager.finish_execution(event.node_id)

# Collect the exception event for observers
self._event_collector.collect(event)

@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
@@ -242,21 +276,31 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()

def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
# Finish the previous attempt before re-queuing the node
self._state_manager.finish_execution(event.node_id)

# Emit retry event for observers
self._event_collector.collect(event)

# Re-queue node for execution
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)

def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.

Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
for variable_name, variable_value in outputs.items():
self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)

def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
def _update_response_outputs(self, outputs: Mapping[str, object]) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
for key, value in outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:

+ 16
- 4
api/core/workflow/graph_engine/graph_engine.py Parādīt failu

@@ -23,6 +23,7 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@@ -260,12 +261,23 @@ class GraphEngine:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
else:
yield GraphRunSucceededEvent(
outputs=outputs,
)

except Exception as e:
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
raise

finally:

+ 8
- 0
api/core/workflow/graph_engine/layers/debug_logging.py Parādīt failu

@@ -15,6 +15,7 @@ from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
@@ -127,6 +128,13 @@ class DebugLoggingLayer(GraphEngineLayer):
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))

elif isinstance(event, GraphRunPartialSucceededEvent):
self.logger.warning("⚠️ Graph run partially succeeded")
if event.exceptions_count > 0:
self.logger.warning(" Total exceptions: %s", event.exceptions_count)
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))

elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:

+ 2
- 1
api/core/workflow/nodes/iteration/iteration_node.py Parādīt failu

@@ -19,6 +19,7 @@ from core.workflow.enums import (
from core.workflow.graph_events import (
GraphNodeEventBase,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
)
from core.workflow.node_events import (
@@ -456,7 +457,7 @@ class IterationNode(Node):
if isinstance(event, GraphNodeEventBase):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, GraphRunSucceededEvent):
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
if result is None:
outputs.append(None)

+ 120
- 0
api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py Parādīt failu

@@ -0,0 +1,120 @@
"""Tests for graph engine event handlers."""

from __future__ import annotations

from datetime import datetime

from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.event_management.event_handlers import EventHandler
from core.workflow.graph_engine.event_management.event_manager import EventManager
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig


class _StubEdgeProcessor:
"""Minimal edge processor stub for tests."""


class _StubErrorHandler:
"""Minimal error handler stub for tests."""


class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""

def __init__(self, node_id: str) -> None:
self.id = node_id
self.state = NodeState.UNKNOWN
self.title = "Stub Node"
self.execution_type = NodeExecutionType.EXECUTABLE
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False


def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
"""Construct an EventHandler with in-memory dependencies for testing."""

node = _StubNode(node_id)
graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node)

variable_pool = VariablePool()
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_execution = GraphExecution(workflow_id="test-workflow")

event_manager = EventManager()
state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue())
response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph)

handler = EventHandler(
graph=graph,
graph_runtime_state=runtime_state,
graph_execution=graph_execution,
response_coordinator=response_coordinator,
event_collector=event_manager,
edge_processor=_StubEdgeProcessor(),
state_manager=state_manager,
error_handler=_StubErrorHandler(),
)

return handler, event_manager, graph_execution


def test_retry_does_not_emit_additional_start_event() -> None:
"""Ensure retry attempts do not produce duplicate start events."""

node_id = "test-node"
handler, event_manager, graph_execution = _build_event_handler(node_id)

execution_id = "exec-1"
node_type = NodeType.CODE
start_time = datetime.utcnow()

start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(start_event)

retry_event = NodeRunRetryEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
error="boom",
retry_index=1,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="boom",
error_type="TestError",
),
)
handler.dispatch(retry_event)

# Simulate the node starting execution again after retry
second_start_event = NodeRunStartedEvent(
id=execution_id,
node_id=node_id,
node_type=node_type,
node_title="Stub Node",
start_at=start_time,
)
handler.dispatch(second_start_event)

collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined]

assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent]

node_execution = graph_execution.get_or_create_node_execution(node_id)
assert node_execution.retry_count == 1

+ 44
- 1
api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py Parādīt failu

@@ -10,11 +10,18 @@ import time
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st

from core.workflow.enums import ErrorStrategy
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
from core.workflow.graph_events import (
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType

# Import the test framework from the new module
from .test_mock_config import MockConfigBuilder
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase


@@ -721,3 +728,39 @@ def test_event_sequence_validation_with_table_tests():
else:
assert result.event_sequence_match is True
assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}"


def test_graph_run_emits_partial_success_when_node_failure_recovered():
runner = TableTestRunner()

fixture_data = runner.workflow_runner.load_fixture("basic_chatflow")
mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build()

graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture(
fixture_data=fixture_data,
query="hello",
use_mock_factory=True,
mock_config=mock_config,
)

llm_node = graph.nodes["llm"]
base_node_data = llm_node.get_base_node_data()
base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE
base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)]

engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)

events = list(engine.run())

assert isinstance(events[-1], GraphRunPartialSucceededEvent)

partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent))
assert partial_event.exceptions_count == 1
assert partial_event.outputs.get("answer") == "fallback response"

assert not any(isinstance(event, GraphRunSucceededEvent) for event in events)

+ 0
- 65
api/tests/unit_tests/core/workflow/nodes/test_retry.py Parādīt failu

@@ -1,65 +0,0 @@
import pytest

pytest.skip(
"Retry functionality is part of Phase 2 enhanced error handling - not implemented in MVP of queue-based engine",
allow_module_level=True,
)

DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]


def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11


def test_retry_failed():
"""retry failed with success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

Notiek ielāde…
Atcelt
Saglabāt