| @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationStartEvent, | |||
| QueueMessageReplaceEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent): | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueNodeFailedEvent): | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response = self._workflow_node_finish_to_stream_response( | |||
| @@ -9,6 +9,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import ( | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeInIterationFailedEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunRetrieverResourceEvent, | |||
| NodeRunStartedEvent, | |||
| @@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| node_run_index=event.route_node_state.index, | |||
| predecessor_node_id=event.predecessor_node_id, | |||
| in_iteration_id=event.in_iteration_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunSucceededEvent): | |||
| @@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| error=event.route_node_state.node_run_result.error | |||
| if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error | |||
| else "Unknown error", | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_iteration_id=event.in_iteration_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeInIterationFailedEvent): | |||
| self._publish_event( | |||
| QueueNodeInIterationFailedEvent( | |||
| node_execution_id=event.id, | |||
| node_id=event.node_id, | |||
| node_type=event.node_type, | |||
| node_data=event.node_data, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| start_at=event.route_node_state.start_at, | |||
| inputs=event.route_node_state.node_run_result.inputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| process_data=event.route_node_state.node_run_result.process_data | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| outputs=event.route_node_state.node_run_result.outputs | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| execution_metadata=event.route_node_state.node_run_result.metadata | |||
| if event.route_node_state.node_run_result | |||
| else {}, | |||
| in_iteration_id=event.in_iteration_id, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| elif isinstance(event, NodeRunStreamChunkEvent): | |||
| self._publish_event( | |||
| QueueTextChunkEvent( | |||
| @@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| index=event.index, | |||
| node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, | |||
| output=event.pre_iteration_output, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| ) | |||
| ) | |||
| elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): | |||
| @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent): | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteratoin run in parallel mode run id""" | |||
| node_run_index: int | |||
| output: Optional[Any] = None # output for the current iteration | |||
| @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent): | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteratoin run in parallel mode run id""" | |||
| class QueueNodeSucceededEvent(AppQueueEvent): | |||
| @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): | |||
| error: Optional[str] = None | |||
| class QueueNodeInIterationFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeInIterationFailedEvent entity | |||
| """ | |||
| event: QueueEvent = QueueEvent.NODE_FAILED | |||
| node_execution_id: str | |||
| node_id: str | |||
| node_type: NodeType | |||
| node_data: BaseNodeData | |||
| parallel_id: Optional[str] = None | |||
| """parallel id if node is in parallel""" | |||
| parallel_start_node_id: Optional[str] = None | |||
| """parallel start node id if node is in parallel""" | |||
| parent_parallel_id: Optional[str] = None | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| start_at: datetime | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None | |||
| error: str | |||
| class QueueNodeFailedEvent(AppQueueEvent): | |||
| """ | |||
| QueueNodeFailedEvent entity | |||
| @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent): | |||
| inputs: Optional[dict[str, Any]] = None | |||
| process_data: Optional[dict[str, Any]] = None | |||
| outputs: Optional[dict[str, Any]] = None | |||
| execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None | |||
| error: str | |||
| @@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse): | |||
| parent_parallel_id: Optional[str] = None | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| iteration_id: Optional[str] = None | |||
| parallel_run_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.NODE_STARTED | |||
| workflow_run_id: str | |||
| @@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse): | |||
| extras: dict = {} | |||
| parallel_id: Optional[str] = None | |||
| parallel_start_node_id: Optional[str] = None | |||
| parallel_mode_run_id: Optional[str] = None | |||
| event: StreamEvent = StreamEvent.ITERATION_NEXT | |||
| workflow_run_id: str | |||
| @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( | |||
| QueueIterationNextEvent, | |||
| QueueIterationStartEvent, | |||
| QueueNodeFailedEvent, | |||
| QueueNodeInIterationFailedEvent, | |||
| QueueNodeStartedEvent, | |||
| QueueNodeSucceededEvent, | |||
| QueueParallelBranchRunFailedEvent, | |||
| @@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.tool.entities import ToolNodeData | |||
| @@ -251,6 +253,12 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value | |||
| workflow_node_execution.created_by_role = workflow_run.created_by_role | |||
| workflow_node_execution.created_by = workflow_run.created_by | |||
| workflow_node_execution.execution_metadata = json.dumps( | |||
| { | |||
| NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, | |||
| NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, | |||
| } | |||
| ) | |||
| workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| session.add(workflow_node_execution) | |||
| @@ -305,7 +313,9 @@ class WorkflowCycleManage: | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: | |||
| def _handle_workflow_node_execution_failed( | |||
| self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Workflow node execution failed | |||
| :param event: queue node failed event | |||
| @@ -318,16 +328,19 @@ class WorkflowCycleManage: | |||
| outputs = WorkflowEntry.handle_special_values(event.outputs) | |||
| finished_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| elapsed_time = (finished_at - event.start_at).total_seconds() | |||
| execution_metadata = ( | |||
| json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None | |||
| ) | |||
| db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( | |||
| { | |||
| WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, | |||
| WorkflowNodeExecution.error: event.error, | |||
| WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, | |||
| WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, | |||
| WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, | |||
| WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, | |||
| WorkflowNodeExecution.finished_at: finished_at, | |||
| WorkflowNodeExecution.elapsed_time: elapsed_time, | |||
| WorkflowNodeExecution.execution_metadata: execution_metadata, | |||
| } | |||
| ) | |||
| @@ -342,6 +355,7 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.outputs = json.dumps(outputs) if outputs else None | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) | |||
| @@ -448,6 +462,7 @@ class WorkflowCycleManage: | |||
| parent_parallel_id=event.parent_parallel_id, | |||
| parent_parallel_start_node_id=event.parent_parallel_start_node_id, | |||
| iteration_id=event.in_iteration_id, | |||
| parallel_run_id=event.parallel_mode_run_id, | |||
| ), | |||
| ) | |||
| @@ -464,7 +479,7 @@ class WorkflowCycleManage: | |||
| def _workflow_node_finish_to_stream_response( | |||
| self, | |||
| event: QueueNodeSucceededEvent | QueueNodeFailedEvent, | |||
| event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, | |||
| task_id: str, | |||
| workflow_node_execution: WorkflowNodeExecution, | |||
| ) -> Optional[NodeFinishStreamResponse]: | |||
| @@ -608,6 +623,7 @@ class WorkflowCycleManage: | |||
| extras={}, | |||
| parallel_id=event.parallel_id, | |||
| parallel_start_node_id=event.parallel_start_node_id, | |||
| parallel_mode_run_id=event.parallel_mode_run_id, | |||
| ), | |||
| ) | |||
| @@ -633,7 +649,9 @@ class WorkflowCycleManage: | |||
| created_at=int(time.time()), | |||
| extras={}, | |||
| inputs=event.inputs or {}, | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED | |||
| if event.error is None | |||
| else WorkflowNodeExecutionStatus.FAILED, | |||
| error=None, | |||
| elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), | |||
| total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, | |||
| @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum): | |||
| PARALLEL_START_NODE_ID = "parallel_start_node_id" | |||
| PARENT_PARALLEL_ID = "parent_parallel_id" | |||
| PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" | |||
| PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" | |||
| class NodeRunResult(BaseModel): | |||
| @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent): | |||
| class NodeRunStartedEvent(BaseNodeEvent): | |||
| predecessor_node_id: Optional[str] = None | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """predecessor node id""" | |||
| @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| class NodeInIterationFailedEvent(BaseNodeEvent): | |||
| error: str = Field(..., description="error") | |||
| ########################################### | |||
| # Parallel Branch Events | |||
| ########################################### | |||
| @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent): | |||
| """parent parallel id if node is in parallel""" | |||
| parent_parallel_start_node_id: Optional[str] = None | |||
| """parent parallel start node id if node is in parallel""" | |||
| parallel_mode_run_id: Optional[str] = None | |||
| """iteratoin run in parallel mode run id""" | |||
| class IterationRunStartedEvent(BaseIterationEvent): | |||
| @@ -4,6 +4,7 @@ import time | |||
| import uuid | |||
| from collections.abc import Generator, Mapping | |||
| from concurrent.futures import ThreadPoolExecutor, wait | |||
| from copy import copy, deepcopy | |||
| from typing import Any, Optional | |||
| from flask import Flask, current_app | |||
| @@ -724,6 +725,16 @@ class GraphEngine: | |||
| """ | |||
| return time.perf_counter() - start_at > max_execution_time | |||
| def create_copy(self): | |||
| """ | |||
| create a graph engine copy | |||
| :return: with a new variable pool instance of graph engine | |||
| """ | |||
| new_instance = copy(self) | |||
| new_instance.graph_runtime_state = copy(self.graph_runtime_state) | |||
| new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) | |||
| return new_instance | |||
| class GraphRunFailedError(Exception): | |||
| def __init__(self, error: str): | |||
| @@ -1,3 +1,4 @@ | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import Field | |||
| @@ -5,6 +6,12 @@ from pydantic import Field | |||
| from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData | |||
| class ErrorHandleMode(str, Enum): | |||
| TERMINATED = "terminated" | |||
| CONTINUE_ON_ERROR = "continue-on-error" | |||
| REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" | |||
| class IterationNodeData(BaseIterationNodeData): | |||
| """ | |||
| Iteration Node Data. | |||
| @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData): | |||
| parent_loop_id: Optional[str] = None # redundant field, not used currently | |||
| iterator_selector: list[str] # variable selector | |||
| output_selector: list[str] # output selector | |||
| is_parallel: bool = False # open the parallel mode or not | |||
| parallel_nums: int = 10 # the numbers of parallel | |||
| error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error | |||
| class IterationStartNodeData(BaseNodeData): | |||
| @@ -1,12 +1,20 @@ | |||
| import logging | |||
| import uuid | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from concurrent.futures import Future, wait | |||
| from datetime import datetime, timezone | |||
| from typing import Any, cast | |||
| from queue import Empty, Queue | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| from flask import Flask, current_app | |||
| from configs import dify_config | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.variables import IntegerSegment | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | |||
| from core.workflow.entities.node_entities import ( | |||
| NodeRunMetadataKey, | |||
| NodeRunResult, | |||
| ) | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine.entities.event import ( | |||
| BaseGraphEvent, | |||
| BaseNodeEvent, | |||
| @@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import ( | |||
| IterationRunNextEvent, | |||
| IterationRunStartedEvent, | |||
| IterationRunSucceededEvent, | |||
| NodeInIterationFailedEvent, | |||
| NodeRunFailedEvent, | |||
| NodeRunStartedEvent, | |||
| NodeRunStreamChunkEvent, | |||
| NodeRunSucceededEvent, | |||
| ) | |||
| @@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event import NodeEvent, RunCompletedEvent | |||
| from core.workflow.nodes.iteration.entities import IterationNodeData | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| if TYPE_CHECKING: | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| logger = logging.getLogger(__name__) | |||
| @@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| _node_data_cls = IterationNodeData | |||
| _node_type = NodeType.ITERATION | |||
| @classmethod | |||
| def get_default_config(cls, filters: Optional[dict] = None) -> dict: | |||
| return { | |||
| "type": "iteration", | |||
| "config": { | |||
| "is_parallel": False, | |||
| "parallel_nums": 10, | |||
| "error_handle_mode": ErrorHandleMode.TERMINATED.value, | |||
| }, | |||
| } | |||
| def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| """ | |||
| Run the node. | |||
| @@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| variable_pool.add([self.node_id, "item"], iterator_list_value[0]) | |||
| # init graph engine | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine | |||
| from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool | |||
| graph_engine = GraphEngine( | |||
| tenant_id=self.tenant_id, | |||
| @@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| index=0, | |||
| pre_iteration_output=None, | |||
| ) | |||
| outputs: list[Any] = [] | |||
| try: | |||
| for _ in range(len(iterator_list_value)): | |||
| # run workflow | |||
| rst = graph_engine.run() | |||
| for event in rst: | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: | |||
| event.in_iteration_id = self.node_id | |||
| if ( | |||
| isinstance(event, BaseNodeEvent) | |||
| and event.node_type == NodeType.ITERATION_START | |||
| and not isinstance(event, NodeRunStreamChunkEvent) | |||
| ): | |||
| continue | |||
| if isinstance(event, NodeRunSucceededEvent): | |||
| if event.route_node_state.node_run_result: | |||
| metadata = event.route_node_state.node_run_result.metadata | |||
| if not metadata: | |||
| metadata = {} | |||
| if NodeRunMetadataKey.ITERATION_ID not in metadata: | |||
| metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id | |||
| index_variable = variable_pool.get([self.node_id, "index"]) | |||
| if not isinstance(index_variable, IntegerSegment): | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=f"Invalid index variable type: {type(index_variable)}", | |||
| ) | |||
| ) | |||
| return | |||
| metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value | |||
| event.route_node_state.node_run_result.metadata = metadata | |||
| yield event | |||
| elif isinstance(event, BaseGraphEvent): | |||
| if isinstance(event, GraphRunFailedEvent): | |||
| # iteration run failed | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": jsonable_encoder(outputs)}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| return | |||
| else: | |||
| event = cast(InNodeEvent, event) | |||
| if self.node_data.is_parallel: | |||
| futures: list[Future] = [] | |||
| q = Queue() | |||
| thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) | |||
| for index, item in enumerate(iterator_list_value): | |||
| future: Future = thread_pool.submit( | |||
| self._run_single_iter_parallel, | |||
| current_app._get_current_object(), | |||
| q, | |||
| iterator_list_value, | |||
| inputs, | |||
| outputs, | |||
| start_at, | |||
| graph_engine, | |||
| iteration_graph, | |||
| index, | |||
| item, | |||
| ) | |||
| future.add_done_callback(thread_pool.task_done_callback) | |||
| futures.append(future) | |||
| succeeded_count = 0 | |||
| while True: | |||
| try: | |||
| event = q.get(timeout=1) | |||
| if event is None: | |||
| break | |||
| if isinstance(event, IterationRunNextEvent): | |||
| succeeded_count += 1 | |||
| if succeeded_count == len(futures): | |||
| q.put(None) | |||
| yield event | |||
| if isinstance(event, RunCompletedEvent): | |||
| q.put(None) | |||
| for f in futures: | |||
| if not f.done(): | |||
| f.cancel() | |||
| yield event | |||
| if isinstance(event, IterationRunFailedEvent): | |||
| q.put(None) | |||
| yield event | |||
| except Empty: | |||
| continue | |||
| # append to iteration output variable list | |||
| current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) | |||
| if current_iteration_output_variable is None: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=f"Iteration output variable {self.node_data.output_selector} not found", | |||
| ) | |||
| # wait all threads | |||
| wait(futures) | |||
| else: | |||
| for _ in range(len(iterator_list_value)): | |||
| yield from self._run_single_iter( | |||
| iterator_list_value, | |||
| variable_pool, | |||
| inputs, | |||
| outputs, | |||
| start_at, | |||
| graph_engine, | |||
| iteration_graph, | |||
| ) | |||
| return | |||
| current_iteration_output = current_iteration_output_variable.to_object() | |||
| outputs.append(current_iteration_output) | |||
| # remove all nodes outputs from variable pool | |||
| for node_id in iteration_graph.node_ids: | |||
| variable_pool.remove([node_id]) | |||
| # move to next iteration | |||
| current_index_variable = variable_pool.get([self.node_id, "index"]) | |||
| if not isinstance(current_index_variable, IntegerSegment): | |||
| raise ValueError(f"iteration {self.node_id} current index not found") | |||
| next_index = current_index_variable.value + 1 | |||
| variable_pool.add([self.node_id, "index"], next_index) | |||
| if next_index < len(iterator_list_value): | |||
| variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=next_index, | |||
| pre_iteration_output=jsonable_encoder(current_iteration_output), | |||
| ) | |||
| yield IterationRunSucceededEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| @@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]): | |||
| } | |||
| return variable_mapping | |||
| def _handle_event_metadata( | |||
| self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str | |||
| ) -> NodeRunStartedEvent | BaseNodeEvent: | |||
| """ | |||
| add iteration metadata to event. | |||
| """ | |||
| if not isinstance(event, BaseNodeEvent): | |||
| return event | |||
| if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): | |||
| event.parallel_mode_run_id = parallel_mode_run_id | |||
| return event | |||
| if event.route_node_state.node_run_result: | |||
| metadata = event.route_node_state.node_run_result.metadata | |||
| if not metadata: | |||
| metadata = {} | |||
| if NodeRunMetadataKey.ITERATION_ID not in metadata: | |||
| metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id | |||
| if self.node_data.is_parallel: | |||
| metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id | |||
| else: | |||
| metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index | |||
| event.route_node_state.node_run_result.metadata = metadata | |||
| return event | |||
| def _run_single_iter( | |||
| self, | |||
| iterator_list_value: list[str], | |||
| variable_pool: VariablePool, | |||
| inputs: dict[str, list], | |||
| outputs: list, | |||
| start_at: datetime, | |||
| graph_engine: "GraphEngine", | |||
| iteration_graph: Graph, | |||
| parallel_mode_run_id: Optional[str] = None, | |||
| ) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| """ | |||
| run single iteration | |||
| """ | |||
| try: | |||
| rst = graph_engine.run() | |||
| # get current iteration index | |||
| current_index = variable_pool.get([self.node_id, "index"]).value | |||
| next_index = int(current_index) + 1 | |||
| if current_index is None: | |||
| raise ValueError(f"iteration {self.node_id} current index not found") | |||
| for event in rst: | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: | |||
| event.in_iteration_id = self.node_id | |||
| if ( | |||
| isinstance(event, BaseNodeEvent) | |||
| and event.node_type == NodeType.ITERATION_START | |||
| and not isinstance(event, NodeRunStreamChunkEvent) | |||
| ): | |||
| continue | |||
| if isinstance(event, NodeRunSucceededEvent): | |||
| yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) | |||
| elif isinstance(event, BaseGraphEvent): | |||
| if isinstance(event, GraphRunFailedEvent): | |||
| # iteration run failed | |||
| if self.node_data.is_parallel: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": jsonable_encoder(outputs)}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| else: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": jsonable_encoder(outputs)}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=event.error, | |||
| ) | |||
| ) | |||
| return | |||
| else: | |||
| event = cast(InNodeEvent, event) | |||
| metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) | |||
| if isinstance(event, NodeRunFailedEvent): | |||
| if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: | |||
| yield NodeInIterationFailedEvent( | |||
| **metadata_event.model_dump(), | |||
| ) | |||
| outputs.insert(current_index, None) | |||
| variable_pool.add([self.node_id, "index"], next_index) | |||
| if next_index < len(iterator_list_value): | |||
| variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=next_index, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| pre_iteration_output=None, | |||
| ) | |||
| return | |||
| elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: | |||
| yield NodeInIterationFailedEvent( | |||
| **metadata_event.model_dump(), | |||
| ) | |||
| variable_pool.add([self.node_id, "index"], next_index) | |||
| if next_index < len(iterator_list_value): | |||
| variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=next_index, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| pre_iteration_output=None, | |||
| ) | |||
| return | |||
| elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": None}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=event.error, | |||
| ) | |||
| yield metadata_event | |||
| current_iteration_output = variable_pool.get(self.node_data.output_selector).value | |||
| outputs.insert(current_index, current_iteration_output) | |||
| # remove all nodes outputs from variable pool | |||
| for node_id in iteration_graph.node_ids: | |||
| variable_pool.remove([node_id]) | |||
| # move to next iteration | |||
| variable_pool.add([self.node_id, "index"], next_index) | |||
| if next_index < len(iterator_list_value): | |||
| variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) | |||
| yield IterationRunNextEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| index=next_index, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, | |||
| ) | |||
| except Exception as e: | |||
| logger.exception(f"Iteration run failed:{str(e)}") | |||
| yield IterationRunFailedEvent( | |||
| iteration_id=self.id, | |||
| iteration_node_id=self.node_id, | |||
| iteration_node_type=self.node_type, | |||
| iteration_node_data=self.node_data, | |||
| start_at=start_at, | |||
| inputs=inputs, | |||
| outputs={"output": None}, | |||
| steps=len(iterator_list_value), | |||
| metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, | |||
| error=str(e), | |||
| ) | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| error=str(e), | |||
| ) | |||
| ) | |||
| def _run_single_iter_parallel( | |||
| self, | |||
| flask_app: Flask, | |||
| q: Queue, | |||
| iterator_list_value: list[str], | |||
| inputs: dict[str, list], | |||
| outputs: list, | |||
| start_at: datetime, | |||
| graph_engine: "GraphEngine", | |||
| iteration_graph: Graph, | |||
| index: int, | |||
| item: Any, | |||
| ) -> Generator[NodeEvent | InNodeEvent, None, None]: | |||
| """ | |||
| run single iteration in parallel mode | |||
| """ | |||
| with flask_app.app_context(): | |||
| parallel_mode_run_id = uuid.uuid4().hex | |||
| graph_engine_copy = graph_engine.create_copy() | |||
| variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool | |||
| variable_pool_copy.add([self.node_id, "index"], index) | |||
| variable_pool_copy.add([self.node_id, "item"], item) | |||
| for event in self._run_single_iter( | |||
| iterator_list_value=iterator_list_value, | |||
| variable_pool=variable_pool_copy, | |||
| inputs=inputs, | |||
| outputs=outputs, | |||
| start_at=start_at, | |||
| graph_engine=graph_engine_copy, | |||
| iteration_graph=iteration_graph, | |||
| parallel_mode_run_id=parallel_mode_run_id, | |||
| ): | |||
| q.put(event) | |||
| @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph | |||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.nodes.event import RunCompletedEvent | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode | |||
| from core.workflow.nodes.iteration.iteration_node import IterationNode | |||
| from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode | |||
| from models.enums import UserFrom | |||
| @@ -185,8 +186,6 @@ def test_run(): | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| # print("") | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| result = iteration_node._run() | |||
| @@ -404,18 +403,458 @@ def test_run_parallel(): | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| # print("") | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| result = iteration_node._run() | |||
| count = 0 | |||
| for item in result: | |||
| # print(type(item), item) | |||
| count += 1 | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} | |||
| assert count == 32 | |||
| def test_iteration_run_in_parallel_mode(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-target", | |||
| "source": "iteration-start", | |||
| "target": "tt", | |||
| }, | |||
| { | |||
| "id": "iteration-start-source-tt-2-target", | |||
| "source": "iteration-start", | |||
| "target": "tt-2", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "tt", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "tt-2-source-if-else-target", | |||
| "source": "tt-2", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "answer-2", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "answer-4", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "answer": "{{#tt.output#}}", | |||
| "iteration_id": "iteration-1", | |||
| "title": "answer 2", | |||
| "type": "answer", | |||
| }, | |||
| "id": "answer-2", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "title": "iteration-start", | |||
| "type": "iteration-start", | |||
| }, | |||
| "id": "iteration-start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 123", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }} 321", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], | |||
| }, | |||
| "id": "tt-2", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "hi", | |||
| "variable_selector": ["sys", "query"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, | |||
| "id": "answer-4", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.CHAT, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "dify", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "1", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) | |||
| parallel_iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), | |||
| config={ | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| ) | |||
| sequential_iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), | |||
| config={ | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "迭代", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| ) | |||
| def tt_generator(self): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs={"iterator_selector": "dify"}, | |||
| outputs={"output": "dify 123"}, | |||
| ) | |||
| with patch.object(TemplateTransformNode, "_run", new=tt_generator): | |||
| # execute node | |||
| parallel_result = parallel_iteration_node._run() | |||
| sequential_result = sequential_iteration_node._run() | |||
| assert parallel_iteration_node.node_data.parallel_nums == 10 | |||
| assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED | |||
| count = 0 | |||
| parallel_arr = [] | |||
| sequential_arr = [] | |||
| for item in parallel_result: | |||
| count += 1 | |||
| parallel_arr.append(item) | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} | |||
| assert count == 32 | |||
| for item in sequential_result: | |||
| sequential_arr.append(item) | |||
| count += 1 | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} | |||
| assert count == 64 | |||
| def test_iteration_run_error_handle(): | |||
| graph_config = { | |||
| "edges": [ | |||
| { | |||
| "id": "start-source-pe-target", | |||
| "source": "start", | |||
| "target": "pe", | |||
| }, | |||
| { | |||
| "id": "iteration-1-source-answer-3-target", | |||
| "source": "iteration-1", | |||
| "target": "answer-3", | |||
| }, | |||
| { | |||
| "id": "tt-source-if-else-target", | |||
| "source": "iteration-start", | |||
| "target": "if-else", | |||
| }, | |||
| { | |||
| "id": "if-else-true-answer-2-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "true", | |||
| "target": "tt", | |||
| }, | |||
| { | |||
| "id": "if-else-false-answer-4-target", | |||
| "source": "if-else", | |||
| "sourceHandle": "false", | |||
| "target": "tt2", | |||
| }, | |||
| { | |||
| "id": "pe-source-iteration-1-target", | |||
| "source": "pe", | |||
| "target": "iteration-1", | |||
| }, | |||
| ], | |||
| "nodes": [ | |||
| {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, | |||
| { | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt2", "output"], | |||
| "output_type": "array[string]", | |||
| "start_node_id": "if-else", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1.split(arg2) }}", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [ | |||
| {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, | |||
| {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, | |||
| ], | |||
| }, | |||
| "id": "tt", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "template": "{{ arg1 }}", | |||
| "title": "template transform", | |||
| "type": "template-transform", | |||
| "variables": [ | |||
| {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, | |||
| ], | |||
| }, | |||
| "id": "tt2", | |||
| }, | |||
| { | |||
| "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, | |||
| "id": "answer-3", | |||
| }, | |||
| { | |||
| "data": { | |||
| "iteration_id": "iteration-1", | |||
| "title": "iteration-start", | |||
| "type": "iteration-start", | |||
| }, | |||
| "id": "iteration-start", | |||
| }, | |||
| { | |||
| "data": { | |||
| "conditions": [ | |||
| { | |||
| "comparison_operator": "is", | |||
| "id": "1721916275284", | |||
| "value": "1", | |||
| "variable_selector": ["iteration-1", "item"], | |||
| } | |||
| ], | |||
| "iteration_id": "iteration-1", | |||
| "logical_operator": "and", | |||
| "title": "if", | |||
| "type": "if-else", | |||
| }, | |||
| "id": "if-else", | |||
| }, | |||
| { | |||
| "data": { | |||
| "instruction": "test1", | |||
| "model": { | |||
| "completion_params": {"temperature": 0.7}, | |||
| "mode": "chat", | |||
| "name": "gpt-4o", | |||
| "provider": "openai", | |||
| }, | |||
| "parameters": [ | |||
| {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} | |||
| ], | |||
| "query": ["sys", "query"], | |||
| "reasoning_mode": "prompt", | |||
| "title": "pe", | |||
| "type": "parameter-extractor", | |||
| }, | |||
| "id": "pe", | |||
| }, | |||
| ], | |||
| } | |||
| graph = Graph.init(graph_config=graph_config) | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| workflow_type=WorkflowType.CHAT, | |||
| workflow_id="1", | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| ) | |||
| # construct variable pool | |||
| pool = VariablePool( | |||
| system_variables={ | |||
| SystemVariableKey.QUERY: "dify", | |||
| SystemVariableKey.FILES: [], | |||
| SystemVariableKey.CONVERSATION_ID: "abababa", | |||
| SystemVariableKey.USER_ID: "1", | |||
| }, | |||
| user_inputs={}, | |||
| environment_variables=[], | |||
| ) | |||
| pool.add(["pe", "list_output"], ["1", "1"]) | |||
| iteration_node = IterationNode( | |||
| id=str(uuid.uuid4()), | |||
| graph_init_params=init_params, | |||
| graph=graph, | |||
| graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), | |||
| config={ | |||
| "data": { | |||
| "iterator_selector": ["pe", "list_output"], | |||
| "output_selector": ["tt", "output"], | |||
| "output_type": "array[string]", | |||
| "startNodeType": "template-transform", | |||
| "start_node_id": "iteration-start", | |||
| "title": "iteration", | |||
| "type": "iteration", | |||
| "is_parallel": True, | |||
| "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, | |||
| }, | |||
| "id": "iteration-1", | |||
| }, | |||
| ) | |||
| # execute continue on error node | |||
| result = iteration_node._run() | |||
| result_arr = [] | |||
| count = 0 | |||
| for item in result: | |||
| result_arr.append(item) | |||
| count += 1 | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.outputs == {"output": [None, None]} | |||
| assert count == 14 | |||
| # execute remove abnormal output | |||
| iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT | |||
| result = iteration_node._run() | |||
| count = 0 | |||
| for item in result: | |||
| count += 1 | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.outputs == {"output": []} | |||
| assert count == 14 | |||
| @@ -125,7 +125,7 @@ const Select: FC<ISelectProps> = ({ | |||
| </Combobox.Button> | |||
| </div> | |||
| {filteredItems.length > 0 && ( | |||
| {(filteredItems.length > 0 && open) && ( | |||
| <Combobox.Options className={`absolute z-10 mt-1 px-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm ${overlayClassName}`}> | |||
| {filteredItems.map((item: Item) => ( | |||
| <Combobox.Option | |||
| @@ -340,7 +340,9 @@ export const NODES_INITIAL_DATA = { | |||
| ...ListFilterDefault.defaultValue, | |||
| }, | |||
| } | |||
| export const MAX_ITERATION_PARALLEL_NUM = 10 | |||
| export const MIN_ITERATION_PARALLEL_NUM = 1 | |||
| export const DEFAULT_ITER_TIMES = 1 | |||
| export const NODE_WIDTH = 240 | |||
| export const X_OFFSET = 60 | |||
| export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET | |||
| @@ -644,6 +644,11 @@ export const useNodesInteractions = () => { | |||
| newNode.data.isInIteration = true | |||
| newNode.data.iteration_id = prevNode.parentId | |||
| newNode.zIndex = ITERATION_CHILDREN_Z_INDEX | |||
| if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) { | |||
| const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId) | |||
| const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data | |||
| iterNodeData._isShowTips = true | |||
| } | |||
| } | |||
| const newEdge: Edge = { | |||
| @@ -14,6 +14,7 @@ import { | |||
| NodeRunningStatus, | |||
| WorkflowRunningStatus, | |||
| } from '../types' | |||
| import { DEFAULT_ITER_TIMES } from '../constants' | |||
| import { useWorkflowUpdate } from './use-workflow-interactions' | |||
| import { useStore as useAppStore } from '@/app/components/app/store' | |||
| import type { IOtherOptions } from '@/service/base' | |||
| @@ -170,11 +171,13 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| setIterParallelLogMap, | |||
| } = workflowStore.getState() | |||
| const { | |||
| edges, | |||
| setEdges, | |||
| } = store.getState() | |||
| setIterParallelLogMap(new Map()) | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| draft.task_id = task_id | |||
| draft.result = { | |||
| @@ -244,6 +247,8 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| iterParallelLogMap, | |||
| setIterParallelLogMap, | |||
| } = workflowStore.getState() | |||
| const { | |||
| getNodes, | |||
| @@ -259,10 +264,21 @@ export const useWorkflowRun = () => { | |||
| const tracing = draft.tracing! | |||
| const iterations = tracing.find(trace => trace.node_id === node?.parentId) | |||
| const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] | |||
| currIteration?.push({ | |||
| ...data, | |||
| status: NodeRunningStatus.Running, | |||
| } as any) | |||
| if (!data.parallel_run_id) { | |||
| currIteration?.push({ | |||
| ...data, | |||
| status: NodeRunningStatus.Running, | |||
| } as any) | |||
| } | |||
| else { | |||
| if (!iterParallelLogMap.has(data.parallel_run_id)) | |||
| iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) | |||
| else | |||
| iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) | |||
| setIterParallelLogMap(iterParallelLogMap) | |||
| if (iterations) | |||
| iterations.details = Array.from(iterParallelLogMap.values()) | |||
| } | |||
| })) | |||
| } | |||
| else { | |||
| @@ -309,6 +325,8 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| iterParallelLogMap, | |||
| setIterParallelLogMap, | |||
| } = workflowStore.getState() | |||
| const { | |||
| getNodes, | |||
| @@ -317,21 +335,21 @@ export const useWorkflowRun = () => { | |||
| const nodes = getNodes() | |||
| const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId | |||
| if (nodeParentId) { | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| const tracing = draft.tracing! | |||
| const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node | |||
| if (iterations && iterations.details) { | |||
| const iterationIndex = data.execution_metadata?.iteration_index || 0 | |||
| if (!iterations.details[iterationIndex]) | |||
| iterations.details[iterationIndex] = [] | |||
| const currIteration = iterations.details[iterationIndex] | |||
| const nodeIndex = currIteration.findIndex(node => | |||
| node.node_id === data.node_id && ( | |||
| node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), | |||
| ) | |||
| if (data.status === NodeRunningStatus.Succeeded) { | |||
| if (!data.execution_metadata.parallel_mode_run_id) { | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| const tracing = draft.tracing! | |||
| const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node | |||
| if (iterations && iterations.details) { | |||
| const iterationIndex = data.execution_metadata?.iteration_index || 0 | |||
| if (!iterations.details[iterationIndex]) | |||
| iterations.details[iterationIndex] = [] | |||
| const currIteration = iterations.details[iterationIndex] | |||
| const nodeIndex = currIteration.findIndex(node => | |||
| node.node_id === data.node_id && ( | |||
| node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), | |||
| ) | |||
| if (nodeIndex !== -1) { | |||
| currIteration[nodeIndex] = { | |||
| ...currIteration[nodeIndex], | |||
| @@ -344,8 +362,40 @@ export const useWorkflowRun = () => { | |||
| } as any) | |||
| } | |||
| } | |||
| } | |||
| })) | |||
| })) | |||
| } | |||
| else { | |||
| // open parallel mode | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| const tracing = draft.tracing! | |||
| const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node | |||
| if (iterations && iterations.details) { | |||
| const iterRunID = data.execution_metadata?.parallel_mode_run_id | |||
| const currIteration = iterParallelLogMap.get(iterRunID) | |||
| const nodeIndex = currIteration?.findIndex(node => | |||
| node.node_id === data.node_id && ( | |||
| node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), | |||
| ) | |||
| if (currIteration) { | |||
| if (nodeIndex !== undefined && nodeIndex !== -1) { | |||
| currIteration[nodeIndex] = { | |||
| ...currIteration[nodeIndex], | |||
| ...data, | |||
| } as any | |||
| } | |||
| else { | |||
| currIteration.push({ | |||
| ...data, | |||
| } as any) | |||
| } | |||
| } | |||
| setIterParallelLogMap(iterParallelLogMap) | |||
| iterations.details = Array.from(iterParallelLogMap.values()) | |||
| } | |||
| })) | |||
| } | |||
| } | |||
| else { | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| @@ -379,6 +429,7 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| setIterTimes, | |||
| } = workflowStore.getState() | |||
| const { | |||
| getNodes, | |||
| @@ -388,6 +439,7 @@ export const useWorkflowRun = () => { | |||
| transform, | |||
| } = store.getState() | |||
| const nodes = getNodes() | |||
| setIterTimes(DEFAULT_ITER_TIMES) | |||
| setWorkflowRunningData(produce(workflowRunningData!, (draft) => { | |||
| draft.tracing!.push({ | |||
| ...data, | |||
| @@ -431,6 +483,8 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| iterTimes, | |||
| setIterTimes, | |||
| } = workflowStore.getState() | |||
| const { data } = params | |||
| @@ -445,13 +499,14 @@ export const useWorkflowRun = () => { | |||
| if (iteration.details!.length >= iteration.metadata.iterator_length!) | |||
| return | |||
| } | |||
| iteration?.details!.push([]) | |||
| if (!data.parallel_mode_run_id) | |||
| iteration?.details!.push([]) | |||
| })) | |||
| const nodes = getNodes() | |||
| const newNodes = produce(nodes, (draft) => { | |||
| const currentNode = draft.find(node => node.id === data.node_id)! | |||
| currentNode.data._iterationIndex = data.index > 0 ? data.index : 1 | |||
| currentNode.data._iterationIndex = iterTimes | |||
| setIterTimes(iterTimes + 1) | |||
| }) | |||
| setNodes(newNodes) | |||
| @@ -464,6 +519,7 @@ export const useWorkflowRun = () => { | |||
| const { | |||
| workflowRunningData, | |||
| setWorkflowRunningData, | |||
| setIterTimes, | |||
| } = workflowStore.getState() | |||
| const { | |||
| getNodes, | |||
| @@ -480,7 +536,7 @@ export const useWorkflowRun = () => { | |||
| }) | |||
| } | |||
| })) | |||
| setIterTimes(DEFAULT_ITER_TIMES) | |||
| const newNodes = produce(nodes, (draft) => { | |||
| const currentNode = draft.find(node => node.id === data.node_id)! | |||
| @@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip' | |||
| type Props = { | |||
| className?: string | |||
| title: JSX.Element | string | DefaultTFuncReturn | |||
| tooltip?: React.ReactNode | |||
| isSubTitle?: boolean | |||
| tooltip?: string | |||
| supportFold?: boolean | |||
| children?: JSX.Element | string | null | |||
| operations?: JSX.Element | |||
| inline?: boolean | |||
| } | |||
| const Filed: FC<Props> = ({ | |||
| const Field: FC<Props> = ({ | |||
| className, | |||
| title, | |||
| isSubTitle, | |||
| @@ -60,4 +60,4 @@ const Filed: FC<Props> = ({ | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(Filed) | |||
| export default React.memo(Field) | |||
| @@ -25,6 +25,7 @@ import { | |||
| useToolIcon, | |||
| } from '../../hooks' | |||
| import { useNodeIterationInteractions } from '../iteration/use-interactions' | |||
| import type { IterationNodeType } from '../iteration/types' | |||
| import { | |||
| NodeSourceHandle, | |||
| NodeTargetHandle, | |||
| @@ -34,6 +35,7 @@ import NodeControl from './components/node-control' | |||
| import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' | |||
| import cn from '@/utils/classnames' | |||
| import BlockIcon from '@/app/components/workflow/block-icon' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type BaseNodeProps = { | |||
| children: ReactElement | |||
| @@ -166,9 +168,27 @@ const BaseNode: FC<BaseNodeProps> = ({ | |||
| /> | |||
| <div | |||
| title={data.title} | |||
| className='grow mr-1 system-sm-semibold-uppercase text-text-primary truncate' | |||
| className='grow mr-1 system-sm-semibold-uppercase text-text-primary truncate flex items-center' | |||
| > | |||
| {data.title} | |||
| <div> | |||
| {data.title} | |||
| </div> | |||
| { | |||
| data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && ( | |||
| <Tooltip popupContent={ | |||
| <div className='w-[180px]'> | |||
| <div className='font-extrabold'> | |||
| {t('workflow.nodes.iteration.parallelModeEnableTitle')} | |||
| </div> | |||
| {t('workflow.nodes.iteration.parallelModeEnableDesc')} | |||
| </div>} | |||
| > | |||
| <div className='flex justify-center items-center px-[5px] py-[3px] ml-1 border-[1px] border-text-warning rounded-[5px] text-text-warning system-2xs-medium-uppercase '> | |||
| {t('workflow.nodes.iteration.parallelModeUpper')} | |||
| </div> | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| </div> | |||
| { | |||
| data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && ( | |||
| @@ -1,7 +1,10 @@ | |||
| import { BlockEnum } from '../../types' | |||
| import { BlockEnum, ErrorHandleMode } from '../../types' | |||
| import type { NodeDefault } from '../../types' | |||
| import type { IterationNodeType } from './types' | |||
| import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' | |||
| import { | |||
| ALL_CHAT_AVAILABLE_BLOCKS, | |||
| ALL_COMPLETION_AVAILABLE_BLOCKS, | |||
| } from '@/app/components/workflow/constants' | |||
| const i18nPrefix = 'workflow' | |||
| const nodeDefault: NodeDefault<IterationNodeType> = { | |||
| @@ -10,25 +13,45 @@ const nodeDefault: NodeDefault<IterationNodeType> = { | |||
| iterator_selector: [], | |||
| output_selector: [], | |||
| _children: [], | |||
| _isShowTips: false, | |||
| is_parallel: false, | |||
| parallel_nums: 10, | |||
| error_handle_mode: ErrorHandleMode.Terminated, | |||
| }, | |||
| getAvailablePrevNodes(isChatMode: boolean) { | |||
| const nodes = isChatMode | |||
| ? ALL_CHAT_AVAILABLE_BLOCKS | |||
| : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) | |||
| : ALL_COMPLETION_AVAILABLE_BLOCKS.filter( | |||
| type => type !== BlockEnum.End, | |||
| ) | |||
| return nodes | |||
| }, | |||
| getAvailableNextNodes(isChatMode: boolean) { | |||
| const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS | |||
| const nodes = isChatMode | |||
| ? ALL_CHAT_AVAILABLE_BLOCKS | |||
| : ALL_COMPLETION_AVAILABLE_BLOCKS | |||
| return nodes | |||
| }, | |||
| checkValid(payload: IterationNodeType, t: any) { | |||
| let errorMessages = '' | |||
| if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0)) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) }) | |||
| if ( | |||
| !errorMessages | |||
| && (!payload.iterator_selector || payload.iterator_selector.length === 0) | |||
| ) { | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { | |||
| field: t(`${i18nPrefix}.nodes.iteration.input`), | |||
| }) | |||
| } | |||
| if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0)) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) }) | |||
| if ( | |||
| !errorMessages | |||
| && (!payload.output_selector || payload.output_selector.length === 0) | |||
| ) { | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { | |||
| field: t(`${i18nPrefix}.nodes.iteration.output`), | |||
| }) | |||
| } | |||
| return { | |||
| isValid: !errorMessages, | |||
| @@ -8,12 +8,16 @@ import { | |||
| useNodesInitialized, | |||
| useViewport, | |||
| } from 'reactflow' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { IterationStartNodeDumb } from '../iteration-start' | |||
| import { useNodeIterationInteractions } from './use-interactions' | |||
| import type { IterationNodeType } from './types' | |||
| import AddBlock from './add-block' | |||
| import cn from '@/utils/classnames' | |||
| import type { NodeProps } from '@/app/components/workflow/types' | |||
| import Toast from '@/app/components/base/toast' | |||
| const i18nPrefix = 'workflow.nodes.iteration' | |||
| const Node: FC<NodeProps<IterationNodeType>> = ({ | |||
| id, | |||
| @@ -22,11 +26,20 @@ const Node: FC<NodeProps<IterationNodeType>> = ({ | |||
| const { zoom } = useViewport() | |||
| const nodesInitialized = useNodesInitialized() | |||
| const { handleNodeIterationRerender } = useNodeIterationInteractions() | |||
| const { t } = useTranslation() | |||
| useEffect(() => { | |||
| if (nodesInitialized) | |||
| handleNodeIterationRerender(id) | |||
| }, [nodesInitialized, id, handleNodeIterationRerender]) | |||
| if (data.is_parallel && data._isShowTips) { | |||
| Toast.notify({ | |||
| type: 'warning', | |||
| message: t(`${i18nPrefix}.answerNodeWarningDesc`), | |||
| duration: 5000, | |||
| }) | |||
| data._isShowTips = false | |||
| } | |||
| }, [nodesInitialized, id, handleNodeIterationRerender, data, t]) | |||
| return ( | |||
| <div className={cn( | |||
| @@ -8,11 +8,17 @@ import VarReferencePicker from '../_base/components/variable/var-reference-picke | |||
| import Split from '../_base/components/split' | |||
| import ResultPanel from '../../run/result-panel' | |||
| import IterationResultPanel from '../../run/iteration-result-panel' | |||
| import { MAX_ITERATION_PARALLEL_NUM, MIN_ITERATION_PARALLEL_NUM } from '../../constants' | |||
| import type { IterationNodeType } from './types' | |||
| import useConfig from './use-config' | |||
| import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types' | |||
| import { ErrorHandleMode, InputVarType, type NodePanelProps } from '@/app/components/workflow/types' | |||
| import Field from '@/app/components/workflow/nodes/_base/components/field' | |||
| import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' | |||
| import Switch from '@/app/components/base/switch' | |||
| import Select from '@/app/components/base/select' | |||
| import Slider from '@/app/components/base/slider' | |||
| import Input from '@/app/components/base/input' | |||
| import Divider from '@/app/components/base/divider' | |||
| const i18nPrefix = 'workflow.nodes.iteration' | |||
| @@ -21,7 +27,20 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({ | |||
| data, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const responseMethod = [ | |||
| { | |||
| value: ErrorHandleMode.Terminated, | |||
| name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`), | |||
| }, | |||
| { | |||
| value: ErrorHandleMode.ContinueOnError, | |||
| name: t(`${i18nPrefix}.ErrorMethod.continueOnError`), | |||
| }, | |||
| { | |||
| value: ErrorHandleMode.RemoveAbnormalOutput, | |||
| name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`), | |||
| }, | |||
| ] | |||
| const { | |||
| readOnly, | |||
| inputs, | |||
| @@ -47,6 +66,9 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({ | |||
| setIterator, | |||
| iteratorInputKey, | |||
| iterationRunResult, | |||
| changeParallel, | |||
| changeErrorResponseMode, | |||
| changeParallelNums, | |||
| } = useConfig(id, data) | |||
| return ( | |||
| @@ -87,6 +109,39 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({ | |||
| /> | |||
| </Field> | |||
| </div> | |||
| <div className='px-4 pb-2'> | |||
| <Field title={t(`${i18nPrefix}.parallelMode`)} tooltip={<div className='w-[230px]'>{t(`${i18nPrefix}.parallelPanelDesc`)}</div>} inline> | |||
| <Switch defaultValue={inputs.is_parallel} onChange={changeParallel} /> | |||
| </Field> | |||
| </div> | |||
| { | |||
| inputs.is_parallel && (<div className='px-4 pb-2'> | |||
| <Field title={t(`${i18nPrefix}.MaxParallelismTitle`)} isSubTitle tooltip={<div className='w-[230px]'>{t(`${i18nPrefix}.MaxParallelismDesc`)}</div>}> | |||
| <div className='flex row'> | |||
| <Input type='number' wrapperClassName='w-18 mr-4 ' max={MAX_ITERATION_PARALLEL_NUM} min={MIN_ITERATION_PARALLEL_NUM} value={inputs.parallel_nums} onChange={(e) => { changeParallelNums(Number(e.target.value)) }} /> | |||
| <Slider | |||
| value={inputs.parallel_nums} | |||
| onChange={changeParallelNums} | |||
| max={MAX_ITERATION_PARALLEL_NUM} | |||
| min={MIN_ITERATION_PARALLEL_NUM} | |||
| className=' flex-shrink-0 flex-1 mt-4' | |||
| /> | |||
| </div> | |||
| </Field> | |||
| </div>) | |||
| } | |||
| <div className='px-4 py-2'> | |||
| <Divider className='h-[1px]'/> | |||
| </div> | |||
| <div className='px-4 py-2'> | |||
| <Field title={t(`${i18nPrefix}.errorResponseMethod`)} > | |||
| <Select items={responseMethod} defaultValue={inputs.error_handle_mode} onSelect={changeErrorResponseMode} allowSearch={false}> | |||
| </Select> | |||
| </Field> | |||
| </div> | |||
| {isShowSingleRun && ( | |||
| <BeforeRunForm | |||
| nodeName={inputs.title} | |||
| @@ -1,6 +1,7 @@ | |||
| import type { | |||
| BlockEnum, | |||
| CommonNodeType, | |||
| ErrorHandleMode, | |||
| ValueSelector, | |||
| VarType, | |||
| } from '@/app/components/workflow/types' | |||
| @@ -12,4 +13,8 @@ export type IterationNodeType = CommonNodeType & { | |||
| iterator_selector: ValueSelector | |||
| output_selector: ValueSelector | |||
| output_type: VarType // output type. | |||
| is_parallel: boolean // open the parallel mode or not | |||
| parallel_nums: number // the numbers of parallel | |||
| error_handle_mode: ErrorHandleMode // how to handle error in the iteration | |||
| _isShowTips: boolean // when answer node in parallel mode iteration show tips | |||
| } | |||
| @@ -8,12 +8,13 @@ import { | |||
| useWorkflow, | |||
| } from '../../hooks' | |||
| import { VarType } from '../../types' | |||
| import type { ValueSelector, Var } from '../../types' | |||
| import type { ErrorHandleMode, ValueSelector, Var } from '../../types' | |||
| import useNodeCrud from '../_base/hooks/use-node-crud' | |||
| import { getNodeInfoById, getNodeUsedVarPassToServerKey, getNodeUsedVars, isSystemVar, toNodeOutputVars } from '../_base/components/variable/utils' | |||
| import useOneStepRun from '../_base/hooks/use-one-step-run' | |||
| import type { IterationNodeType } from './types' | |||
| import type { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types' | |||
| import type { Item } from '@/app/components/base/select' | |||
| const DELIMITER = '@@@@@' | |||
| const useConfig = (id: string, payload: IterationNodeType) => { | |||
| @@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => { | |||
| }) | |||
| }, [iteratorInputKey, runInputData, setRunInputData]) | |||
| const changeParallel = useCallback((value: boolean) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.is_parallel = value | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs]) | |||
| const changeErrorResponseMode = useCallback((item: Item) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.error_handle_mode = item.value as ErrorHandleMode | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs]) | |||
| const changeParallelNums = useCallback((num: number) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.parallel_nums = num | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs]) | |||
| return { | |||
| readOnly, | |||
| inputs, | |||
| @@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => { | |||
| setIterator, | |||
| iteratorInputKey, | |||
| iterationRunResult, | |||
| changeParallel, | |||
| changeErrorResponseMode, | |||
| changeParallelNums, | |||
| } | |||
| } | |||
| @@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer' | |||
| import { uniqBy } from 'lodash-es' | |||
| import { useWorkflowRun } from '../../hooks' | |||
| import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' | |||
| import { useWorkflowStore } from '../../store' | |||
| import { DEFAULT_ITER_TIMES } from '../../constants' | |||
| import type { | |||
| ChatItem, | |||
| Inputs, | |||
| @@ -43,6 +45,7 @@ export const useChat = ( | |||
| const { notify } = useToastContext() | |||
| const { handleRun } = useWorkflowRun() | |||
| const hasStopResponded = useRef(false) | |||
| const workflowStore = useWorkflowStore() | |||
| const conversationId = useRef('') | |||
| const taskIdRef = useRef('') | |||
| const [chatList, setChatList] = useState<ChatItem[]>(prevChatList || []) | |||
| @@ -52,6 +55,9 @@ export const useChat = ( | |||
| const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([]) | |||
| const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null) | |||
| const { | |||
| setIterTimes, | |||
| } = workflowStore.getState() | |||
| useEffect(() => { | |||
| setAutoFreeze(false) | |||
| return () => { | |||
| @@ -102,15 +108,16 @@ export const useChat = ( | |||
| handleResponding(false) | |||
| if (stopChat && taskIdRef.current) | |||
| stopChat(taskIdRef.current) | |||
| setIterTimes(DEFAULT_ITER_TIMES) | |||
| if (suggestedQuestionsAbortControllerRef.current) | |||
| suggestedQuestionsAbortControllerRef.current.abort() | |||
| }, [handleResponding, stopChat]) | |||
| }, [handleResponding, setIterTimes, stopChat]) | |||
| const handleRestart = useCallback(() => { | |||
| conversationId.current = '' | |||
| taskIdRef.current = '' | |||
| handleStop() | |||
| setIterTimes(DEFAULT_ITER_TIMES) | |||
| const newChatList = config?.opening_statement | |||
| ? [{ | |||
| id: `${Date.now()}`, | |||
| @@ -126,6 +133,7 @@ export const useChat = ( | |||
| config, | |||
| handleStop, | |||
| handleUpdateChatList, | |||
| setIterTimes, | |||
| ]) | |||
| const updateCurrentQA = useCallback(({ | |||
| @@ -60,36 +60,67 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe | |||
| }, [notify, getResultCallback]) | |||
| const formatNodeList = useCallback((list: NodeTracing[]) => { | |||
| const allItems = list.reverse() | |||
| const allItems = [...list].reverse() | |||
| const result: NodeTracing[] = [] | |||
| allItems.forEach((item) => { | |||
| const { node_type, execution_metadata } = item | |||
| if (node_type !== BlockEnum.Iteration) { | |||
| const isInIteration = !!execution_metadata?.iteration_id | |||
| if (isInIteration) { | |||
| const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) | |||
| const iterationDetails = iterationNode?.details | |||
| const currentIterationIndex = execution_metadata?.iteration_index ?? 0 | |||
| if (Array.isArray(iterationDetails)) { | |||
| if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) | |||
| iterationDetails[currentIterationIndex] = [item] | |||
| else | |||
| iterationDetails[currentIterationIndex].push(item) | |||
| } | |||
| return | |||
| } | |||
| // not in iteration | |||
| result.push(item) | |||
| const groupMap = new Map<string, NodeTracing[]>() | |||
| return | |||
| } | |||
| const processIterationNode = (item: NodeTracing) => { | |||
| result.push({ | |||
| ...item, | |||
| details: [], | |||
| }) | |||
| } | |||
| const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => { | |||
| if (!groupMap.has(runId)) | |||
| groupMap.set(runId, [item]) | |||
| else | |||
| groupMap.get(runId)!.push(item) | |||
| if (item.status === 'failed') { | |||
| iterationNode.status = 'failed' | |||
| iterationNode.error = item.error | |||
| } | |||
| iterationNode.details = Array.from(groupMap.values()) | |||
| } | |||
| const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { | |||
| const { details } = iterationNode | |||
| if (details) { | |||
| if (!details[index]) | |||
| details[index] = [item] | |||
| else | |||
| details[index].push(item) | |||
| } | |||
| if (item.status === 'failed') { | |||
| iterationNode.status = 'failed' | |||
| iterationNode.error = item.error | |||
| } | |||
| } | |||
| const processNonIterationNode = (item: NodeTracing) => { | |||
| const { execution_metadata } = item | |||
| if (!execution_metadata?.iteration_id) { | |||
| result.push(item) | |||
| return | |||
| } | |||
| const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id) | |||
| if (!iterationNode || !Array.isArray(iterationNode.details)) | |||
| return | |||
| const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata | |||
| if (parallel_mode_run_id) | |||
| updateParallelModeGroup(parallel_mode_run_id, item, iterationNode) | |||
| else | |||
| updateSequentialModeGroup(iteration_index, item, iterationNode) | |||
| } | |||
| allItems.forEach((item) => { | |||
| item.node_type === BlockEnum.Iteration | |||
| ? processIterationNode(item) | |||
| : processNonIterationNode(item) | |||
| }) | |||
| return result | |||
| }, []) | |||
| @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiArrowRightSLine, | |||
| RiCloseLine, | |||
| RiErrorWarningLine, | |||
| } from '@remixicon/react' | |||
| import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' | |||
| import TracingPanel from './tracing-panel' | |||
| @@ -27,7 +28,7 @@ const IterationResultPanel: FC<Props> = ({ | |||
| noWrap, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const [expandedIterations, setExpandedIterations] = useState<Record<number, boolean>>([]) | |||
| const [expandedIterations, setExpandedIterations] = useState<Record<number, boolean>>({}) | |||
| const toggleIteration = useCallback((index: number) => { | |||
| setExpandedIterations(prev => ({ | |||
| @@ -71,10 +72,19 @@ const IterationResultPanel: FC<Props> = ({ | |||
| <span className='system-sm-semibold-uppercase text-text-primary flex-grow'> | |||
| {t(`${i18nPrefix}.iteration`)} {index + 1} | |||
| </span> | |||
| <RiArrowRightSLine className={cn( | |||
| 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', | |||
| expandedIterations[index] && 'transform rotate-90', | |||
| )} /> | |||
| { | |||
| iteration.some(item => item.status === 'failed') | |||
| ? ( | |||
| <RiErrorWarningLine className='w-4 h-4 text-text-destructive' /> | |||
| ) | |||
| : (< RiArrowRightSLine className={ | |||
| cn( | |||
| 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', | |||
| expandedIterations[index] && 'transform rotate-90', | |||
| )} /> | |||
| ) | |||
| } | |||
| </div> | |||
| </div> | |||
| {expandedIterations[index] && <div | |||
| @@ -72,7 +72,16 @@ const NodePanel: FC<Props> = ({ | |||
| return iteration_length | |||
| } | |||
| const getErrorCount = (details: NodeTracing[][] | undefined) => { | |||
| if (!details || details.length === 0) | |||
| return 0 | |||
| return details.reduce((acc, iteration) => { | |||
| if (iteration.some(item => item.status === 'failed')) | |||
| acc++ | |||
| return acc | |||
| }, 0) | |||
| } | |||
| useEffect(() => { | |||
| setCollapseState(!nodeInfo.expand) | |||
| }, [nodeInfo.expand, setCollapseState]) | |||
| @@ -136,7 +145,12 @@ const NodePanel: FC<Props> = ({ | |||
| onClick={handleOnShowIterationDetail} | |||
| > | |||
| <Iteration className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' /> | |||
| <div className='flex-1 text-left system-sm-medium text-components-button-tertiary-text'>{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}</div> | |||
| <div className='flex-1 text-left system-sm-medium text-components-button-tertiary-text'>{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && ( | |||
| <> | |||
| {t('workflow.nodes.iteration.comma')} | |||
| {t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })} | |||
| </> | |||
| )}</div> | |||
| {justShowIterationNavArrow | |||
| ? ( | |||
| <RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' /> | |||
| @@ -21,6 +21,7 @@ import type { | |||
| WorkflowRunningData, | |||
| } from './types' | |||
| import { WorkflowContext } from './context' | |||
| import type { NodeTracing } from '@/types/workflow' | |||
| // #TODO chatVar# | |||
| // const MOCK_DATA = [ | |||
| @@ -166,6 +167,10 @@ type Shape = { | |||
| setShowImportDSLModal: (showImportDSLModal: boolean) => void | |||
| showTips: string | |||
| setShowTips: (showTips: string) => void | |||
| iterTimes: number | |||
| setIterTimes: (iterTimes: number) => void | |||
| iterParallelLogMap: Map<string, NodeTracing[]> | |||
| setIterParallelLogMap: (iterParallelLogMap: Map<string, NodeTracing[]>) => void | |||
| } | |||
| export const createWorkflowStore = () => { | |||
| @@ -281,6 +286,11 @@ export const createWorkflowStore = () => { | |||
| setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), | |||
| showTips: '', | |||
| setShowTips: showTips => set(() => ({ showTips })), | |||
| iterTimes: 1, | |||
| setIterTimes: iterTimes => set(() => ({ iterTimes })), | |||
| iterParallelLogMap: new Map<string, NodeTracing[]>(), | |||
| setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })), | |||
| })) | |||
| } | |||
| @@ -36,7 +36,11 @@ export enum ControlMode { | |||
| Pointer = 'pointer', | |||
| Hand = 'hand', | |||
| } | |||
| export enum ErrorHandleMode { | |||
| Terminated = 'terminated', | |||
| ContinueOnError = 'continue-on-error', | |||
| RemoveAbnormalOutput = 'remove-abnormal-output', | |||
| } | |||
| export type Branch = { | |||
| id: string | |||
| name: string | |||
| @@ -19,7 +19,7 @@ import type { | |||
| ToolWithProvider, | |||
| ValueSelector, | |||
| } from './types' | |||
| import { BlockEnum } from './types' | |||
| import { BlockEnum, ErrorHandleMode } from './types' | |||
| import { | |||
| CUSTOM_NODE, | |||
| ITERATION_CHILDREN_Z_INDEX, | |||
| @@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { | |||
| }) | |||
| } | |||
| if (node.data.type === BlockEnum.Iteration) | |||
| node.data._children = iterationNodeMap[node.id] || [] | |||
| if (node.data.type === BlockEnum.Iteration) { | |||
| const iterationNodeData = node.data as IterationNodeType | |||
| iterationNodeData._children = iterationNodeMap[node.id] || [] | |||
| iterationNodeData.is_parallel = iterationNodeData.is_parallel || false | |||
| iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10 | |||
| iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated | |||
| } | |||
| return node | |||
| }) | |||
| @@ -556,6 +556,23 @@ const translation = { | |||
| iteration_one: '{{count}} Iteration', | |||
| iteration_other: '{{count}} Iterations', | |||
| currentIteration: 'Current Iteration', | |||
| comma: ', ', | |||
| error_one: '{{count}} Error', | |||
| error_other: '{{count}} Errors', | |||
| parallelMode: 'Parallel Mode', | |||
| parallelModeUpper: 'PARALLEL MODE', | |||
| parallelModeEnableTitle: 'Parallel Mode Enabled', | |||
| parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.', | |||
| parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.', | |||
| MaxParallelismTitle: 'Maximum parallelism', | |||
| MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', | |||
| errorResponseMethod: 'Error response method', | |||
| ErrorMethod: { | |||
| operationTerminated: 'terminated', | |||
| continueOnError: 'continue-on-error', | |||
| removeAbnormalOutput: 'remove-abnormal-output', | |||
| }, | |||
| answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', | |||
| }, | |||
| note: { | |||
| addNote: 'Add Note', | |||
| @@ -556,6 +556,23 @@ const translation = { | |||
| iteration_one: '{{count}}个迭代', | |||
| iteration_other: '{{count}}个迭代', | |||
| currentIteration: '当前迭代', | |||
| comma: ',', | |||
| error_one: '{{count}}个失败', | |||
| error_other: '{{count}}个失败', | |||
| parallelMode: '并行模式', | |||
| parallelModeUpper: '并行模式', | |||
| parallelModeEnableTitle: '并行模式启用', | |||
| parallelModeEnableDesc: '启用并行模式时迭代内的任务支持并行执行。你可以在右侧的属性面板中进行配置。', | |||
| parallelPanelDesc: '在并行模式下,迭代中的任务支持并行执行。', | |||
| MaxParallelismTitle: '最大并行度', | |||
| MaxParallelismDesc: '最大并行度用于控制单次迭代中同时执行的任务数量。', | |||
| errorResponseMethod: '错误响应方法', | |||
| ErrorMethod: { | |||
| operationTerminated: '错误时终止', | |||
| continueOnError: '忽略错误并继续', | |||
| removeAbnormalOutput: '移除错误输出', | |||
| }, | |||
| answerNodeWarningDesc: '并行模式警告:在迭代中,回答节点、会话变量赋值和工具持久读/写操作可能会导致异常。', | |||
| }, | |||
| note: { | |||
| addNote: '添加注释', | |||
| @@ -19,6 +19,7 @@ export type NodeTracing = { | |||
| process_data: any | |||
| outputs?: any | |||
| status: string | |||
| parallel_run_id?: string | |||
| error?: string | |||
| elapsed_time: number | |||
| execution_metadata: { | |||
| @@ -31,6 +32,7 @@ export type NodeTracing = { | |||
| parallel_start_node_id?: string | |||
| parent_parallel_id?: string | |||
| parent_parallel_start_node_id?: string | |||
| parallel_mode_run_id?: string | |||
| } | |||
| metadata: { | |||
| iterator_length: number | |||
| @@ -121,6 +123,7 @@ export type NodeStartedResponse = { | |||
| id: string | |||
| node_id: string | |||
| iteration_id?: string | |||
| parallel_run_id?: string | |||
| node_type: string | |||
| index: number | |||
| predecessor_node_id?: string | |||
| @@ -166,6 +169,7 @@ export type NodeFinishedResponse = { | |||
| parallel_start_node_id?: string | |||
| iteration_index?: number | |||
| iteration_id?: string | |||
| parallel_mode_run_id: string | |||
| } | |||
| created_at: number | |||
| files?: FileResponse[] | |||
| @@ -200,6 +204,7 @@ export type IterationNextResponse = { | |||
| output: any | |||
| extras?: any | |||
| created_at: number | |||
| parallel_mode_run_id: string | |||
| execution_metadata: { | |||
| parallel_id?: string | |||
| } | |||