Browse Source

Feat: Retry on node execution errors (#11871)

Co-authored-by: Novice Lee <novicelee@NoviPro.local>
tags/0.14.2
Novice 10 months ago
parent
commit
7abc7fa573
No account linked to committer's email address

+ 17
- 0
api/core/app/apps/advanced_chat/generate_task_pipeline.py View File

QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent, QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent, QueueParallelBranchRunFailedEvent,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )


if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)

response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)

if response: if response:
yield response yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):

+ 18
- 1
api/core/app/apps/workflow/generate_task_pipeline.py View File

QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent, QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent, QueueParallelBranchRunFailedEvent,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )

if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)

response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)

if response:
yield response

elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
raise Exception("Workflow run not initialized.") raise Exception("Workflow run not initialized.")

+ 32
- 0
api/core/app/apps/workflow_app_runner.py View File

QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent, QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent, QueueParallelBranchRunFailedEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunRetrieverResourceEvent, NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
error=event.error if isinstance(event, IterationRunFailedEvent) else None, error=event.error if isinstance(event, IterationRunFailedEvent) else None,
) )
) )
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
start_index=event.start_index,
)
)


def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
""" """

+ 32
- 0
api/core/app/entities/queue_entities.py View File

ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
STOP = "stop" STOP = "stop"
RETRY = "retry"




class AppQueueEvent(BaseModel): class AppQueueEvent(BaseModel):
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None




class QueueNodeRetryEvent(AppQueueEvent):
"""QueueNodeRetryEvent entity"""

event: QueueEvent = QueueEvent.RETRY

node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime

inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str
retry_index: int # retry index
start_index: int # start index


class QueueNodeInIterationFailedEvent(AppQueueEvent): class QueueNodeInIterationFailedEvent(AppQueueEvent):
""" """
QueueNodeInIterationFailedEvent entity QueueNodeInIterationFailedEvent entity

+ 70
- 0
api/core/app/entities/task_entities.py View File

WORKFLOW_FINISHED = "workflow_finished" WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started" NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished" NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started" PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started" ITERATION_STARTED = "iteration_started"
} }




class NodeRetryStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""

class Data(BaseModel):
"""
Data entity
"""

id: str
node_id: str
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
retry_index: int = 0

event: StreamEvent = StreamEvent.NODE_RETRY
workflow_run_id: str
data: Data

def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
"workflow_run_id": self.workflow_run_id,
"data": {
"id": self.data.id,
"node_id": self.data.node_id,
"node_type": self.data.node_type,
"title": self.data.title,
"index": self.data.index,
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"process_data": None,
"outputs": None,
"status": self.data.status,
"error": None,
"elapsed_time": self.data.elapsed_time,
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"retry_index": self.data.retry_index,
},
}


class ParallelBranchStartStreamResponse(StreamResponse): class ParallelBranchStartStreamResponse(StreamResponse):
""" """
ParallelBranchStartStreamResponse entity ParallelBranchStartStreamResponse entity

+ 93
- 0
api/core/app/task_pipeline/workflow_cycle_manage.py View File

QueueNodeExceptionEvent, QueueNodeExceptionEvent,
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent, QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent, QueueParallelBranchRunFailedEvent,
IterationNodeNextStreamResponse, IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse, IterationNodeStartStreamResponse,
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse, ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse, ParallelBranchStartStreamResponse,


return workflow_node_execution return workflow_node_execution


def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)

workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.index = event.start_index

db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)

return workflow_node_execution

################################################# #################################################
# to stream responses # # to stream responses #
################################################# #################################################
), ),
) )


def _workflow_node_retry_to_stream_response(
self,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None

return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
),
)

def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:

+ 0
- 1
api/core/helper/ssrf_proxy.py View File

) )


retries = 0 retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:

+ 3
- 0
api/core/workflow/entities/node_entities.py View File



error: Optional[str] = None # error message if status is failed error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed error_type: Optional[str] = None # error type if status is failed

# single step node run retry
retry_index: int = 0

+ 7
- 0
api/core/workflow/graph_engine/entities/event.py View File

error: str = Field(..., description="error") error: str = Field(..., description="error")




