|
|
|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast |
|
|
|
from flask import Flask, current_app |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder |
|
|
|
from core.variables import IntegerVariable |
|
|
|
from core.workflow.entities.node_entities import ( |
|
|
|
NodeRunMetadataKey, |
|
|
|
NodeRunResult, |
|
|
|
@@ -155,18 +155,19 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
iteration_node_data=self.node_data, |
|
|
|
index=0, |
|
|
|
pre_iteration_output=None, |
|
|
|
duration=None, |
|
|
|
) |
|
|
|
iter_run_map: dict[str, float] = {} |
|
|
|
outputs: list[Any] = [None] * len(iterator_list_value) |
|
|
|
try: |
|
|
|
if self.node_data.is_parallel: |
|
|
|
futures: list[Future] = [] |
|
|
|
q = Queue() |
|
|
|
q: Queue = Queue() |
|
|
|
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) |
|
|
|
for index, item in enumerate(iterator_list_value): |
|
|
|
future: Future = thread_pool.submit( |
|
|
|
self._run_single_iter_parallel, |
|
|
|
current_app._get_current_object(), |
|
|
|
current_app._get_current_object(), # type: ignore |
|
|
|
q, |
|
|
|
iterator_list_value, |
|
|
|
inputs, |
|
|
|
@@ -181,6 +182,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
future.add_done_callback(thread_pool.task_done_callback) |
|
|
|
futures.append(future) |
|
|
|
succeeded_count = 0 |
|
|
|
empty_count = 0 |
|
|
|
while True: |
|
|
|
try: |
|
|
|
event = q.get(timeout=1) |
|
|
|
@@ -208,17 +210,22 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
else: |
|
|
|
for _ in range(len(iterator_list_value)): |
|
|
|
yield from self._run_single_iter( |
|
|
|
iterator_list_value, |
|
|
|
variable_pool, |
|
|
|
inputs, |
|
|
|
outputs, |
|
|
|
start_at, |
|
|
|
graph_engine, |
|
|
|
iteration_graph, |
|
|
|
iter_run_map, |
|
|
|
iterator_list_value=iterator_list_value, |
|
|
|
variable_pool=variable_pool, |
|
|
|
inputs=inputs, |
|
|
|
outputs=outputs, |
|
|
|
start_at=start_at, |
|
|
|
graph_engine=graph_engine, |
|
|
|
iteration_graph=iteration_graph, |
|
|
|
iter_run_map=iter_run_map, |
|
|
|
) |
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: |
|
|
|
outputs = [output for output in outputs if output is not None] |
|
|
|
|
|
|
|
# Flatten the list of lists |
|
|
|
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): |
|
|
|
outputs = [item for sublist in outputs for item in sublist] |
|
|
|
|
|
|
|
yield IterationRunSucceededEvent( |
|
|
|
iteration_id=self.id, |
|
|
|
iteration_node_id=self.node_id, |
|
|
|
@@ -226,7 +233,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
iteration_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
outputs={"output": jsonable_encoder(outputs)}, |
|
|
|
outputs={"output": outputs}, |
|
|
|
steps=len(iterator_list_value), |
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
) |
|
|
|
@@ -234,7 +241,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, |
|
|
|
outputs={"output": jsonable_encoder(outputs)}, |
|
|
|
outputs={"output": outputs}, |
|
|
|
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, |
|
|
|
) |
|
|
|
) |
|
|
|
@@ -248,7 +255,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
iteration_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
outputs={"output": jsonable_encoder(outputs)}, |
|
|
|
outputs={"output": outputs}, |
|
|
|
steps=len(iterator_list_value), |
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
error=str(e), |
|
|
|
@@ -280,7 +287,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
:param node_data: node data |
|
|
|
:return: |
|
|
|
""" |
|
|
|
variable_mapping = { |
|
|
|
variable_mapping: dict[str, Sequence[str]] = { |
|
|
|
f"{node_id}.input_selector": node_data.iterator_selector, |
|
|
|
} |
|
|
|
|
|
|
|
@@ -308,7 +315,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( |
|
|
|
graph_config=graph_config, config=sub_node_config |
|
|
|
) |
|
|
|
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) |
|
|
|
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) |
|
|
|
except NotImplementedError: |
|
|
|
sub_node_variable_mapping = {} |
|
|
|
|
|
|
|
@@ -329,8 +336,12 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
return variable_mapping |
|
|
|
|
|
|
|
def _handle_event_metadata( |
|
|
|
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str |
|
|
|
) -> NodeRunStartedEvent | BaseNodeEvent: |
|
|
|
self, |
|
|
|
*, |
|
|
|
event: BaseNodeEvent | InNodeEvent, |
|
|
|
iter_run_index: int, |
|
|
|
parallel_mode_run_id: str | None, |
|
|
|
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: |
|
|
|
""" |
|
|
|
add iteration metadata to event. |
|
|
|
""" |
|
|
|
@@ -355,6 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
|
|
|
|
def _run_single_iter( |
|
|
|
self, |
|
|
|
*, |
|
|
|
iterator_list_value: list[str], |
|
|
|
variable_pool: VariablePool, |
|
|
|
inputs: dict[str, list], |
|
|
|
@@ -373,12 +385,12 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
try: |
|
|
|
rst = graph_engine.run() |
|
|
|
# get current iteration index |
|
|
|
current_index = variable_pool.get([self.node_id, "index"]).value |
|
|
|
index_variable = variable_pool.get([self.node_id, "index"]) |
|
|
|
if not isinstance(index_variable, IntegerVariable): |
|
|
|
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") |
|
|
|
current_index = index_variable.value |
|
|
|
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" |
|
|
|
next_index = int(current_index) + 1 |
|
|
|
|
|
|
|
if current_index is None: |
|
|
|
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") |
|
|
|
for event in rst: |
|
|
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: |
|
|
|
event.in_iteration_id = self.node_id |
|
|
|
@@ -391,7 +403,9 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
continue |
|
|
|
|
|
|
|
if isinstance(event, NodeRunSucceededEvent): |
|
|
|
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) |
|
|
|
yield self._handle_event_metadata( |
|
|
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id |
|
|
|
) |
|
|
|
elif isinstance(event, BaseGraphEvent): |
|
|
|
if isinstance(event, GraphRunFailedEvent): |
|
|
|
# iteration run failed |
|
|
|
@@ -404,7 +418,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
parallel_mode_run_id=parallel_mode_run_id, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
outputs={"output": jsonable_encoder(outputs)}, |
|
|
|
outputs={"output": outputs}, |
|
|
|
steps=len(iterator_list_value), |
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
error=event.error, |
|
|
|
@@ -417,7 +431,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
iteration_node_data=self.node_data, |
|
|
|
start_at=start_at, |
|
|
|
inputs=inputs, |
|
|
|
outputs={"output": jsonable_encoder(outputs)}, |
|
|
|
outputs={"output": outputs}, |
|
|
|
steps=len(iterator_list_value), |
|
|
|
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, |
|
|
|
error=event.error, |
|
|
|
@@ -429,9 +443,11 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
) |
|
|
|
) |
|
|
|
return |
|
|
|
else: |
|
|
|
event = cast(InNodeEvent, event) |
|
|
|
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) |
|
|
|
elif isinstance(event, InNodeEvent): |
|
|
|
# event = cast(InNodeEvent, event) |
|
|
|
metadata_event = self._handle_event_metadata( |
|
|
|
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id |
|
|
|
) |
|
|
|
if isinstance(event, NodeRunFailedEvent): |
|
|
|
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: |
|
|
|
yield NodeInIterationFailedEvent( |
|
|
|
@@ -513,7 +529,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
iteration_node_data=self.node_data, |
|
|
|
index=next_index, |
|
|
|
parallel_mode_run_id=parallel_mode_run_id, |
|
|
|
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, |
|
|
|
pre_iteration_output=current_iteration_output or None, |
|
|
|
duration=duration, |
|
|
|
) |
|
|
|
|
|
|
|
@@ -551,7 +567,7 @@ class IterationNode(BaseNode[IterationNodeData]): |
|
|
|
index: int, |
|
|
|
item: Any, |
|
|
|
iter_run_map: dict[str, float], |
|
|
|
) -> Generator[NodeEvent | InNodeEvent, None, None]: |
|
|
|
): |
|
|
|
""" |
|
|
|
run single iteration in parallel mode |
|
|
|
""" |