Bladeren bron

fix(fail-branch): prevent streaming output in exception branches (#17153)

tags/1.3.0
Novice 6 maanden geleden
bovenliggende
commit
c91045a9d0
No account linked to committer's email address

+ 21
- 2
api/core/workflow/nodes/answer/answer_stream_processor.py Bestand weergeven

@@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# exclude current node id
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]:
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids

+ 52
- 5
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py Bestand weergeven

@@ -1,14 +1,20 @@
from unittest.mock import patch

from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType


class ContinueOnErrorTestHelper:
@@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
ContinueOnErrorTestHelper.get_http_node(),
],
}
@@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0


def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)

def llm_generator(self):
contents = ["hi", "bye", "good morning"]

yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])

yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)

with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

Laden…
Annuleren
Opslaan