class NodeRunRetryEvent(BaseNodeEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
start_index: int = Field(..., description="retry start index")


########################################### ###########################################
# Parallel Branch Events # Parallel Branch Events
########################################### ###########################################

+ 175
- 135
api/core/workflow/graph_engine/graph_engine.py View File

from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast from typing import Any, Optional, cast


from flask import Flask, current_app from flask import Flask, current_app
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunRetrieverResourceEvent, NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,


def _run_node( def _run_node(
self, self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState, route_node_state: RouteNodeState,
parallel_id: Optional[str] = None, parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None,
) )


db.session.close() db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
shoudl_continue_retry = True
while shoudl_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id

yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
self.graph_runtime_state.node_run_steps += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
error=run_result.error,
retry_index=retries,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
start_at=retry_start_at,
start_index=self.graph_runtime_state.node_run_steps,
)
time.sleep(retry_interval)
continue
route_node_state.set_finished(run_result=run_result)

if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)


try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage


yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)

if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
# append node output variables to variable pool
if run_result.outputs: if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items(): for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively # append variables to variable pool recursively
variable_key_list=[variable_key], variable_key_list=[variable_key],
variable_value=variable_value, variable_value=variable_value,
) )
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",

# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}

run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)

yield NodeRunSucceededEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
) )
shoudl_continue_retry = False


elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)

if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage

# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)

# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}

run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)

yield NodeRunSucceededEvent(
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
node_data=node_instance.node_data, node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
) )

break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()


def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
""" """

+ 13
- 0
api/core/workflow/nodes/base/entities.py View File

return self return self




class RetryConfig(BaseModel):
"""node retry config"""

max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled

@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000


class BaseNodeData(ABC, BaseModel): class BaseNodeData(ABC, BaseModel):
title: str title: str
desc: Optional[str] = None desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None default_value: Optional[list[DefaultValue]] = None
version: str = "1" version: str = "1"
retry_config: RetryConfig = RetryConfig()


@property @property
def default_value_dict(self): def default_value_dict(self):

+ 10
- 1
api/core/workflow/nodes/base/node.py View File

from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast


from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus


bool: if should continue on error bool: if should continue on error
""" """
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE

@property
def should_retry(self) -> bool:
"""judge if should retry

Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

+ 1
- 0
api/core/workflow/nodes/enums.py View File





CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

+ 8
- 1
api/core/workflow/nodes/event/__init__.py View File

from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent from .types import NodeEvent


__all__ = [ __all__ = [
"NodeEvent", "NodeEvent",
"RunCompletedEvent", "RunCompletedEvent",
"RunRetrieverResourceEvent", "RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent", "RunStreamChunkEvent",
] ]

+ 25
- 0
api/core/workflow/nodes/event/event.py View File

from datetime import datetime

from pydantic import BaseModel, Field from pydantic import BaseModel, Field


from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus




class RunCompletedEvent(BaseModel): class RunCompletedEvent(BaseModel):
text: str text: str
usage: LLMUsage usage: LLMUsage
finish_reason: str | None = None finish_reason: str | None = None


class RunRetryEvent(BaseModel):
"""Node Run Retry event"""

error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")


class SingleStepRetryEvent(BaseModel):
"""Single step retry event"""

status: str = WorkflowNodeExecutionStatus.RETRY.value

inputs: dict | None = Field(..., description="input")
error: str = Field(..., description="error")
outputs: dict = Field(..., description="output")
retry_index: int = Field(..., description="Retry attempt number")
error: str = Field(..., description="error")
elapsed_time: float = Field(..., description="elapsed time")
execution_metadata: dict | None = Field(..., description="execution metadata")

+ 4
- 0
api/core/workflow/nodes/http_request/executor.py View File

headers: dict[str, str] headers: dict[str, str]
auth: HttpRequestNodeAuthorization auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout timeout: HttpRequestNodeTimeout
max_retries: int


boundary: str boundary: str


node_data: HttpRequestNodeData, node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout, timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool, variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
): ):
# If authorization API key is present, convert the API key using the variable pool # If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key": if node_data.authorization.type == "api-key":
self.files = None self.files = None
self.data = None self.data = None
self.json = None self.json = None
self.max_retries = max_retries


# init template # init template
self.variable_pool = variable_pool self.variable_pool = variable_pool
"params": self.params, "params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True, "follow_redirects": True,
"max_retries": self.max_retries,
} }
# request_args = {k: v for k, v in request_args.items() if v is not None} # request_args = {k: v for k, v in request_args.items() if v is not None}
try: try:

+ 7
- 1
api/core/workflow/nodes/http_request/node.py View File

"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
}, },
}, },
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
} }


def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
node_data=self.node_data, node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data), timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
) )
process_data["request"] = http_executor.to_log() process_data["request"] = http_executor.to_log()


response = http_executor.invoke() response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response) files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
outputs={ outputs={

+ 14
- 0
api/fields/workflow_run_fields.py View File

"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField, "finished_at": TimestampField,
"exceptions_count": fields.Integer, "exceptions_count": fields.Integer,
"retry_index": fields.Integer,
} }


advanced_chat_workflow_run_for_list_fields = { advanced_chat_workflow_run_for_list_fields = {
"created_at": TimestampField, "created_at": TimestampField,
"finished_at": TimestampField, "finished_at": TimestampField,
"exceptions_count": fields.Integer, "exceptions_count": fields.Integer,
"retry_index": fields.Integer,
} }


advanced_chat_workflow_run_pagination_fields = { advanced_chat_workflow_run_pagination_fields = {
"exceptions_count": fields.Integer, "exceptions_count": fields.Integer,
} }


retry_event_field = {
"error": fields.String,
"retry_index": fields.Integer,
"inputs": fields.Raw(attribute="inputs"),
"elapsed_time": fields.Float,
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
"status": fields.String,
"outputs": fields.Raw(attribute="outputs"),
}


workflow_run_node_execution_fields = { workflow_run_node_execution_fields = {
"id": fields.String, "id": fields.String,
"index": fields.Integer, "index": fields.Integer,
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField, "finished_at": TimestampField,
"retry_events": fields.List(fields.Nested(retry_event_field)),
} }


workflow_run_node_execution_list_fields = { workflow_run_node_execution_list_fields = {

+ 33
- 0
api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py View File

"""add retry_index field to node-execution model

