Co-authored-by: Novice Lee <novicelee@NoviPro.local>tags/0.14.2
| @@ -22,6 +22,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response: | |||
| yield response | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeRetryEvent, | |||
| ): | |||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| response = self._workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| @@ -18,6 +18,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if node_failed_response: | |||
| yield node_failed_response | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeRetryEvent, | |||
| ): | |||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| response = self._workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| if not workflow_run: | |||
| raise Exception("Workflow run not initialized.") | |||
| @@ -11,6 +11,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import ( | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunRetryEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| @@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| error=event.error if isinstance(event, IterationRunFailedEvent) else None, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunRetryEvent): | |||
| self._publish_event( | |||
| QueueNodeRetryEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| error=event.error, | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_iteration_id=event.in_iteration_id, | |||
| retry_index=event.retry_index, | |||
| start_index=event.start_index, | |||
| ) | |||
| ) | |||
| def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| @@ -43,6 +43,7 @@ class QueueEvent(StrEnum): | |||
| ERROR = "error" | |||
| PING = "ping" | |||
| STOP = "stop" | |||
| RETRY = "retry" | |||
| class AppQueueEvent(BaseModel): | |||
| @@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| iteration_duration_map: Optional[dict[str, float]] = None | |||
| class QueueNodeRetryEvent(AppQueueEvent): | |||
| """QueueNodeRetryEvent entity""" | |||
| event: QueueEvent = QueueEvent.RETRY | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None | |||
| error: str | |||
| retry_index: int # retry index | |||
| start_index: int # start index | |||
| class QueueNodeInIterationFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeInIterationFailedEvent entity | |||
| @@ -52,6 +52,7 @@ class StreamEvent(Enum): | |||
| WORKFLOW_FINISHED = "workflow_finished" | |||
| NODE_STARTED = "node_started" | |||
| NODE_FINISHED = "node_finished" | |||
| NODE_RETRY = "node_retry" | |||
| PARALLEL_BRANCH_STARTED = "parallel_branch_started" | |||
| PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" | |||
| ITERATION_STARTED = "iteration_started" | |||
| @@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse): | |||
| } | |||
| class NodeRetryStreamResponse(StreamResponse): | |||
| """ | |||
| NodeFinishStreamResponse entity | |||
| """ | |||
| class Data(BaseModel): | |||
| """ | |||
| Data entity | |||
| """ | |||
| id: str | |||
| node_id: str | |||
| node_type: str | |||
| title: str | |||
| index: int | |||
| predecessor_node_id: Optional[str] = None | |||
| inputs: Optional[dict] = None | |||
| process_data: Optional[dict] = None | |||
| outputs: Optional[dict] = None | |||
| status: str | |||
| error: Optional[str] = None | |||
| elapsed_time: float | |||
| execution_metadata: Optional[dict] = None | |||
| created_at: int | |||
| finished_at: int | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| retry_index: int = 0 | |||
| event: StreamEvent = StreamEvent.NODE_RETRY | |||
| workflow_run_id: str | |||
| data: Data | |||
| def to_ignore_detail_dict(self): | |||
| return { | |||
| "event": self.event.value, | |||
| "task_id": self.task_id, | |||
| "workflow_run_id": self.workflow_run_id, | |||
| "data": { | |||
| "id": self.data.id, | |||
| "node_id": self.data.node_id, | |||
| "node_type": self.data.node_type, | |||
| "title": self.data.title, | |||
| "index": self.data.index, | |||
| "predecessor_node_id": self.data.predecessor_node_id, | |||
| "inputs": None, | |||
| "process_data": None, | |||
| "outputs": None, | |||
| "status": self.data.status, | |||
| "error": None, | |||
| "elapsed_time": self.data.elapsed_time, | |||
| "execution_metadata": None, | |||
| "created_at": self.data.created_at, | |||
| "finished_at": self.data.finished_at, | |||
| "files": [], | |||
| "parallel_id": self.data.parallel_id, | |||
| "parallel_start_node_id": self.data.parallel_start_node_id, | |||
| "parent_parallel_id": self.data.parent_parallel_id, | |||
| "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, | |||
| "iteration_id": self.data.iteration_id, | |||
| "retry_index": self.data.retry_index, | |||
| }, | |||
| } | |||
| class ParallelBranchStartStreamResponse(StreamResponse): | |||
| """ | |||
| ParallelBranchStartStreamResponse entity | |||
| @@ -15,6 +15,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeRetryEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -26,6 +27,7 @@ from core.app.entities.task_entities import ( | |||
| IterationNodeNextStreamResponse, | |||
| IterationNodeStartStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeRetryStreamResponse, | |||
| NodeStartStreamResponse, | |||
| ParallelBranchFinishedStreamResponse, | |||
| ParallelBranchStartStreamResponse, | |||
| @@ -423,6 +425,52 @@ class WorkflowCycleManage: | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_retried( | |||
| self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param event: queue node failed event | |||
| :return: | |||
| """ | |||
| created_at = event.start_at | |||
| finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - created_at).total_seconds() | |||
| inputs = WorkflowEntry.handle_special_values(event.inputs) | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| workflow_node_execution = WorkflowNodeExecution() | |||
| workflow_node_execution.tenant_id = workflow_run.tenant_id | |||
| workflow_node_execution.app_id = workflow_run.app_id | |||
| workflow_node_execution.workflow_id = workflow_run.workflow_id | |||
| workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value | |||
| workflow_node_execution.workflow_run_id = workflow_run.id | |||
| workflow_node_execution.node_execution_id = event.node_execution_id | |||
| workflow_node_execution.node_id = event.node_id | |||
| workflow_node_execution.node_type = event.node_type.value | |||
| workflow_node_execution.title = event.node_data.title | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value | |||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||
| workflow_node_execution.created_by = workflow_run.created_by | |||
| workflow_node_execution.created_at = created_at | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.execution_metadata = json.dumps( | |||
| { | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| } | |||
| ) | |||
| workflow_node_execution.index = event.start_index | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| db.session.refresh(workflow_node_execution) | |||
| return workflow_node_execution | |||
| ################################################# | |||
| # to stream responses # | |||
| ################################################# | |||
| @@ -587,6 +635,51 @@ class WorkflowCycleManage: | |||
| ), | |||
| ) | |||
| def _workflow_node_retry_to_stream_response( | |||
| self, | |||
| event: QueueNodeRetryEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| """ | |||
| Workflow node finish to stream response. | |||
| :param event: queue node succeeded or failed event | |||
| :param task_id: task id | |||
| :param workflow_node_execution: workflow node execution | |||
| :return: | |||
| """ | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| return None | |||
| return NodeRetryStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_node_execution.workflow_run_id, | |||
| data=NodeRetryStreamResponse.Data( | |||
| id=workflow_node_execution.id, | |||
| node_id=workflow_node_execution.node_id, | |||
| node_type=workflow_node_execution.node_type, | |||
| index=workflow_node_execution.index, | |||
| title=workflow_node_execution.title, | |||
| predecessor_node_id=workflow_node_execution.predecessor_node_id, | |||
| inputs=workflow_node_execution.inputs_dict, | |||
| process_data=workflow_node_execution.process_data_dict, | |||
| outputs=workflow_node_execution.outputs_dict, | |||
| status=workflow_node_execution.status, | |||
| error=workflow_node_execution.error, | |||
| elapsed_time=workflow_node_execution.elapsed_time, | |||
| execution_metadata=workflow_node_execution.execution_metadata_dict, | |||
| created_at=int(workflow_node_execution.created_at.timestamp()), | |||
| finished_at=int(workflow_node_execution.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| retry_index=event.retry_index, | |||
| ), | |||
| ) | |||
| def _workflow_parallel_branch_start_to_stream_response( | |||
| self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent | |||
| ) -> ParallelBranchStartStreamResponse: | |||
| @@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| ) | |||
| retries = 0 | |||
| stream = kwargs.pop("stream", False) | |||
| while retries <= max_retries: | |||
| try: | |||
| if dify_config.SSRF_PROXY_ALL_URL: | |||
| @@ -45,3 +45,6 @@ class NodeRunResult(BaseModel): | |||
| error: Optional[str] = None # error message if status is failed | |||
| error_type: Optional[str] = None # error type if status is failed | |||
| # single step node run retry | |||
| retry_index: int = 0 | |||
| @@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| class NodeRunRetryEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| retry_index: int = Field(..., description="which retry attempt is about to be performed") | |||
| start_at: datetime = Field(..., description="retry start time") | |||
| start_index: int = Field(..., description="retry start index") | |||
| ########################################### | |||
| # Parallel Branch Events | |||
| ########################################### | |||
| @@ -5,6 +5,7 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from concurrent.futures import ThreadPoolExecutor, wait | |||
| from copy import copy, deepcopy | |||
| from datetime import UTC, datetime | |||
| from typing import Any, Optional, cast | |||
| from flask import Flask, current_app | |||
| @@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import ( | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunRetryEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| @@ -581,7 +583,7 @@ class GraphEngine: | |||
| def _run_node( | |||
| self, | |||
| node_instance: BaseNode, | |||
| node_instance: BaseNode[BaseNodeData], | |||
| route_node_state: RouteNodeState, | |||
| parallel_id: Optional[str] = None, | |||
| parallel_start_node_id: Optional[str] = None, | |||
| @@ -607,36 +609,121 @@ class GraphEngine: | |||
| ) | |||
| db.session.close() | |||
| max_retries = node_instance.node_data.retry_config.max_retries | |||
| retry_interval = node_instance.node_data.retry_config.retry_interval_seconds | |||
| retries = 0 | |||
| shoudl_continue_retry = True | |||
| while shoudl_continue_retry and retries <= max_retries: | |||
| try: | |||
| # run node | |||
| retry_start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| generator = node_instance.run() | |||
| for item in generator: | |||
| if isinstance(item, GraphEngineEvent): | |||
| if isinstance(item, BaseIterationEvent): | |||
| # add parallel info to iteration event | |||
| item.parallel_id = parallel_id | |||
| item.parallel_start_node_id = parallel_start_node_id | |||
| item.parent_parallel_id = parent_parallel_id | |||
| item.parent_parallel_start_node_id = parent_parallel_start_node_id | |||
| yield item | |||
| else: | |||
| if isinstance(item, RunCompletedEvent): | |||
| run_result = item.run_result | |||
| if run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| if ( | |||
| retries == max_retries | |||
| and node_instance.node_type == NodeType.HTTP_REQUEST | |||
| and run_result.outputs | |||
| and not node_instance.should_continue_on_error | |||
| ): | |||
| run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED | |||
| if node_instance.should_retry and retries < max_retries: | |||
| retries += 1 | |||
| self.graph_runtime_state.node_run_steps += 1 | |||
| route_node_state.node_run_result = run_result | |||
| yield NodeRunRetryEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| error=run_result.error, | |||
| retry_index=retries, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| start_at=retry_start_at, | |||
| start_index=self.graph_runtime_state.node_run_steps, | |||
| ) | |||
| time.sleep(retry_interval) | |||
| continue | |||
| route_node_state.set_finished(run_result=run_result) | |||
| if run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| if node_instance.should_continue_on_error: | |||
| # if run failed, handle error | |||
| run_result = self._handle_continue_on_error( | |||
| node_instance, | |||
| item.run_result, | |||
| self.graph_runtime_state.variable_pool, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| route_node_state.node_run_result = run_result | |||
| route_node_state.status = RouteNodeState.Status.EXCEPTION | |||
| if run_result.outputs: | |||
| for variable_key, variable_value in run_result.outputs.items(): | |||
| # append variables to variable pool recursively | |||
| self._append_variables_recursively( | |||
| node_id=node_instance.node_id, | |||
| variable_key_list=[variable_key], | |||
| variable_value=variable_value, | |||
| ) | |||
| yield NodeRunExceptionEvent( | |||
| error=run_result.error or "System Error", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| shoudl_continue_retry = False | |||
| else: | |||
| yield NodeRunFailedEvent( | |||
| error=route_node_state.failed_reason or "Unknown error.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| shoudl_continue_retry = False | |||
| elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| if node_instance.should_continue_on_error and self.graph.edge_mapping.get( | |||
| node_instance.node_id | |||
| ): | |||
| run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS | |||
| if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| # plus state total_tokens | |||
| self.graph_runtime_state.total_tokens += int( | |||
| run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] | |||
| ) | |||
| try: | |||
| # run node | |||
| generator = node_instance.run() | |||
| for item in generator: | |||
| if isinstance(item, GraphEngineEvent): | |||
| if isinstance(item, BaseIterationEvent): | |||
| # add parallel info to iteration event | |||
| item.parallel_id = parallel_id | |||
| item.parallel_start_node_id = parallel_start_node_id | |||
| item.parent_parallel_id = parent_parallel_id | |||
| item.parent_parallel_start_node_id = parent_parallel_start_node_id | |||
| if run_result.llm_usage: | |||
| # use the latest usage | |||
| self.graph_runtime_state.llm_usage += run_result.llm_usage | |||
| yield item | |||
| else: | |||
| if isinstance(item, RunCompletedEvent): | |||
| run_result = item.run_result | |||
| route_node_state.set_finished(run_result=run_result) | |||
| if run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| if node_instance.should_continue_on_error: | |||
| # if run failed, handle error | |||
| run_result = self._handle_continue_on_error( | |||
| node_instance, | |||
| item.run_result, | |||
| self.graph_runtime_state.variable_pool, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| route_node_state.node_run_result = run_result | |||
| route_node_state.status = RouteNodeState.Status.EXCEPTION | |||
| # append node output variables to variable pool | |||
| if run_result.outputs: | |||
| for variable_key, variable_value in run_result.outputs.items(): | |||
| # append variables to variable pool recursively | |||
| @@ -645,21 +732,23 @@ class GraphEngine: | |||
| variable_key_list=[variable_key], | |||
| variable_value=variable_value, | |||
| ) | |||
| yield NodeRunExceptionEvent( | |||
| error=run_result.error or "System Error", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| else: | |||
| yield NodeRunFailedEvent( | |||
| error=route_node_state.failed_reason or "Unknown error.", | |||
| # add parallel info to run result metadata | |||
| if parallel_id and parallel_start_node_id: | |||
| if not run_result.metadata: | |||
| run_result.metadata = {} | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( | |||
| parallel_start_node_id | |||
| ) | |||
| if parent_parallel_id and parent_parallel_start_node_id: | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( | |||
| parent_parallel_start_node_id | |||
| ) | |||
| yield NodeRunSucceededEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| @@ -670,108 +759,59 @@ class GraphEngine: | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| shoudl_continue_retry = False | |||
| elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| if node_instance.should_continue_on_error and self.graph.edge_mapping.get( | |||
| node_instance.node_id | |||
| ): | |||
| run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS | |||
| if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| # plus state total_tokens | |||
| self.graph_runtime_state.total_tokens += int( | |||
| run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] | |||
| ) | |||
| if run_result.llm_usage: | |||
| # use the latest usage | |||
| self.graph_runtime_state.llm_usage += run_result.llm_usage | |||
| # append node output variables to variable pool | |||
| if run_result.outputs: | |||
| for variable_key, variable_value in run_result.outputs.items(): | |||
| # append variables to variable pool recursively | |||
| self._append_variables_recursively( | |||
| node_id=node_instance.node_id, | |||
| variable_key_list=[variable_key], | |||
| variable_value=variable_value, | |||
| ) | |||
| # add parallel info to run result metadata | |||
| if parallel_id and parallel_start_node_id: | |||
| if not run_result.metadata: | |||
| run_result.metadata = {} | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id | |||
| if parent_parallel_id and parent_parallel_start_node_id: | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id | |||
| run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( | |||
| parent_parallel_start_node_id | |||
| ) | |||
| yield NodeRunSucceededEvent( | |||
| break | |||
| elif isinstance(item, RunStreamChunkEvent): | |||
| yield NodeRunStreamChunkEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| chunk_content=item.chunk_content, | |||
| from_variable_selector=item.from_variable_selector, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| break | |||
| elif isinstance(item, RunStreamChunkEvent): | |||
| yield NodeRunStreamChunkEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| chunk_content=item.chunk_content, | |||
| from_variable_selector=item.from_variable_selector, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| elif isinstance(item, RunRetrieverResourceEvent): | |||
| yield NodeRunRetrieverResourceEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| retriever_resources=item.retriever_resources, | |||
| context=item.context, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| except GenerateTaskStoppedError: | |||
| # trigger node run failed event | |||
| route_node_state.status = RouteNodeState.Status.FAILED | |||
| route_node_state.failed_reason = "Workflow stopped." | |||
| yield NodeRunFailedEvent( | |||
| error="Workflow stopped.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| return | |||
| except Exception as e: | |||
| logger.exception(f"Node {node_instance.node_data.title} run failed") | |||
| raise e | |||
| finally: | |||
| db.session.close() | |||
| elif isinstance(item, RunRetrieverResourceEvent): | |||
| yield NodeRunRetrieverResourceEvent( | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| retriever_resources=item.retriever_resources, | |||
| context=item.context, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| except GenerateTaskStoppedError: | |||
| # trigger node run failed event | |||
| route_node_state.status = RouteNodeState.Status.FAILED | |||
| route_node_state.failed_reason = "Workflow stopped." | |||
| yield NodeRunFailedEvent( | |||
| error="Workflow stopped.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| return | |||
| except Exception as e: | |||
| logger.exception(f"Node {node_instance.node_data.title} run failed") | |||
| raise e | |||
| finally: | |||
| db.session.close() | |||
| def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||
| """ | |||
| @@ -106,12 +106,25 @@ class DefaultValue(BaseModel): | |||
| return self | |||
| class RetryConfig(BaseModel): | |||
| """node retry config""" | |||
| max_retries: int = 0 # max retry times | |||
| retry_interval: int = 0 # retry interval in milliseconds | |||
| retry_enabled: bool = False # whether retry is enabled | |||
| @property | |||
| def retry_interval_seconds(self) -> float: | |||
| return self.retry_interval / 1000 | |||
| class BaseNodeData(ABC, BaseModel): | |||
| title: str | |||
| desc: Optional[str] = None | |||
| error_strategy: Optional[ErrorStrategy] = None | |||
| default_value: Optional[list[DefaultValue]] = None | |||
| version: str = "1" | |||
| retry_config: RetryConfig = RetryConfig() | |||
| @property | |||
| def default_value_dict(self): | |||
| @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType | |||
| from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]): | |||
| bool: if should continue on error | |||
| """ | |||
| return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE | |||
| @property | |||
| def should_retry(self) -> bool: | |||
| """judge if should retry | |||
| Returns: | |||
| bool: if should retry | |||
| """ | |||
| return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE | |||
| @@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum): | |||
| CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] | |||
| RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST] | |||
| @@ -1,4 +1,10 @@ | |||
| from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from .event import ( | |||
| ModelInvokeCompletedEvent, | |||
| RunCompletedEvent, | |||
| RunRetrieverResourceEvent, | |||
| RunRetryEvent, | |||
| RunStreamChunkEvent, | |||
| ) | |||
| from .types import NodeEvent | |||
| __all__ = [ | |||
| @@ -6,5 +12,6 @@ __all__ = [ | |||
| "NodeEvent", | |||
| "RunCompletedEvent", | |||
| "RunRetrieverResourceEvent", | |||
| "RunRetryEvent", | |||
| "RunStreamChunkEvent", | |||
| ] | |||
| @@ -1,7 +1,10 @@ | |||
| from datetime import datetime | |||
| from pydantic import BaseModel, Field | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| class RunCompletedEvent(BaseModel): | |||
| @@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel): | |||
| text: str | |||
| usage: LLMUsage | |||
| finish_reason: str | None = None | |||
| class RunRetryEvent(BaseModel): | |||
| """Node Run Retry event""" | |||
| error: str = Field(..., description="error") | |||
| retry_index: int = Field(..., description="Retry attempt number") | |||
| start_at: datetime = Field(..., description="Retry start time") | |||
| class SingleStepRetryEvent(BaseModel): | |||
| """Single step retry event""" | |||
| status: str = WorkflowNodeExecutionStatus.RETRY.value | |||
| inputs: dict | None = Field(..., description="input") | |||
| error: str = Field(..., description="error") | |||
| outputs: dict = Field(..., description="output") | |||
| retry_index: int = Field(..., description="Retry attempt number") | |||
| error: str = Field(..., description="error") | |||
| elapsed_time: float = Field(..., description="elapsed time") | |||
| execution_metadata: dict | None = Field(..., description="execution metadata") | |||
| @@ -45,6 +45,7 @@ class Executor: | |||
| headers: dict[str, str] | |||
| auth: HttpRequestNodeAuthorization | |||
| timeout: HttpRequestNodeTimeout | |||
| max_retries: int | |||
| boundary: str | |||
| @@ -54,6 +55,7 @@ class Executor: | |||
| node_data: HttpRequestNodeData, | |||
| timeout: HttpRequestNodeTimeout, | |||
| variable_pool: VariablePool, | |||
| max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, | |||
| ): | |||
| # If authorization API key is present, convert the API key using the variable pool | |||
| if node_data.authorization.type == "api-key": | |||
| @@ -73,6 +75,7 @@ class Executor: | |||
| self.files = None | |||
| self.data = None | |||
| self.json = None | |||
| self.max_retries = max_retries | |||
| # init template | |||
| self.variable_pool = variable_pool | |||
| @@ -241,6 +244,7 @@ class Executor: | |||
| "params": self.params, | |||
| "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), | |||
| "follow_redirects": True, | |||
| "max_retries": self.max_retries, | |||
| } | |||
| # request_args = {k: v for k, v in request_args.items() if v is not None} | |||
| try: | |||
| @@ -52,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, | |||
| }, | |||
| }, | |||
| "retry_config": { | |||
| "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, | |||
| "retry_interval": 0.5 * (2**2), | |||
| "retry_enabled": True, | |||
| }, | |||
| } | |||
| def _run(self) -> NodeRunResult: | |||
| @@ -61,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| node_data=self.node_data, | |||
| timeout=self._get_request_timeout(self.node_data), | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| max_retries=0, | |||
| ) | |||
| process_data["request"] = http_executor.to_log() | |||
| response = http_executor.invoke() | |||
| files = self.extract_files(url=http_executor.url, response=response) | |||
| if not response.response.is_success and self.should_continue_on_error: | |||
| if not response.response.is_success and (self.should_continue_on_error or self.should_retry): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| outputs={ | |||
| @@ -29,6 +29,7 @@ workflow_run_for_list_fields = { | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| "retry_index": fields.Integer, | |||
| } | |||
| advanced_chat_workflow_run_for_list_fields = { | |||
| @@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = { | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| "retry_index": fields.Integer, | |||
| } | |||
| advanced_chat_workflow_run_pagination_fields = { | |||
| @@ -79,6 +81,17 @@ workflow_run_detail_fields = { | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| retry_event_field = { | |||
| "error": fields.String, | |||
| "retry_index": fields.Integer, | |||
| "inputs": fields.Raw(attribute="inputs"), | |||
| "elapsed_time": fields.Float, | |||
| "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), | |||
| "status": fields.String, | |||
| "outputs": fields.Raw(attribute="outputs"), | |||
| } | |||
| workflow_run_node_execution_fields = { | |||
| "id": fields.String, | |||
| "index": fields.Integer, | |||
| @@ -99,6 +112,7 @@ workflow_run_node_execution_fields = { | |||
| "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), | |||
| "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), | |||
| "finished_at": TimestampField, | |||
| "retry_events": fields.List(fields.Nested(retry_event_field)), | |||
| } | |||
| workflow_run_node_execution_list_fields = { | |||
| @@ -0,0 +1,33 @@ | |||
| """add retry_index field to node-execution model | |||
| Revision ID: 348cb0a93d53 | |||
| Revises: cf8f4fc45278 | |||
| Create Date: 2024-12-16 01:23:13.093432 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = '348cb0a93d53' | |||
| down_revision = 'cf8f4fc45278' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.drop_column('retry_index') | |||
| # ### end Alembic commands ### | |||
| @@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum): | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| EXCEPTION = "exception" | |||
| RETRY = "retry" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": | |||
| @@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model): | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| finished_at = db.Column(db.DateTime) | |||
| retry_index = db.Column(db.Integer, server_default=db.text("0")) | |||
| @property | |||
| def created_by_account(self): | |||
| @@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData | |||
| from core.workflow.nodes.base.node import BaseNode | |||
| from core.workflow.nodes.enums import ErrorStrategy | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.event.event import SingleStepRetryEvent | |||
| from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated | |||
| @@ -220,56 +221,99 @@ class WorkflowService: | |||
| # run draft workflow node | |||
| start_at = time.perf_counter() | |||
| retries = 0 | |||
| max_retries = 0 | |||
| should_retry = True | |||
| retry_events = [] | |||
| try: | |||
| node_instance, generator = WorkflowEntry.single_step_run( | |||
| workflow=draft_workflow, | |||
| node_id=node_id, | |||
| user_inputs=user_inputs, | |||
| user_id=account.id, | |||
| ) | |||
| node_instance = cast(BaseNode[BaseNodeData], node_instance) | |||
| node_run_result: NodeRunResult | None = None | |||
| for event in generator: | |||
| if isinstance(event, RunCompletedEvent): | |||
| node_run_result = event.run_result | |||
| # sign output files | |||
| node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) | |||
| break | |||
| if not node_run_result: | |||
| raise ValueError("Node run failed with no run result") | |||
| # single step debug mode error handling return | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: | |||
| node_error_args = { | |||
| "status": WorkflowNodeExecutionStatus.EXCEPTION, | |||
| "error": node_run_result.error, | |||
| "inputs": node_run_result.inputs, | |||
| "metadata": {"error_strategy": node_instance.node_data.error_strategy}, | |||
| } | |||
| if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| **node_instance.node_data.default_value_dict, | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| else: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| run_succeeded = node_run_result.status in ( | |||
| WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| WorkflowNodeExecutionStatus.EXCEPTION, | |||
| ) | |||
| error = node_run_result.error if not run_succeeded else None | |||
| while retries <= max_retries and should_retry: | |||
| retry_start_at = time.perf_counter() | |||
| node_instance, generator = WorkflowEntry.single_step_run( | |||
| workflow=draft_workflow, | |||
| node_id=node_id, | |||
| user_inputs=user_inputs, | |||
| user_id=account.id, | |||
| ) | |||
| node_instance = cast(BaseNode[BaseNodeData], node_instance) | |||
| max_retries = ( | |||
| node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0 | |||
| ) | |||
| retry_interval = node_instance.node_data.retry_config.retry_interval_seconds | |||
| node_run_result: NodeRunResult | None = None | |||
| for event in generator: | |||
| if isinstance(event, RunCompletedEvent): | |||
| node_run_result = event.run_result | |||
| # sign output files | |||
| node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) | |||
| break | |||
| if not node_run_result: | |||
| raise ValueError("Node run failed with no run result") | |||
| # single step debug mode error handling return | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| if ( | |||
| retries == max_retries | |||
| and node_instance.node_type == NodeType.HTTP_REQUEST | |||
| and node_run_result.outputs | |||
| and not node_instance.should_continue_on_error | |||
| ): | |||
| node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED | |||
| should_retry = False | |||
| else: | |||
| if node_instance.should_retry: | |||
| node_run_result.status = WorkflowNodeExecutionStatus.RETRY | |||
| retries += 1 | |||
| node_run_result.retry_index = retries | |||
| retry_events.append( | |||
| SingleStepRetryEvent( | |||
| inputs=WorkflowEntry.handle_special_values(node_run_result.inputs) | |||
| if node_run_result.inputs | |||
| else None, | |||
| error=node_run_result.error, | |||
| outputs=WorkflowEntry.handle_special_values(node_run_result.outputs) | |||
| if node_run_result.outputs | |||
| else None, | |||
| retry_index=node_run_result.retry_index, | |||
| elapsed_time=time.perf_counter() - retry_start_at, | |||
| execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata) | |||
| if node_run_result.metadata | |||
| else None, | |||
| ) | |||
| ) | |||
| time.sleep(retry_interval) | |||
| else: | |||
| should_retry = False | |||
| if node_instance.should_continue_on_error: | |||
| node_error_args = { | |||
| "status": WorkflowNodeExecutionStatus.EXCEPTION, | |||
| "error": node_run_result.error, | |||
| "inputs": node_run_result.inputs, | |||
| "metadata": {"error_strategy": node_instance.node_data.error_strategy}, | |||
| } | |||
| if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| **node_instance.node_data.default_value_dict, | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| else: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| run_succeeded = node_run_result.status in ( | |||
| WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| WorkflowNodeExecutionStatus.EXCEPTION, | |||
| ) | |||
| error = node_run_result.error if not run_succeeded else None | |||
| except WorkflowNodeRunFailedError as e: | |||
| node_instance = e.node_instance | |||
| run_succeeded = False | |||
| @@ -318,6 +362,7 @@ class WorkflowService: | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| workflow_node_execution.retry_events = retry_events | |||
| return workflow_node_execution | |||
| @@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunStreamChunkEvent, | |||
| ) | |||
| @@ -14,7 +13,9 @@ from models.workflow import WorkflowType | |||
| class ContinueOnErrorTestHelper: | |||
| @staticmethod | |||
| def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| def get_code_node( | |||
| code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} | |||
| ): | |||
| """Helper method to create a code node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| @@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper: | |||
| "code_language": "python3", | |||
| "code": "\n".join([line[4:] for line in code.split("\n")]), | |||
| "type": "code", | |||
| **retry_config, | |||
| }, | |||
| } | |||
| if default_value: | |||
| @@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper: | |||
| @staticmethod | |||
| def get_http_node( | |||
| error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False | |||
| error_strategy: str = "fail-branch", | |||
| default_value: dict | None = None, | |||
| authorization_success: bool = False, | |||
| retry_config: dict = {}, | |||
| ): | |||
| """Helper method to create a http node configuration""" | |||
| authorization = ( | |||
| @@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper: | |||
| "body": None, | |||
| "type": "http-request", | |||
| "error_strategy": error_strategy, | |||
| **retry_config, | |||
| }, | |||
| } | |||
| if default_value: | |||
| @@ -0,0 +1,73 @@ | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunRetryEvent, | |||
| ) | |||
| from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper | |||
| DEFAULT_VALUE_EDGE = [ | |||
| { | |||
| "id": "start-source-node-target", | |||
| "source": "start", | |||
| "target": "node", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-source-answer-target", | |||
| "source": "node", | |||
| "target": "answer", | |||
| "sourceHandle": "source", | |||
| }, | |||
| ] | |||
| def test_retry_default_value_partial_success(): | |||
| """retry default value node with partial success status""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_http_node( | |||
| "default-value", | |||
| [{"key": "result", "type": "string", "value": "http node got error response"}], | |||
| retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 | |||
| assert events[-1].outputs == {"answer": "http node got error response"} | |||
| assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) | |||
| assert len(events) == 11 | |||
| def test_retry_failed(): | |||
| """retry failed with success status""" | |||
| error_code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": 1 / 0, | |||
| } | |||
| """ | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_http_node( | |||
| None, | |||
| None, | |||
| retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 | |||
| assert any(isinstance(e, GraphRunFailedEvent) for e in events) | |||
| assert len(events) == 8 | |||