| import contextvars | |||||
| import logging | import logging | ||||
| from collections.abc import Generator, Mapping, Sequence | from collections.abc import Generator, Mapping, Sequence | ||||
| from concurrent.futures import Future, ThreadPoolExecutor, as_completed | from concurrent.futures import Future, ThreadPoolExecutor, as_completed | ||||
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from typing import TYPE_CHECKING, Any, NewType, cast | from typing import TYPE_CHECKING, Any, NewType, cast | ||||
| from flask import Flask, current_app | |||||
| from typing_extensions import TypeIs | from typing_extensions import TypeIs | ||||
| from core.variables import IntegerVariable, NoneSegment | from core.variables import IntegerVariable, NoneSegment | ||||
| from core.workflow.nodes.base.node import Node | from core.workflow.nodes.base.node import Node | ||||
| from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from libs.flask_utils import preserve_flask_contexts | |||||
| from .exc import ( | from .exc import ( | ||||
| InvalidIteratorValueError, | InvalidIteratorValueError, | ||||
| self._execute_single_iteration_parallel, | self._execute_single_iteration_parallel, | ||||
| index=index, | index=index, | ||||
| item=item, | item=item, | ||||
| flask_app=current_app._get_current_object(), # type: ignore | |||||
| context_vars=contextvars.copy_context(), | |||||
| ) | ) | ||||
| future_to_index[future] = index | future_to_index[future] = index | ||||
| self, | self, | ||||
| index: int, | index: int, | ||||
| item: object, | item: object, | ||||
| flask_app: Flask, | |||||
| context_vars: contextvars.Context, | |||||
| ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: | ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: | ||||
| """Execute a single iteration in parallel mode and return results.""" | """Execute a single iteration in parallel mode and return results.""" | ||||
| iter_start_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| events: list[GraphNodeEventBase] = [] | |||||
| outputs_temp: list[object] = [] | |||||
| with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): | |||||
| iter_start_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| events: list[GraphNodeEventBase] = [] | |||||
| outputs_temp: list[object] = [] | |||||
| graph_engine = self._create_graph_engine(index, item) | |||||
| graph_engine = self._create_graph_engine(index, item) | |||||
| # Collect events instead of yielding them directly | |||||
| for event in self._run_single_iter( | |||||
| variable_pool=graph_engine.graph_runtime_state.variable_pool, | |||||
| outputs=outputs_temp, | |||||
| graph_engine=graph_engine, | |||||
| ): | |||||
| events.append(event) | |||||
| # Collect events instead of yielding them directly | |||||
| for event in self._run_single_iter( | |||||
| variable_pool=graph_engine.graph_runtime_state.variable_pool, | |||||
| outputs=outputs_temp, | |||||
| graph_engine=graph_engine, | |||||
| ): | |||||
| events.append(event) | |||||
| # Get the output value from the temporary outputs list | |||||
| output_value = outputs_temp[0] if outputs_temp else None | |||||
| # Get the output value from the temporary outputs list | |||||
| output_value = outputs_temp[0] if outputs_temp else None | |||||
| return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens | |||||
| return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens | |||||
| def _handle_iteration_success( | def _handle_iteration_success( | ||||
| self, | self, |