Revision ID: 348cb0a93d53
Revises: cf8f4fc45278
Create Date: 2024-12-16 01:23:13.093432

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '348cb0a93d53'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')

# ### end Alembic commands ###

+ 2
- 0
api/models/workflow.py View File

SUCCEEDED = "succeeded" SUCCEEDED = "succeeded"
FAILED = "failed" FAILED = "failed"
EXCEPTION = "exception" EXCEPTION = "exception"
RETRY = "retry"


@classmethod @classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime) finished_at = db.Column(db.DateTime)
retry_index = db.Column(db.Integer, server_default=db.text("0"))


@property @property
def created_by_account(self): def created_by_account(self):

+ 93
- 48
api/services/workflow_service.py View File

from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.event import SingleStepRetryEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated


# run draft workflow node # run draft workflow node
start_at = time.perf_counter() start_at = time.perf_counter()
retries = 0
max_retries = 0
should_retry = True
retry_events = []


try: try:
node_instance, generator = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result

# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break

if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
while retries <= max_retries and should_retry:
retry_start_at = time.perf_counter()
node_instance, generator = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
max_retries = (
node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
)
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result

# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break

if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and node_run_result.outputs
and not node_instance.should_continue_on_error
):
node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
should_retry = False
else:
if node_instance.should_retry:
node_run_result.status = WorkflowNodeExecutionStatus.RETRY
retries += 1
node_run_result.retry_index = retries
retry_events.append(
SingleStepRetryEvent(
inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
if node_run_result.inputs
else None,
error=node_run_result.error,
outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
if node_run_result.outputs
else None,
retry_index=node_run_result.retry_index,
elapsed_time=time.perf_counter() - retry_start_at,
execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
if node_run_result.metadata
else None,
)
)
time.sleep(retry_interval)
else:
should_retry = False
if node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e: except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance node_instance = e.node_instance
run_succeeded = False run_succeeded = False


db.session.add(workflow_node_execution) db.session.add(workflow_node_execution)
db.session.commit() db.session.commit()
workflow_node_execution.retry_events = retry_events


return workflow_node_execution return workflow_node_execution



+ 9
- 3
api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py View File

from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )


class ContinueOnErrorTestHelper: class ContinueOnErrorTestHelper:
@staticmethod @staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration""" """Helper method to create a code node configuration"""
node = { node = {
"id": "node", "id": "node",
"code_language": "python3", "code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]), "code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code", "type": "code",
**retry_config,
}, },
} }
if default_value: if default_value:


@staticmethod @staticmethod
def get_http_node( def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
error_strategy: str = "fail-branch",
default_value: dict | None = None,
authorization_success: bool = False,
retry_config: dict = {},
): ):
"""Helper method to create a http node configuration""" """Helper method to create a http node configuration"""
authorization = ( authorization = (
"body": None, "body": None,
"type": "http-request", "type": "http-request",
"error_strategy": error_strategy, "error_strategy": error_strategy,
**retry_config,
}, },
} }
if default_value: if default_value:

+ 73
- 0
api/tests/unit_tests/core/workflow/nodes/test_retry.py View File

from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunRetryEvent,
)
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper

DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]


def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}

graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11


def test_retry_failed():
"""retry failed with success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""

graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

Loading…
Cancel
Save