|
|
|
@@ -1,14 +1,20 @@ |
|
|
|
from unittest.mock import patch |
|
|
|
|
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom |
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult |
|
|
|
from core.workflow.enums import SystemVariableKey |
|
|
|
from core.workflow.graph_engine.entities.event import ( |
|
|
|
GraphRunPartialSucceededEvent, |
|
|
|
NodeRunExceptionEvent, |
|
|
|
NodeRunFailedEvent, |
|
|
|
NodeRunStreamChunkEvent, |
|
|
|
) |
|
|
|
from core.workflow.graph_engine.entities.graph import Graph |
|
|
|
from core.workflow.graph_engine.graph_engine import GraphEngine |
|
|
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent |
|
|
|
from core.workflow.nodes.llm.node import LLMNode |
|
|
|
from models.enums import UserFrom |
|
|
|
from models.workflow import WorkflowType |
|
|
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType |
|
|
|
|
|
|
|
|
|
|
|
class ContinueOnErrorTestHelper: |
|
|
|
@@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error(): |
|
|
|
"edges": FAIL_BRANCH_EDGES[:-1], |
|
|
|
"nodes": [ |
|
|
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, |
|
|
|
{ |
|
|
|
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, |
|
|
|
"id": "success", |
|
|
|
}, |
|
|
|
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"}, |
|
|
|
ContinueOnErrorTestHelper.get_http_node(), |
|
|
|
], |
|
|
|
} |
|
|
|
@@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error(): |
|
|
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events) |
|
|
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) |
|
|
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 |
|
|
|
|
|
|
|
|
|
|
|
def test_stream_output_with_fail_branch_continue_on_error(): |
|
|
|
"""Test stream output with fail-branch error strategy""" |
|
|
|
graph_config = { |
|
|
|
"edges": FAIL_BRANCH_EDGES, |
|
|
|
"nodes": [ |
|
|
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, |
|
|
|
{ |
|
|
|
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, |
|
|
|
"id": "success", |
|
|
|
}, |
|
|
|
{ |
|
|
|
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"}, |
|
|
|
"id": "error", |
|
|
|
}, |
|
|
|
ContinueOnErrorTestHelper.get_llm_node(), |
|
|
|
], |
|
|
|
} |
|
|
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) |
|
|
|
|
|
|
|
def llm_generator(self): |
|
|
|
contents = ["hi", "bye", "good morning"] |
|
|
|
|
|
|
|
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"]) |
|
|
|
|
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, |
|
|
|
inputs={}, |
|
|
|
process_data={}, |
|
|
|
outputs={}, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: 1, |
|
|
|
NodeRunMetadataKey.TOTAL_PRICE: 1, |
|
|
|
NodeRunMetadataKey.CURRENCY: "USD", |
|
|
|
}, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
with patch.object(LLMNode, "_run", new=llm_generator): |
|
|
|
events = list(graph_engine.run()) |
|
|
|
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1 |
|
|
|
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events) |