|
|
|
@@ -1,10 +1,20 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from collections.abc import Generator, Mapping, Sequence |
|
|
|
from datetime import UTC, datetime |
|
|
|
from typing import Any, cast |
|
|
|
from typing import TYPE_CHECKING, Any, Literal, cast |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.variables import IntegerSegment |
|
|
|
from core.variables import ( |
|
|
|
ArrayNumberSegment, |
|
|
|
ArrayObjectSegment, |
|
|
|
ArrayStringSegment, |
|
|
|
IntegerSegment, |
|
|
|
ObjectSegment, |
|
|
|
Segment, |
|
|
|
SegmentType, |
|
|
|
StringSegment, |
|
|
|
) |
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult |
|
|
|
from core.workflow.graph_engine.entities.event import ( |
|
|
|
BaseGraphEvent, |
|
|
|
@@ -29,6 +39,10 @@ from core.workflow.nodes.loop.entities import LoopNodeData |
|
|
|
from core.workflow.utils.condition.processor import ConditionProcessor |
|
|
|
from models.workflow import WorkflowNodeExecutionStatus |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from core.workflow.entities.variable_pool import VariablePool |
|
|
|
from core.workflow.graph_engine.graph_engine import GraphEngine |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
@@ -61,6 +75,28 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
variable_pool = self.graph_runtime_state.variable_pool |
|
|
|
variable_pool.add([self.node_id, "index"], 0) |
|
|
|
|
|
|
|
# Initialize loop variables |
|
|
|
loop_variable_selectors = {} |
|
|
|
if self.node_data.loop_variables: |
|
|
|
for loop_variable in self.node_data.loop_variables: |
|
|
|
value_processor = { |
|
|
|
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), |
|
|
|
"variable": lambda var=loop_variable: variable_pool.get(var.value), |
|
|
|
} |
|
|
|
|
|
|
|
if loop_variable.value_type not in value_processor: |
|
|
|
raise ValueError( |
|
|
|
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" |
|
|
|
) |
|
|
|
|
|
|
|
processed_segment = value_processor[loop_variable.value_type]() |
|
|
|
if not processed_segment: |
|
|
|
raise ValueError(f"Invalid value for loop variable {loop_variable.label}") |
|
|
|
variable_selector = [self.node_id, loop_variable.label] |
|
|
|
variable_pool.add(variable_selector, processed_segment.value) |
|
|
|
loop_variable_selectors[loop_variable.label] = variable_selector |
|
|
|
inputs[loop_variable.label] = processed_segment.value |
|
|
|
|
|
|
|
from core.workflow.graph_engine.graph_engine import GraphEngine |
|
|
|
|
|
|
|
graph_engine = GraphEngine( |
|
|
|
@@ -95,135 +131,51 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
predecessor_node_id=self.previous_node_id, |
|
|
|
) |
|
|
|
|
|
|
|
yield LoopRunNextEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
index=0, |
|
|
|
pre_loop_output=None, |
|
|
|
) |
|
|
|
|
|
|
|
# yield LoopRunNextEvent( |
|
|
|
# loop_id=self.id, |
|
|
|
# loop_node_id=self.node_id, |
|
|
|
# loop_node_type=self.node_type, |
|
|
|
# loop_node_data=self.node_data, |
|
|
|
# index=0, |
|
|
|
# pre_loop_output=None, |
|
|
|
# ) |
|
|
|
loop_duration_map = {} |
|
|
|
single_loop_variable_map = {} # single loop variable output |
|
|
|
try: |
|
|
|
check_break_result = False |
|
|
|
for i in range(loop_count): |
|
|
|
# Run workflow |
|
|
|
rst = graph_engine.run() |
|
|
|
current_index_variable = variable_pool.get([self.node_id, "index"]) |
|
|
|
if not isinstance(current_index_variable, IntegerSegment): |
|
|
|
raise ValueError(f"loop {self.node_id} current index not found") |
|
|
|
current_index = current_index_variable.value |
|
|
|
|
|
|
|
check_break_result = False |
|
|
|
|
|
|
|
for event in rst: |
|
|
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: |
|
|
|
event.in_loop_id = self.node_id |
|
|
|
|
|
|
|
if ( |
|
|
|
isinstance(event, BaseNodeEvent) |
|
|
|
and event.node_type == NodeType.LOOP_START |
|
|
|
and not isinstance(event, NodeRunStreamChunkEvent) |
|
|
|
): |
|
|
|
continue |
|
|
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent): |
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index) |
|
|
|
|
|
|
|
# Check if all variables in break conditions exist |
|
|
|
exists_variable = False |
|
|
|
for condition in break_conditions: |
|
|
|
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): |
|
|
|
exists_variable = False |
|
|
|
break |
|
|
|
else: |
|
|
|
exists_variable = True |
|
|
|
if exists_variable: |
|
|
|
input_conditions, group_result, check_break_result = condition_processor.process_conditions( |
|
|
|
variable_pool=self.graph_runtime_state.variable_pool, |
|
|
|
conditions=break_conditions, |
|
|
|
operator=logical_operator, |
|
|
|
) |
|
|
|
if check_break_result: |
|
|
|
break |
|
|
|
|
|
|
|
elif isinstance(event, BaseGraphEvent): |
|
|
|
if isinstance(event, GraphRunFailedEvent): |
|
|
|
# Loop run failed |
|
|
|
yield LoopRunFailedEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
steps=i, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "error", |
|
|
|
}, |
|
|
|
error=event.error, |
|
|
|
) |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
error=event.error, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens |
|
|
|
}, |
|
|
|
) |
|
|
|
) |
|
|
|
return |
|
|
|
elif isinstance(event, NodeRunFailedEvent): |
|
|
|
# Loop run failed |
|
|
|
yield event |
|
|
|
yield LoopRunFailedEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
steps=i, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "error", |
|
|
|
}, |
|
|
|
error=event.error, |
|
|
|
) |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
error=event.error, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens |
|
|
|
}, |
|
|
|
) |
|
|
|
) |
|
|
|
return |
|
|
|
loop_start_time = datetime.now(UTC).replace(tzinfo=None) |
|
|
|
# run single loop |
|
|
|
loop_result = yield from self._run_single_loop( |
|
|
|
graph_engine=graph_engine, |
|
|
|
loop_graph=loop_graph, |
|
|
|
variable_pool=variable_pool, |
|
|
|
loop_variable_selectors=loop_variable_selectors, |
|
|
|
break_conditions=break_conditions, |
|
|
|
logical_operator=logical_operator, |
|
|
|
condition_processor=condition_processor, |
|
|
|
current_index=i, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
) |
|
|
|
loop_end_time = datetime.now(UTC).replace(tzinfo=None) |
|
|
|
|
|
|
|
single_loop_variable = {} |
|
|
|
for key, selector in loop_variable_selectors.items(): |
|
|
|
item = variable_pool.get(selector) |
|
|
|
if item: |
|
|
|
single_loop_variable[key] = item.value |
|
|
|
else: |
|
|
|
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) |
|
|
|
single_loop_variable[key] = None |
|
|
|
|
|
|
|
# Remove all nodes outputs from variable pool |
|
|
|
for node_id in loop_graph.node_ids: |
|
|
|
variable_pool.remove([node_id]) |
|
|
|
loop_duration_map[str(i)] = (loop_end_time - loop_start_time).total_seconds() |
|
|
|
single_loop_variable_map[str(i)] = single_loop_variable |
|
|
|
|
|
|
|
check_break_result = loop_result.get("check_break_result", False) |
|
|
|
|
|
|
|
if check_break_result: |
|
|
|
break |
|
|
|
|
|
|
|
# Move to next loop |
|
|
|
next_index = current_index + 1 |
|
|
|
variable_pool.add([self.node_id, "index"], next_index) |
|
|
|
|
|
|
|
yield LoopRunNextEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
index=next_index, |
|
|
|
pre_loop_output=None, |
|
|
|
) |
|
|
|
|
|
|
|
# Loop completed successfully |
|
|
|
yield LoopRunSucceededEvent( |
|
|
|
loop_id=self.id, |
|
|
|
@@ -232,17 +184,26 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
loop_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
outputs=self.node_data.outputs, |
|
|
|
steps=loop_count, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "loop_break" if check_break_result else "loop_completed", |
|
|
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, |
|
|
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, |
|
|
|
}, |
|
|
|
) |
|
|
|
|
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, |
|
|
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, |
|
|
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, |
|
|
|
}, |
|
|
|
outputs=self.node_data.outputs, |
|
|
|
inputs=inputs, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
@@ -260,6 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
metadata={ |
|
|
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "error", |
|
|
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, |
|
|
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, |
|
|
|
}, |
|
|
|
error=str(e), |
|
|
|
) |
|
|
|
@@ -268,7 +231,11 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
error=str(e), |
|
|
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, |
|
|
|
NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, |
|
|
|
}, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
@@ -276,6 +243,159 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
# Clean up |
|
|
|
variable_pool.remove([self.node_id, "index"]) |
|
|
|
|
|
|
|
def _run_single_loop( |
|
|
|
self, |
|
|
|
*, |
|
|
|
graph_engine: "GraphEngine", |
|
|
|
loop_graph: Graph, |
|
|
|
variable_pool: "VariablePool", |
|
|
|
loop_variable_selectors: dict, |
|
|
|
break_conditions: list, |
|
|
|
logical_operator: Literal["and", "or"], |
|
|
|
condition_processor: ConditionProcessor, |
|
|
|
current_index: int, |
|
|
|
start_at: datetime, |
|
|
|
inputs: dict, |
|
|
|
) -> Generator[NodeEvent | InNodeEvent, None, dict]: |
|
|
|
"""Run a single loop iteration. |
|
|
|
Returns: |
|
|
|
dict: {'check_break_result': bool} |
|
|
|
""" |
|
|
|
# Run workflow |
|
|
|
rst = graph_engine.run() |
|
|
|
current_index_variable = variable_pool.get([self.node_id, "index"]) |
|
|
|
if not isinstance(current_index_variable, IntegerSegment): |
|
|
|
raise ValueError(f"loop {self.node_id} current index not found") |
|
|
|
current_index = current_index_variable.value |
|
|
|
|
|
|
|
check_break_result = False |
|
|
|
|
|
|
|
for event in rst: |
|
|
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: |
|
|
|
event.in_loop_id = self.node_id |
|
|
|
|
|
|
|
if ( |
|
|
|
isinstance(event, BaseNodeEvent) |
|
|
|
and event.node_type == NodeType.LOOP_START |
|
|
|
and not isinstance(event, NodeRunStreamChunkEvent) |
|
|
|
): |
|
|
|
continue |
|
|
|
|
|
|
|
if ( |
|
|
|
isinstance(event, NodeRunSucceededEvent) |
|
|
|
and event.node_type == NodeType.LOOP_END |
|
|
|
and not isinstance(event, NodeRunStreamChunkEvent) |
|
|
|
): |
|
|
|
check_break_result = True |
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index) |
|
|
|
break |
|
|
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent): |
|
|
|
yield self._handle_event_metadata(event=event, iter_run_index=current_index) |
|
|
|
|
|
|
|
# Check if all variables in break conditions exist |
|
|
|
exists_variable = False |
|
|
|
for condition in break_conditions: |
|
|
|
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): |
|
|
|
exists_variable = False |
|
|
|
break |
|
|
|
else: |
|
|
|
exists_variable = True |
|
|
|
if exists_variable: |
|
|
|
input_conditions, group_result, check_break_result = condition_processor.process_conditions( |
|
|
|
variable_pool=self.graph_runtime_state.variable_pool, |
|
|
|
conditions=break_conditions, |
|
|
|
operator=logical_operator, |
|
|
|
) |
|
|
|
if check_break_result: |
|
|
|
break |
|
|
|
|
|
|
|
elif isinstance(event, BaseGraphEvent): |
|
|
|
if isinstance(event, GraphRunFailedEvent): |
|
|
|
# Loop run failed |
|
|
|
yield LoopRunFailedEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
steps=current_index, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "error", |
|
|
|
}, |
|
|
|
error=event.error, |
|
|
|
) |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
error=event.error, |
|
|
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
) |
|
|
|
) |
|
|
|
return {"check_break_result": True} |
|
|
|
elif isinstance(event, NodeRunFailedEvent): |
|
|
|
# Loop run failed |
|
|
|
yield event |
|
|
|
yield LoopRunFailedEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
steps=current_index, |
|
|
|
metadata={ |
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, |
|
|
|
"completed_reason": "error", |
|
|
|
}, |
|
|
|
error=event.error, |
|
|
|
) |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
error=event.error, |
|
|
|
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
) |
|
|
|
) |
|
|
|
return {"check_break_result": True} |
|
|
|
else: |
|
|
|
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index) |
|
|
|
|
|
|
|
# Remove all nodes outputs from variable pool |
|
|
|
for node_id in loop_graph.node_ids: |
|
|
|
variable_pool.remove([node_id]) |
|
|
|
|
|
|
|
_outputs = {} |
|
|
|
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items(): |
|
|
|
_loop_variable_segment = variable_pool.get(loop_variable_selector) |
|
|
|
if _loop_variable_segment: |
|
|
|
_outputs[loop_variable_key] = _loop_variable_segment.value |
|
|
|
else: |
|
|
|
_outputs[loop_variable_key] = None |
|
|
|
|
|
|
|
_outputs["loop_round"] = current_index + 1 |
|
|
|
self.node_data.outputs = _outputs |
|
|
|
|
|
|
|
if check_break_result: |
|
|
|
return {"check_break_result": True} |
|
|
|
|
|
|
|
# Move to next loop |
|
|
|
next_index = current_index + 1 |
|
|
|
variable_pool.add([self.node_id, "index"], next_index) |
|
|
|
|
|
|
|
yield LoopRunNextEvent( |
|
|
|
loop_id=self.id, |
|
|
|
loop_node_id=self.node_id, |
|
|
|
loop_node_type=self.node_type, |
|
|
|
loop_node_data=self.node_data, |
|
|
|
index=next_index, |
|
|
|
pre_loop_output=self.node_data.outputs, |
|
|
|
) |
|
|
|
|
|
|
|
return {"check_break_result": False} |
|
|
|
|
|
|
|
def _handle_event_metadata( |
|
|
|
self, |
|
|
|
*, |
|
|
|
@@ -360,3 +480,25 @@ class LoopNode(BaseNode[LoopNodeData]): |
|
|
|
} |
|
|
|
|
|
|
|
return variable_mapping |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_segment_for_constant(var_type: str, value: Any) -> Segment: |
|
|
|
"""Get the appropriate segment type for a constant value.""" |
|
|
|
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { |
|
|
|
"string": (StringSegment, SegmentType.STRING), |
|
|
|
"number": (IntegerSegment, SegmentType.NUMBER), |
|
|
|
"object": (ObjectSegment, SegmentType.OBJECT), |
|
|
|
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), |
|
|
|
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), |
|
|
|
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), |
|
|
|
} |
|
|
|
if var_type in ["array[string]", "array[number]", "array[object]"]: |
|
|
|
if value: |
|
|
|
value = json.loads(value) |
|
|
|
else: |
|
|
|
value = [] |
|
|
|
segment_info = segment_mapping.get(var_type) |
|
|
|
if not segment_info: |
|
|
|
raise ValueError(f"Invalid variable type: {var_type}") |
|
|
|
segment_class, value_type = segment_info |
|
|
|
return segment_class(value=value, value_type=value_type) |