| @@ -1,9 +1,11 @@ | |||
| import contextvars | |||
| import logging | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from concurrent.futures import Future, ThreadPoolExecutor, as_completed | |||
| from datetime import UTC, datetime | |||
| from typing import TYPE_CHECKING, Any, NewType, cast | |||
| from flask import Flask, current_app | |||
| from typing_extensions import TypeIs | |||
| from core.variables import IntegerVariable, NoneSegment | |||
| @@ -35,6 +37,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.flask_utils import preserve_flask_contexts | |||
| from .exc import ( | |||
| InvalidIteratorValueError, | |||
| @@ -239,6 +242,8 @@ class IterationNode(Node): | |||
| self._execute_single_iteration_parallel, | |||
| index=index, | |||
| item=item, | |||
| flask_app=current_app._get_current_object(), # type: ignore | |||
| context_vars=contextvars.copy_context(), | |||
| ) | |||
| future_to_index[future] = index | |||
| @@ -281,26 +286,29 @@ class IterationNode(Node): | |||
| self, | |||
| index: int, | |||
| item: object, | |||
| flask_app: Flask, | |||
| context_vars: contextvars.Context, | |||
| ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: | |||
| """Execute a single iteration in parallel mode and return results.""" | |||
| iter_start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| events: list[GraphNodeEventBase] = [] | |||
| outputs_temp: list[object] = [] | |||
| with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): | |||
| iter_start_at = datetime.now(UTC).replace(tzinfo=None) | |||
| events: list[GraphNodeEventBase] = [] | |||
| outputs_temp: list[object] = [] | |||
| graph_engine = self._create_graph_engine(index, item) | |||
| graph_engine = self._create_graph_engine(index, item) | |||
| # Collect events instead of yielding them directly | |||
| for event in self._run_single_iter( | |||
| variable_pool=graph_engine.graph_runtime_state.variable_pool, | |||
| outputs=outputs_temp, | |||
| graph_engine=graph_engine, | |||
| ): | |||
| events.append(event) | |||
| # Collect events instead of yielding them directly | |||
| for event in self._run_single_iter( | |||
| variable_pool=graph_engine.graph_runtime_state.variable_pool, | |||
| outputs=outputs_temp, | |||
| graph_engine=graph_engine, | |||
| ): | |||
| events.append(event) | |||
| # Get the output value from the temporary outputs list | |||
| output_value = outputs_temp[0] if outputs_temp else None | |||
| # Get the output value from the temporary outputs list | |||
| output_value = outputs_temp[0] if outputs_temp else None | |||
| return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens | |||
| return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens | |||
| def _handle_iteration_success( | |||
| self, | |||