Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local> Co-authored-by: Novice Lee <novicelee@NoviPro.local>tags/0.14.0
| @@ -19,6 +19,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueMessageReplaceEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| @@ -31,6 +32,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| @@ -317,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| @@ -384,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not workflow_run: | |||
| raise Exception("Workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise Exception("Graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| elif isinstance(event, QueueWorkflowFailedEvent): | |||
| if not workflow_run: | |||
| @@ -401,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| error=event.error, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| @@ -6,6 +6,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueMessageEndEvent, | |||
| QueueStopEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| WorkflowQueueMessage, | |||
| ) | |||
| @@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager): | |||
| | QueueErrorEvent | |||
| | QueueMessageEndEvent | |||
| | QueueWorkflowSucceededEvent | |||
| | QueueWorkflowFailedEvent, | |||
| | QueueWorkflowFailedEvent | |||
| | QueueWorkflowPartialSuccessEvent, | |||
| ): | |||
| self.stop_listen() | |||
| @@ -15,6 +15,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| @@ -26,6 +27,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| @@ -276,7 +278,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| @@ -345,22 +347,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not workflow_run: | |||
| raise Exception("Workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise Exception("Graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| @@ -368,6 +368,60 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): | |||
| if not workflow_run: | |||
| raise Exception("Workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise Exception("Graph runtime state not initialized.") | |||
| handle_args = { | |||
| "workflow_run": workflow_run, | |||
| "start_at": graph_runtime_state.start_at, | |||
| "total_tokens": graph_runtime_state.total_tokens, | |||
| "total_steps": graph_runtime_state.node_run_steps, | |||
| "status": WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| "conversation_id": None, | |||
| "trace_manager": trace_manager, | |||
| "exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| } | |||
| workflow_run = self._handle_workflow_run_failed(**handle_args) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not workflow_run: | |||
| raise Exception("Workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise Exception("Graph runtime state not initialized.") | |||
| handle_args = { | |||
| "workflow_run": workflow_run, | |||
| "start_at": graph_runtime_state.start_at, | |||
| "total_tokens": graph_runtime_state.total_tokens, | |||
| "total_steps": graph_runtime_state.node_run_steps, | |||
| "status": WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| "conversation_id": None, | |||
| "trace_manager": trace_manager, | |||
| "exceptions_count": event.exceptions_count, | |||
| } | |||
| workflow_run = self._handle_workflow_run_partial_success(**handle_args) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| @@ -8,6 +8,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| @@ -18,6 +19,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueRetrieverResourcesEvent, | |||
| QueueTextChunkEvent, | |||
| QueueWorkflowFailedEvent, | |||
| QueueWorkflowPartialSuccessEvent, | |||
| QueueWorkflowStartedEvent, | |||
| QueueWorkflowSucceededEvent, | |||
| ) | |||
| @@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| IterationRunFailedEvent, | |||
| @@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import ( | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeInIterationFailedEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunStartedEvent, | |||
| @@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| ) | |||
| elif isinstance(event, GraphRunSucceededEvent): | |||
| self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) | |||
| elif isinstance(event, GraphRunPartialSucceededEvent): | |||
| self._publish_event( | |||
| QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count) | |||
| ) | |||
| elif isinstance(event, GraphRunFailedEvent): | |||
| self._publish_event(QueueWorkflowFailedEvent(error=event.error)) | |||
| self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) | |||
| elif isinstance(event, NodeRunStartedEvent): | |||
| self._publish_event( | |||
| QueueNodeStartedEvent( | |||
| @@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| in_iteration_id=event.in_iteration_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunExceptionEvent): | |||
| self._publish_event( | |||
| QueueNodeExceptionEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| error=event.route_node_state.node_run_result.error | |||
| if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error | |||
| else "Unknown error", | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_iteration_id=event.in_iteration_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeInIterationFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeInIterationFailedEvent( | |||
| @@ -25,12 +25,14 @@ class QueueEvent(StrEnum): | |||
| WORKFLOW_STARTED = "workflow_started" | |||
| WORKFLOW_SUCCEEDED = "workflow_succeeded" | |||
| WORKFLOW_FAILED = "workflow_failed" | |||
| WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" | |||
| ITERATION_START = "iteration_start" | |||
| ITERATION_NEXT = "iteration_next" | |||
| ITERATION_COMPLETED = "iteration_completed" | |||
| NODE_STARTED = "node_started" | |||
| NODE_SUCCEEDED = "node_succeeded" | |||
| NODE_FAILED = "node_failed" | |||
| NODE_EXCEPTION = "node_exception" | |||
| RETRIEVER_RESOURCES = "retriever_resources" | |||
| ANNOTATION_REPLY = "annotation_reply" | |||
| AGENT_THOUGHT = "agent_thought" | |||
| @@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent): | |||
| event: QueueEvent = QueueEvent.WORKFLOW_FAILED | |||
| error: str | |||
| exceptions_count: int | |||
| class QueueWorkflowPartialSuccessEvent(AppQueueEvent): | |||
| """ | |||
| QueueWorkflowFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED | |||
| exceptions_count: int | |||
| outputs: Optional[dict[str, Any]] = None | |||
| class QueueNodeStartedEvent(AppQueueEvent): | |||
| @@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): | |||
| error: str | |||
| class QueueNodeExceptionEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeExceptionEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_EXCEPTION | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None | |||
| error: str | |||
| class QueueNodeFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeFailedEvent entity | |||
| @@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse): | |||
| created_by: Optional[dict] = None | |||
| created_at: int | |||
| finished_at: int | |||
| exceptions_count: Optional[int] = 0 | |||
| files: Optional[Sequence[Mapping[str, Any]]] = [] | |||
| event: StreamEvent = StreamEvent.WORKFLOW_FINISHED | |||
| @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationCompletedEvent, | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeExceptionEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| @@ -164,6 +165,55 @@ class WorkflowCycleManage: | |||
| return workflow_run | |||
| def _handle_workflow_run_partial_success( | |||
| self, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| total_steps: int, | |||
| outputs: Mapping[str, Any] | None = None, | |||
| exceptions_count: int = 0, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run success | |||
| :param workflow_run: workflow run | |||
| :param start_at: start time | |||
| :param total_tokens: total tokens | |||
| :param total_steps: total steps | |||
| :param outputs: outputs | |||
| :param conversation_id: conversation id | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| outputs = WorkflowEntry.handle_special_values(outputs) | |||
| workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value | |||
| workflow_run.outputs = json.dumps(outputs or {}) | |||
| workflow_run.elapsed_time = time.perf_counter() - start_at | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.WORKFLOW_TRACE, | |||
| workflow_run=workflow_run, | |||
| conversation_id=conversation_id, | |||
| user_id=trace_manager.user_id, | |||
| ) | |||
| ) | |||
| db.session.close() | |||
| return workflow_run | |||
| def _handle_workflow_run_failed( | |||
| self, | |||
| workflow_run: WorkflowRun, | |||
| @@ -174,6 +224,7 @@ class WorkflowCycleManage: | |||
| error: str, | |||
| conversation_id: Optional[str] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| exceptions_count: int = 0, | |||
| ) -> WorkflowRun: | |||
| """ | |||
| Workflow run failed | |||
| @@ -193,7 +244,7 @@ class WorkflowCycleManage: | |||
| workflow_run.total_tokens = total_tokens | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| db.session.commit() | |||
| running_workflow_node_executions = ( | |||
| @@ -318,7 +369,7 @@ class WorkflowCycleManage: | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_failed( | |||
| self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | |||
| self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| @@ -337,7 +388,11 @@ class WorkflowCycleManage: | |||
| ) | |||
| db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( | |||
| { | |||
| WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, | |||
| WorkflowNodeExecution.status: ( | |||
| WorkflowNodeExecutionStatus.FAILED.value | |||
| if not isinstance(event, QueueNodeExceptionEvent) | |||
| else WorkflowNodeExecutionStatus.EXCEPTION.value | |||
| ), | |||
| WorkflowNodeExecution.error: event.error, | |||
| WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, | |||
| WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, | |||
| @@ -351,8 +406,11 @@ class WorkflowCycleManage: | |||
| db.session.commit() | |||
| db.session.close() | |||
| process_data = WorkflowEntry.handle_special_values(event.process_data) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.status = ( | |||
| WorkflowNodeExecutionStatus.FAILED.value | |||
| if not isinstance(event, QueueNodeExceptionEvent) | |||
| else WorkflowNodeExecutionStatus.EXCEPTION.value | |||
| ) | |||
| workflow_node_execution.error = event.error | |||
| workflow_node_execution.inputs = json.dumps(inputs) if inputs else None | |||
| workflow_node_execution.process_data = json.dumps(process_data) if process_data else None | |||
| @@ -433,6 +491,7 @@ class WorkflowCycleManage: | |||
| created_at=int(workflow_run.created_at.timestamp()), | |||
| finished_at=int(workflow_run.finished_at.timestamp()), | |||
| files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), | |||
| exceptions_count=workflow_run.exceptions_count, | |||
| ), | |||
| ) | |||
| @@ -483,7 +542,10 @@ class WorkflowCycleManage: | |||
| def _workflow_node_finish_to_stream_response( | |||
| self, | |||
| event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, | |||
| event: QueueNodeSucceededEvent | |||
| | QueueNodeFailedEvent | |||
| | QueueNodeInIterationFailedEvent | |||
| | QueueNodeExceptionEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| @@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5 | |||
| STATUS_FORCELIST = [429, 500, 502, 503, 504] | |||
| class MaxRetriesExceededError(Exception): | |||
| """Raised when the maximum number of retries is exceeded.""" | |||
| pass | |||
| def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| if "allow_redirects" in kwargs: | |||
| allow_redirects = kwargs.pop("allow_redirects") | |||
| @@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| if retries <= max_retries: | |||
| time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) | |||
| raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") | |||
| raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") | |||
| def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| @@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| IterationRunFailedEvent, | |||
| @@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback): | |||
| self.print_text("\n[GraphRunStartedEvent]", color="pink") | |||
| elif isinstance(event, GraphRunSucceededEvent): | |||
| self.print_text("\n[GraphRunSucceededEvent]", color="green") | |||
| elif isinstance(event, GraphRunPartialSucceededEvent): | |||
| self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink") | |||
| elif isinstance(event, GraphRunFailedEvent): | |||
| self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") | |||
| elif isinstance(event, NodeRunStartedEvent): | |||
| @@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum): | |||
| PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" | |||
| PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" | |||
| ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs | |||
| ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field | |||
| class NodeRunResult(BaseModel): | |||
| @@ -43,3 +44,4 @@ class NodeRunResult(BaseModel): | |||
| edge_source_handle: Optional[str] = None # source handle id of node with multiple branches | |||
| error: Optional[str] = None # error message if status is failed | |||
| error_type: Optional[str] = None # error type if status is failed | |||
| @@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent): | |||
| class GraphRunFailedEvent(BaseGraphEvent): | |||
| error: str = Field(..., description="failed reason") | |||
| exceptions_count: Optional[int] = Field(description="exception count", default=0) | |||
| class GraphRunPartialSucceededEvent(BaseGraphEvent): | |||
| exceptions_count: int = Field(..., description="exception count") | |||
| outputs: Optional[dict[str, Any]] = None | |||
| ########################################### | |||
| @@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| class NodeRunExceptionEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| class NodeInIterationFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| @@ -64,13 +64,21 @@ class Graph(BaseModel): | |||
| edge_configs = graph_config.get("edges") | |||
| if edge_configs is None: | |||
| edge_configs = [] | |||
| # node configs | |||
| node_configs = graph_config.get("nodes") | |||
| if not node_configs: | |||
| raise ValueError("Graph must have at least one node") | |||
| edge_configs = cast(list, edge_configs) | |||
| node_configs = cast(list, node_configs) | |||
| # reorganize edges mapping | |||
| edge_mapping: dict[str, list[GraphEdge]] = {} | |||
| reverse_edge_mapping: dict[str, list[GraphEdge]] = {} | |||
| target_edge_ids = set() | |||
| fail_branch_source_node_id = [ | |||
| node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch" | |||
| ] | |||
| for edge_config in edge_configs: | |||
| source_node_id = edge_config.get("source") | |||
| if not source_node_id: | |||
| @@ -90,8 +98,16 @@ class Graph(BaseModel): | |||
| # parse run condition | |||
| run_condition = None | |||
| if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": | |||
| run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) | |||
| if edge_config.get("sourceHandle"): | |||
| if ( | |||
| edge_config.get("source") in fail_branch_source_node_id | |||
| and edge_config.get("sourceHandle") != "fail-branch" | |||
| ): | |||
| run_condition = RunCondition(type="branch_identify", branch_identify="success-branch") | |||
| elif edge_config.get("sourceHandle") != "source": | |||
| run_condition = RunCondition( | |||
| type="branch_identify", branch_identify=edge_config.get("sourceHandle") | |||
| ) | |||
| graph_edge = GraphEdge( | |||
| source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition | |||
| @@ -100,13 +116,6 @@ class Graph(BaseModel): | |||
| edge_mapping[source_node_id].append(graph_edge) | |||
| reverse_edge_mapping[target_node_id].append(graph_edge) | |||
| # node configs | |||
| node_configs = graph_config.get("nodes") | |||
| if not node_configs: | |||
| raise ValueError("Graph must have at least one node") | |||
| node_configs = cast(list, node_configs) | |||
| # fetch nodes that have no predecessor node | |||
| root_node_configs = [] | |||
| all_node_id_config_mapping: dict[str, dict] = {} | |||
| @@ -15,6 +15,7 @@ class RouteNodeState(BaseModel): | |||
| SUCCESS = "success" | |||
| FAILED = "failed" | |||
| PAUSED = "paused" | |||
| EXCEPTION = "exception" | |||
| id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |||
| """node state id""" | |||
| @@ -51,7 +52,11 @@ class RouteNodeState(BaseModel): | |||
| :param run_result: run result | |||
| """ | |||
| if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: | |||
| if self.status in { | |||
| RouteNodeState.Status.SUCCESS, | |||
| RouteNodeState.Status.FAILED, | |||
| RouteNodeState.Status.EXCEPTION, | |||
| }: | |||
| raise Exception(f"Route state {self.id} already finished") | |||
| if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| @@ -59,6 +64,9 @@ class RouteNodeState(BaseModel): | |||
| elif run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| self.status = RouteNodeState.Status.FAILED | |||
| self.failed_reason = run_result.error | |||
| elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: | |||
| self.status = RouteNodeState.Status.EXCEPTION | |||
| self.failed_reason = run_result.error | |||
| else: | |||
| raise Exception(f"Invalid route status {run_result.status}") | |||
| @@ -5,21 +5,23 @@ import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from concurrent.futures import ThreadPoolExecutor, wait | |||
| from copy import copy, deepcopy | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, cast | |||
| from flask import Flask, current_app | |||
| from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||
| from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseIterationEvent, | |||
| GraphEngineEvent, | |||
| GraphRunFailedEvent, | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunStartedEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunStartedEvent, | |||
| @@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.base.entities import BaseNodeData | |||
| from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | |||
| from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | |||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | |||
| from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from extensions.ext_database import db | |||
| @@ -128,6 +132,7 @@ class GraphEngine: | |||
| def run(self) -> Generator[GraphEngineEvent, None, None]: | |||
| # trigger graph run start event | |||
| yield GraphRunStartedEvent() | |||
| handle_exceptions = [] | |||
| try: | |||
| if self.init_params.workflow_type == WorkflowType.CHAT: | |||
| @@ -140,13 +145,17 @@ class GraphEngine: | |||
| ) | |||
| # run graph | |||
| generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) | |||
| generator = stream_processor.process( | |||
| self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) | |||
| ) | |||
| for item in generator: | |||
| try: | |||
| yield item | |||
| if isinstance(item, NodeRunFailedEvent): | |||
| yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") | |||
| yield GraphRunFailedEvent( | |||
| error=item.route_node_state.failed_reason or "Unknown error.", | |||
| exceptions_count=len(handle_exceptions), | |||
| ) | |||
| return | |||
| elif isinstance(item, NodeRunSucceededEvent): | |||
| if item.node_type == NodeType.END: | |||
| @@ -172,19 +181,24 @@ class GraphEngine: | |||
| ].strip() | |||
| except Exception as e: | |||
| logger.exception("Graph run failed") | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) | |||
| return | |||
| # trigger graph run success event | |||
| yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) | |||
| # count exceptions to determine partial success | |||
| if len(handle_exceptions) > 0: | |||
| yield GraphRunPartialSucceededEvent( | |||
| exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs | |||
| ) | |||
| else: | |||
| # trigger graph run success event | |||
| yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) | |||
| self._release_thread() | |||
| except GraphRunFailedError as e: | |||
| yield GraphRunFailedEvent(error=e.error) | |||
| yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions)) | |||
| self._release_thread() | |||
| return | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when graph running") | |||
| yield GraphRunFailedEvent(error=str(e)) | |||
| yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) | |||
| self._release_thread() | |||
| raise e | |||
| @@ -198,6 +212,7 @@ class GraphEngine: | |||
| in_parallel_id: Optional[str] = None, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None, | |||
| handle_exceptions: list[str] = [], | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| parallel_start_node_id = None | |||
| if in_parallel_id: | |||
| @@ -242,7 +257,7 @@ class GraphEngine: | |||
| previous_node_id=previous_node_id, | |||
| thread_pool_id=self.thread_pool_id, | |||
| ) | |||
| node_instance = cast(BaseNode[BaseNodeData], node_instance) | |||
| try: | |||
| # run node | |||
| generator = self._run_node( | |||
| @@ -252,6 +267,7 @@ class GraphEngine: | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| for item in generator: | |||
| @@ -301,7 +317,12 @@ class GraphEngine: | |||
| if len(edge_mappings) == 1: | |||
| edge = edge_mappings[0] | |||
| if ( | |||
| previous_route_node_state.status == RouteNodeState.Status.EXCEPTION | |||
| and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH | |||
| and edge.run_condition is None | |||
| ): | |||
| break | |||
| if edge.run_condition: | |||
| result = ConditionManager.get_condition_handler( | |||
| init_params=self.init_params, | |||
| @@ -334,7 +355,7 @@ class GraphEngine: | |||
| if len(sub_edge_mappings) == 0: | |||
| continue | |||
| edge = sub_edge_mappings[0] | |||
| edge = cast(GraphEdge, sub_edge_mappings[0]) | |||
| result = ConditionManager.get_condition_handler( | |||
| init_params=self.init_params, | |||
| @@ -355,6 +376,7 @@ class GraphEngine: | |||
| edge_mappings=sub_edge_mappings, | |||
| in_parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| for item in parallel_generator: | |||
| @@ -369,11 +391,18 @@ class GraphEngine: | |||
| break | |||
| next_node_id = final_node_id | |||
| elif ( | |||
| node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH | |||
| and node_instance.should_continue_on_error | |||
| and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION | |||
| ): | |||
| break | |||
| else: | |||
| parallel_generator = self._run_parallel_branches( | |||
| edge_mappings=edge_mappings, | |||
| in_parallel_id=in_parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| for item in parallel_generator: | |||
| @@ -395,6 +424,7 @@ class GraphEngine: | |||
| edge_mappings: list[GraphEdge], | |||
| in_parallel_id: Optional[str] = None, | |||
| parallel_start_node_id: Optional[str] = None, | |||
| handle_exceptions: list[str] = [], | |||
| ) -> Generator[GraphEngineEvent | str, None, None]: | |||
| # if nodes has no run conditions, parallel run all nodes | |||
| parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) | |||
| @@ -438,6 +468,7 @@ class GraphEngine: | |||
| "parallel_start_node_id": edge.target_node_id, | |||
| "parent_parallel_id": in_parallel_id, | |||
| "parent_parallel_start_node_id": parallel_start_node_id, | |||
| "handle_exceptions": handle_exceptions, | |||
| }, | |||
| ) | |||
| @@ -481,6 +512,7 @@ class GraphEngine: | |||
| parallel_start_node_id: str, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None, | |||
| handle_exceptions: list[str] = [], | |||
| ) -> None: | |||
| """ | |||
| Run parallel nodes | |||
| @@ -502,6 +534,7 @@ class GraphEngine: | |||
| in_parallel_id=parallel_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| for item in generator: | |||
| @@ -548,6 +581,7 @@ class GraphEngine: | |||
| parallel_start_node_id: Optional[str] = None, | |||
| parent_parallel_id: Optional[str] = None, | |||
| parent_parallel_start_node_id: Optional[str] = None, | |||
| handle_exceptions: list[str] = [], | |||
| ) -> Generator[GraphEngineEvent, None, None]: | |||
| """ | |||
| Run node | |||
| @@ -587,19 +621,55 @@ class GraphEngine: | |||
| route_node_state.set_finished(run_result=run_result) | |||
| if run_result.status == WorkflowNodeExecutionStatus.FAILED: | |||
| yield NodeRunFailedEvent( | |||
| error=route_node_state.failed_reason or "Unknown error.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| if node_instance.should_continue_on_error: | |||
| # if run failed, handle error | |||
| run_result = self._handle_continue_on_error( | |||
| node_instance, | |||
| item.run_result, | |||
| self.graph_runtime_state.variable_pool, | |||
| handle_exceptions=handle_exceptions, | |||
| ) | |||
| route_node_state.node_run_result = run_result | |||
| route_node_state.status = RouteNodeState.Status.EXCEPTION | |||
| if run_result.outputs: | |||
| for variable_key, variable_value in run_result.outputs.items(): | |||
| # append variables to variable pool recursively | |||
| self._append_variables_recursively( | |||
| node_id=node_instance.node_id, | |||
| variable_key_list=[variable_key], | |||
| variable_value=variable_value, | |||
| ) | |||
| yield NodeRunExceptionEvent( | |||
| error=run_result.error or "System Error", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| else: | |||
| yield NodeRunFailedEvent( | |||
| error=route_node_state.failed_reason or "Unknown error.", | |||
| id=node_instance.id, | |||
| node_id=node_instance.node_id, | |||
| node_type=node_instance.node_type, | |||
| node_data=node_instance.node_data, | |||
| route_node_state=route_node_state, | |||
| parallel_id=parallel_id, | |||
| parallel_start_node_id=parallel_start_node_id, | |||
| parent_parallel_id=parent_parallel_id, | |||
| parent_parallel_start_node_id=parent_parallel_start_node_id, | |||
| ) | |||
| elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| if node_instance.should_continue_on_error and self.graph.edge_mapping.get( | |||
| node_instance.node_id | |||
| ): | |||
| run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS | |||
| if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): | |||
| # plus state total_tokens | |||
| self.graph_runtime_state.total_tokens += int( | |||
| @@ -735,6 +805,56 @@ class GraphEngine: | |||
| new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) | |||
| return new_instance | |||
| def _handle_continue_on_error( | |||
| self, | |||
| node_instance: BaseNode[BaseNodeData], | |||
| error_result: NodeRunResult, | |||
| variable_pool: VariablePool, | |||
| handle_exceptions: list[str] = [], | |||
| ) -> NodeRunResult: | |||
| """ | |||
| handle continue on error when self._should_continue_on_error is True | |||
| :param error_result (NodeRunResult): error run result | |||
| :param variable_pool (VariablePool): variable pool | |||
| :return: excption run result | |||
| """ | |||
| # add error message and error type to variable pool | |||
| variable_pool.add([node_instance.node_id, "error_message"], error_result.error) | |||
| variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) | |||
| # add error message to handle_exceptions | |||
| handle_exceptions.append(error_result.error) | |||
| node_error_args = { | |||
| "status": WorkflowNodeExecutionStatus.EXCEPTION, | |||
| "error": error_result.error, | |||
| "inputs": error_result.inputs, | |||
| "metadata": { | |||
| NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, | |||
| }, | |||
| } | |||
| if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: | |||
| return NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| **node_instance.node_data.default_value_dict, | |||
| "error_message": error_result.error, | |||
| "error_type": error_result.error_type, | |||
| }, | |||
| ) | |||
| elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: | |||
| if self.graph.edge_mapping.get(node_instance.node_id): | |||
| node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED | |||
| return NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| "error_message": error_result.error, | |||
| "error_type": error_result.error_type, | |||
| }, | |||
| ) | |||
| return error_result | |||
| class GraphRunFailedError(Exception): | |||
| def __init__(self, error: str): | |||
| @@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import ( | |||
| TextGenerateRouteChunk, | |||
| VarGenerateRouteChunk, | |||
| ) | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.enums import ErrorStrategy, NodeType | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| @@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter: | |||
| for edge in reverse_edges: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") | |||
| if source_node_type in { | |||
| NodeType.ANSWER, | |||
| NodeType.IF_ELSE, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| NodeType.ITERATION, | |||
| NodeType.VARIABLE_ASSIGNER, | |||
| }: | |||
| source_node_data = node_id_config_mapping[source_node_id].get("data", {}) | |||
| if ( | |||
| source_node_type | |||
| in { | |||
| NodeType.ANSWER, | |||
| NodeType.IF_ELSE, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| NodeType.ITERATION, | |||
| NodeType.VARIABLE_ASSIGNER, | |||
| } | |||
| or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH | |||
| ): | |||
| answer_dependencies[answer_node_id].append(source_node_id) | |||
| else: | |||
| cls._recursive_fetch_answer_dependencies( | |||
| @@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphEngineEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| @@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor): | |||
| for _ in stream_out_answer_node_ids: | |||
| yield event | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent): | |||
| yield event | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| # update self.route_position after all stream event finished | |||
| @@ -1,14 +1,124 @@ | |||
| import json | |||
| from abc import ABC | |||
| from typing import Optional | |||
| from enum import StrEnum | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, model_validator | |||
| from core.workflow.nodes.base.exc import DefaultValueTypeError | |||
| from core.workflow.nodes.enums import ErrorStrategy | |||
| class DefaultValueType(StrEnum): | |||
| STRING = "string" | |||
| NUMBER = "number" | |||
| OBJECT = "object" | |||
| ARRAY_NUMBER = "array[number]" | |||
| ARRAY_STRING = "array[string]" | |||
| ARRAY_OBJECT = "array[object]" | |||
| ARRAY_FILES = "array[file]" | |||
| NumberType = Union[int, float] | |||
| class DefaultValue(BaseModel): | |||
| value: Any | |||
| type: DefaultValueType | |||
| key: str | |||
| @staticmethod | |||
| def _parse_json(value: str) -> Any: | |||
| """Unified JSON parsing handler""" | |||
| try: | |||
| return json.loads(value) | |||
| except json.JSONDecodeError: | |||
| raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") | |||
| @staticmethod | |||
| def _validate_array(value: Any, element_type: DefaultValueType) -> bool: | |||
| """Unified array type validation""" | |||
| return isinstance(value, list) and all(isinstance(x, element_type) for x in value) | |||
| @staticmethod | |||
| def _convert_number(value: str) -> float: | |||
| """Unified number conversion handler""" | |||
| try: | |||
| return float(value) | |||
| except ValueError: | |||
| raise DefaultValueTypeError(f"Cannot convert to number: {value}") | |||
| @model_validator(mode="after") | |||
| def validate_value_type(self) -> "DefaultValue": | |||
| if self.type is None: | |||
| raise DefaultValueTypeError("type field is required") | |||
| # Type validation configuration | |||
| type_validators = { | |||
| DefaultValueType.STRING: { | |||
| "type": str, | |||
| "converter": lambda x: x, | |||
| }, | |||
| DefaultValueType.NUMBER: { | |||
| "type": NumberType, | |||
| "converter": self._convert_number, | |||
| }, | |||
| DefaultValueType.OBJECT: { | |||
| "type": dict, | |||
| "converter": self._parse_json, | |||
| }, | |||
| DefaultValueType.ARRAY_NUMBER: { | |||
| "type": list, | |||
| "element_type": NumberType, | |||
| "converter": self._parse_json, | |||
| }, | |||
| DefaultValueType.ARRAY_STRING: { | |||
| "type": list, | |||
| "element_type": str, | |||
| "converter": self._parse_json, | |||
| }, | |||
| DefaultValueType.ARRAY_OBJECT: { | |||
| "type": list, | |||
| "element_type": dict, | |||
| "converter": self._parse_json, | |||
| }, | |||
| } | |||
| validator = type_validators.get(self.type) | |||
| if not validator: | |||
| if self.type == DefaultValueType.ARRAY_FILES: | |||
| # Handle files type | |||
| return self | |||
| raise DefaultValueTypeError(f"Unsupported type: {self.type}") | |||
| # Handle string input cases | |||
| if isinstance(self.value, str) and self.type != DefaultValueType.STRING: | |||
| self.value = validator["converter"](self.value) | |||
| # Validate base type | |||
| if not isinstance(self.value, validator["type"]): | |||
| raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") | |||
| # Validate array element types | |||
| if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): | |||
| raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") | |||
| return self | |||
| class BaseNodeData(ABC, BaseModel): | |||
| title: str | |||
| desc: Optional[str] = None | |||
| error_strategy: Optional[ErrorStrategy] = None | |||
| default_value: Optional[list[DefaultValue]] = None | |||
| version: str = "1" | |||
| @property | |||
| def default_value_dict(self): | |||
| if self.default_value: | |||
| return {item.key: item.value for item in self.default_value} | |||
| return {} | |||
| class BaseIterationNodeData(BaseNodeData): | |||
| start_node_id: Optional[str] = None | |||
| @@ -0,0 +1,10 @@ | |||
| class BaseNodeError(Exception): | |||
| """Base class for node errors.""" | |||
| pass | |||
| class DefaultValueTypeError(BaseNodeError): | |||
| """Raised when the default value type is invalid.""" | |||
| pass | |||
| @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| @@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]): | |||
| result = self._run() | |||
| except Exception as e: | |||
| logger.exception(f"Node {self.node_id} failed to run") | |||
| result = NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| ) | |||
| result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") | |||
| if isinstance(result, NodeRunResult): | |||
| yield RunCompletedEvent(run_result=result) | |||
| @@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]): | |||
| :return: | |||
| """ | |||
| return self._node_type | |||
| @property | |||
| def should_continue_on_error(self) -> bool: | |||
| """judge if should continue on error | |||
| Returns: | |||
| bool: if should continue on error | |||
| """ | |||
| return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE | |||
| @@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]): | |||
| # Transform result | |||
| result = self._transform_result(result, self.node_data.outputs) | |||
| except (CodeExecutionError, CodeNodeError) as e: | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ | |||
| ) | |||
| return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) | |||
| @@ -22,3 +22,16 @@ class NodeType(StrEnum): | |||
| VARIABLE_ASSIGNER = "assigner" | |||
| DOCUMENT_EXTRACTOR = "document-extractor" | |||
| LIST_OPERATOR = "list-operator" | |||
| class ErrorStrategy(StrEnum): | |||
| FAIL_BRANCH = "fail-branch" | |||
| DEFAULT_VALUE = "default-value" | |||
| class FailBranchSourceHandle(StrEnum): | |||
| FAILED = "fail-branch" | |||
| SUCCESS = "success-branch" | |||
| CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] | |||
| @@ -21,6 +21,7 @@ from .entities import ( | |||
| from .exc import ( | |||
| AuthorizationConfigError, | |||
| FileFetchError, | |||
| HttpRequestNodeError, | |||
| InvalidHttpMethodError, | |||
| ResponseSizeError, | |||
| ) | |||
| @@ -208,8 +209,10 @@ class Executor: | |||
| "follow_redirects": True, | |||
| } | |||
| # request_args = {k: v for k, v in request_args.items() if v is not None} | |||
| response = getattr(ssrf_proxy, self.method)(**request_args) | |||
| try: | |||
| response = getattr(ssrf_proxy, self.method)(**request_args) | |||
| except ssrf_proxy.MaxRetriesExceededError as e: | |||
| raise HttpRequestNodeError(str(e)) | |||
| return response | |||
| def invoke(self) -> Response: | |||
| @@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| response = http_executor.invoke() | |||
| files = self.extract_files(url=http_executor.url, response=response) | |||
| if not response.response.is_success and self.should_continue_on_error: | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| outputs={ | |||
| "status_code": response.status_code, | |||
| "body": response.text if not files else "", | |||
| "headers": response.headers, | |||
| "files": files, | |||
| }, | |||
| process_data={ | |||
| "request": http_executor.to_log(), | |||
| }, | |||
| error=f"Request failed with status code {response.status_code}", | |||
| error_type="HTTPResponseCodeError", | |||
| ) | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| outputs={ | |||
| @@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| process_data=process_data, | |||
| error_type=type(e).__name__, | |||
| ) | |||
| @staticmethod | |||
| @@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| error=str(e), | |||
| inputs=node_inputs, | |||
| process_data=process_data, | |||
| error_type=type(e).__name__, | |||
| ) | |||
| ) | |||
| return | |||
| @@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode): | |||
| "usage": jsonable_encoder(usage), | |||
| "finish_reason": finish_reason, | |||
| } | |||
| outputs = {"class_name": category_name} | |||
| outputs = {"class_name": category_name, "class_id": category_id} | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| @@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| NodeRunMetadataKey.TOOL_INFO: tool_info, | |||
| }, | |||
| error=f"Failed to get tool runtime: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| # get parameters | |||
| @@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| NodeRunMetadataKey.TOOL_INFO: tool_info, | |||
| }, | |||
| error=f"Failed to invoke tool: {str(e)}", | |||
| error_type=type(e).__name__, | |||
| ) | |||
| # convert tool messages | |||
| @@ -14,6 +14,7 @@ workflow_run_for_log_fields = { | |||
| "total_steps": fields.Integer, | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| workflow_run_for_list_fields = { | |||
| @@ -27,6 +28,7 @@ workflow_run_for_list_fields = { | |||
| "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| advanced_chat_workflow_run_for_list_fields = { | |||
| @@ -42,6 +44,7 @@ advanced_chat_workflow_run_for_list_fields = { | |||
| "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| advanced_chat_workflow_run_pagination_fields = { | |||
| @@ -73,6 +76,7 @@ workflow_run_detail_fields = { | |||
| "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), | |||
| "created_at": TimestampField, | |||
| "finished_at": TimestampField, | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| workflow_run_node_execution_fields = { | |||
| @@ -0,0 +1,33 @@ | |||
| """add exceptions_count field to WorkflowRun model | |||
| Revision ID: cf8f4fc45278 | |||
| Revises: 01d6889832f7 | |||
| Create Date: 2024-11-28 05:53:21.576178 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'cf8f4fc45278' | |||
| down_revision = '01d6889832f7' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.drop_column('exceptions_count') | |||
| # ### end Alembic commands ### | |||
| @@ -325,6 +325,7 @@ class WorkflowRunStatus(StrEnum): | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| STOPPED = "stopped" | |||
| PARTIAL_SUCCESSED = "partial-succeeded" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "WorkflowRunStatus": | |||
| @@ -395,7 +396,7 @@ class WorkflowRun(db.Model): | |||
| version = db.Column(db.String(255), nullable=False) | |||
| graph = db.Column(db.Text) | |||
| inputs = db.Column(db.Text) | |||
| status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped | |||
| status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded | |||
| outputs: Mapped[str] = mapped_column(sa.Text, default="{}") | |||
| error = db.Column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| @@ -405,6 +406,7 @@ class WorkflowRun(db.Model): | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| finished_at = db.Column(db.DateTime) | |||
| exceptions_count = db.Column(db.Integer, server_default=db.text("0")) | |||
| @property | |||
| def created_by_account(self): | |||
| @@ -464,6 +466,7 @@ class WorkflowRun(db.Model): | |||
| "created_by": self.created_by, | |||
| "created_at": self.created_at, | |||
| "finished_at": self.finished_at, | |||
| "exceptions_count": self.exceptions_count, | |||
| } | |||
| @classmethod | |||
| @@ -489,6 +492,7 @@ class WorkflowRun(db.Model): | |||
| created_by=data.get("created_by"), | |||
| created_at=data.get("created_at"), | |||
| finished_at=data.get("finished_at"), | |||
| exceptions_count=data.get("exceptions_count"), | |||
| ) | |||
| @@ -522,6 +526,7 @@ class WorkflowNodeExecutionStatus(Enum): | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| EXCEPTION = "exception" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": | |||
| @@ -2,7 +2,7 @@ import json | |||
| import time | |||
| from collections.abc import Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import Optional | |||
| from typing import Optional, cast | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| @@ -11,6 +11,9 @@ from core.variables import Variable | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.errors import WorkflowNodeRunFailedError | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.base.entities import BaseNodeData | |||
| from core.workflow.nodes.base.node import BaseNode | |||
| from core.workflow.nodes.enums import ErrorStrategy | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| @@ -225,7 +228,7 @@ class WorkflowService: | |||
| user_inputs=user_inputs, | |||
| user_id=account.id, | |||
| ) | |||
| node_instance = cast(BaseNode[BaseNodeData], node_instance) | |||
| node_run_result: NodeRunResult | None = None | |||
| for event in generator: | |||
| if isinstance(event, RunCompletedEvent): | |||
| @@ -237,8 +240,35 @@ class WorkflowService: | |||
| if not node_run_result: | |||
| raise ValueError("Node run failed with no run result") | |||
| run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False | |||
| # single step debug mode error handling return | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: | |||
| node_error_args = { | |||
| "status": WorkflowNodeExecutionStatus.EXCEPTION, | |||
| "error": node_run_result.error, | |||
| "inputs": node_run_result.inputs, | |||
| "metadata": {"error_strategy": node_instance.node_data.error_strategy}, | |||
| } | |||
| if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| **node_instance.node_data.default_value_dict, | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| else: | |||
| node_run_result = NodeRunResult( | |||
| **node_error_args, | |||
| outputs={ | |||
| "error_message": node_run_result.error, | |||
| "error_type": node_run_result.error_type, | |||
| }, | |||
| ) | |||
| run_succeeded = node_run_result.status in ( | |||
| WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| WorkflowNodeExecutionStatus.EXCEPTION, | |||
| ) | |||
| error = node_run_result.error if not run_succeeded else None | |||
| except WorkflowNodeRunFailedError as e: | |||
| node_instance = e.node_instance | |||
| @@ -260,7 +290,6 @@ class WorkflowService: | |||
| workflow_node_execution.created_by = account.id | |||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| if run_succeeded and node_run_result: | |||
| # create workflow node execution | |||
| inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None | |||
| @@ -277,7 +306,11 @@ class WorkflowService: | |||
| workflow_node_execution.execution_metadata = ( | |||
| json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None | |||
| ) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value | |||
| elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value | |||
| workflow_node_execution.error = node_run_result.error | |||
| else: | |||
| # create workflow node execution | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| @@ -0,0 +1,502 @@ | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| GraphRunPartialSucceededEvent, | |||
| GraphRunSucceededEvent, | |||
| NodeRunExceptionEvent, | |||
| NodeRunStreamChunkEvent, | |||
| ) | |||
| from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| from models.enums import UserFrom | |||
| from models.workflow import WorkflowType | |||
| class ContinueOnErrorTestHelper: | |||
| @staticmethod | |||
| def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a code node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "outputs": {"result": {"type": "number"}}, | |||
| "error_strategy": error_strategy, | |||
| "title": "code", | |||
| "variables": [], | |||
| "code_language": "python3", | |||
| "code": "\n".join([line[4:] for line in code.split("\n")]), | |||
| "type": "code", | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_http_node( | |||
| error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False | |||
| ): | |||
| """Helper method to create a http node configuration""" | |||
| authorization = ( | |||
| { | |||
| "type": "api-key", | |||
| "config": { | |||
| "type": "basic", | |||
| "api_key": "ak-xxx", | |||
| "header": "api-key", | |||
| }, | |||
| } | |||
| if authorization_success | |||
| else { | |||
| "type": "api-key", | |||
| # missing config field | |||
| } | |||
| ) | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "http", | |||
| "desc": "", | |||
| "method": "get", | |||
| "url": "http://example.com", | |||
| "authorization": authorization, | |||
| "headers": "X-Header:123", | |||
| "params": "A:b", | |||
| "body": None, | |||
| "type": "http-request", | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a http node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "type": "http-request", | |||
| "title": "HTTP Request", | |||
| "desc": "", | |||
| "variables": [], | |||
| "method": "get", | |||
| "url": "https://api.github.com/issues", | |||
| "authorization": {"type": "no-auth", "config": None}, | |||
| "headers": "", | |||
| "params": "", | |||
| "body": {"type": "none", "data": []}, | |||
| "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a tool node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "a", | |||
| "desc": "a", | |||
| "provider_id": "maths", | |||
| "provider_type": "builtin", | |||
| "provider_name": "maths", | |||
| "tool_name": "eval_expression", | |||
| "tool_label": "eval_expression", | |||
| "tool_configurations": {}, | |||
| "tool_parameters": { | |||
| "expression": { | |||
| "type": "variable", | |||
| "value": ["1", "123", "args1"], | |||
| } | |||
| }, | |||
| "type": "tool", | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): | |||
| """Helper method to create a llm node configuration""" | |||
| node = { | |||
| "id": "node", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, | |||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| "error_strategy": error_strategy, | |||
| }, | |||
| } | |||
| if default_value: | |||
| node["data"]["default_value"] = default_value | |||
| return node | |||
| @staticmethod | |||
| def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): | |||
| """Helper method to create a graph engine instance for testing""" | |||
| graph = Graph.init(graph_config=graph_config) | |||
| variable_pool = { | |||
| "system_variables": { | |||
| SystemVariableKey.QUERY: "clear", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "aaa", | |||
| }, | |||
| "user_inputs": user_inputs or {"uid": "takato"}, | |||
| } | |||
| return GraphEngine( | |||
| tenant_id="111", | |||
| app_id="222", | |||
| workflow_type=WorkflowType.CHAT, | |||
| workflow_id="333", | |||
| graph_config=graph_config, | |||
| user_id="444", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| call_depth=0, | |||
| graph=graph, | |||
| variable_pool=variable_pool, | |||
| max_execution_steps=500, | |||
| max_execution_time=1200, | |||
| ) | |||
| DEFAULT_VALUE_EDGE = [ | |||
| { | |||
| "id": "start-source-node-target", | |||
| "source": "start", | |||
| "target": "node", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-source-answer-target", | |||
| "source": "node", | |||
| "target": "answer", | |||
| "sourceHandle": "source", | |||
| }, | |||
| ] | |||
| FAIL_BRANCH_EDGES = [ | |||
| { | |||
| "id": "start-source-node-target", | |||
| "source": "start", | |||
| "target": "node", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-true-success-target", | |||
| "source": "node", | |||
| "target": "success", | |||
| "sourceHandle": "source", | |||
| }, | |||
| { | |||
| "id": "node-false-error-target", | |||
| "source": "node", | |||
| "target": "error", | |||
| "sourceHandle": "fail-branch", | |||
| }, | |||
| ] | |||
| def test_code_default_value_continue_on_error(): | |||
| error_code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": 1 / 0, | |||
| } | |||
| """ | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_code_node( | |||
| error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_code_fail_branch_continue_on_error(): | |||
| error_code = """ | |||
| def main() -> dict: | |||
| return { | |||
| "result": 1 / 0, | |||
| } | |||
| """ | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_code_node(error_code), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events | |||
| ) | |||
| def test_http_node_default_value_continue_on_error(): | |||
| """Test HTTP node with default value error strategy""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_http_node( | |||
| "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} | |||
| for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_http_node_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_tool_node_default_value_continue_on_error(): | |||
| """Test tool node with default value error strategy""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_tool_node( | |||
| "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_tool_node_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_tool_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_llm_node_default_value_continue_on_error(): | |||
| """Test LLM node with default value error strategy""" | |||
| graph_config = { | |||
| "edges": DEFAULT_VALUE_EDGE, | |||
| "nodes": [ | |||
| {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, | |||
| {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, | |||
| ContinueOnErrorTestHelper.get_llm_node( | |||
| "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] | |||
| ), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_llm_node_fail_branch_continue_on_error(): | |||
| """Test LLM node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_llm_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_status_code_error_http_node_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_error_status_code_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any( | |||
| isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events | |||
| ) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 | |||
| def test_variable_pool_error_type_variable(): | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES, | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, | |||
| "id": "success", | |||
| }, | |||
| { | |||
| "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, | |||
| "id": "error", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_error_status_code_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| list(graph_engine.run()) | |||
| error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) | |||
| error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) | |||
| assert error_message != None | |||
| assert error_type.value == "HTTPResponseCodeError" | |||
| def test_no_node_in_fail_branch_continue_on_error(): | |||
| """Test HTTP node with fail-branch error strategy""" | |||
| graph_config = { | |||
| "edges": FAIL_BRANCH_EDGES[:-1], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, | |||
| "id": "success", | |||
| }, | |||
| ContinueOnErrorTestHelper.get_http_node(), | |||
| ], | |||
| } | |||
| graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) | |||
| events = list(graph_engine.run()) | |||
| assert any(isinstance(e, NodeRunExceptionEvent) for e in events) | |||
| assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) | |||
| assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 | |||