Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| event_collector: "EventCollector", | event_collector: "EventCollector", | ||||
| branch_handler: "BranchHandler", | branch_handler: "BranchHandler", | ||||
| edge_processor: "EdgeProcessor", | edge_processor: "EdgeProcessor", | ||||
| node_state_manager: "UnifiedStateManager", | |||||
| execution_tracker: "UnifiedStateManager", | |||||
| state_manager: "UnifiedStateManager", | |||||
| error_handler: "ErrorHandler", | error_handler: "ErrorHandler", | ||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| event_collector: Event collector for collecting events | event_collector: Event collector for collecting events | ||||
| branch_handler: Branch handler for branch node processing | branch_handler: Branch handler for branch node processing | ||||
| edge_processor: Edge processor for edge traversal | edge_processor: Edge processor for edge traversal | ||||
| node_state_manager: Node state manager | |||||
| execution_tracker: Execution tracker | |||||
| state_manager: Unified state manager | |||||
| error_handler: Error handler | error_handler: Error handler | ||||
| """ | """ | ||||
| self._graph = graph | self._graph = graph | ||||
| self._event_collector = event_collector | self._event_collector = event_collector | ||||
| self._branch_handler = branch_handler | self._branch_handler = branch_handler | ||||
| self._edge_processor = edge_processor | self._edge_processor = edge_processor | ||||
| self._node_state_manager = node_state_manager | |||||
| self._execution_tracker = execution_tracker | |||||
| self._state_manager = state_manager | |||||
| self._error_handler = error_handler | self._error_handler = error_handler | ||||
| def handle_event(self, event: GraphNodeEventBase) -> None: | def handle_event(self, event: GraphNodeEventBase) -> None: | ||||
| # Enqueue ready nodes | # Enqueue ready nodes | ||||
| for node_id in ready_nodes: | for node_id in ready_nodes: | ||||
| self._node_state_manager.enqueue_node(node_id) | |||||
| self._execution_tracker.add(node_id) | |||||
| self._state_manager.enqueue_node(node_id) | |||||
| self._state_manager.start_execution(node_id) | |||||
| # Update execution tracking | # Update execution tracking | ||||
| self._execution_tracker.remove(event.node_id) | |||||
| self._state_manager.finish_execution(event.node_id) | |||||
| # Handle response node outputs | # Handle response node outputs | ||||
| if node.execution_type == NodeExecutionType.RESPONSE: | if node.execution_type == NodeExecutionType.RESPONSE: | ||||
| # Abort execution | # Abort execution | ||||
| self._graph_execution.fail(RuntimeError(event.error)) | self._graph_execution.fail(RuntimeError(event.error)) | ||||
| self._event_collector.collect(event) | self._event_collector.collect(event) | ||||
| self._execution_tracker.remove(event.node_id) | |||||
| self._state_manager.finish_execution(event.node_id) | |||||
| def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: | def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: | ||||
| """ | """ | 
| self.node_readiness_checker = NodeReadinessChecker(self.graph) | self.node_readiness_checker = NodeReadinessChecker(self.graph) | ||||
| self.edge_processor = EdgeProcessor( | self.edge_processor = EdgeProcessor( | ||||
| graph=self.graph, | graph=self.graph, | ||||
| edge_state_manager=self.state_manager, | |||||
| node_state_manager=self.state_manager, | |||||
| state_manager=self.state_manager, | |||||
| response_coordinator=self.response_coordinator, | response_coordinator=self.response_coordinator, | ||||
| ) | ) | ||||
| self.skip_propagator = SkipPropagator( | self.skip_propagator = SkipPropagator( | ||||
| graph=self.graph, | graph=self.graph, | ||||
| edge_state_manager=self.state_manager, | |||||
| node_state_manager=self.state_manager, | |||||
| state_manager=self.state_manager, | |||||
| ) | ) | ||||
| self.branch_handler = BranchHandler( | self.branch_handler = BranchHandler( | ||||
| graph=self.graph, | graph=self.graph, | ||||
| edge_processor=self.edge_processor, | edge_processor=self.edge_processor, | ||||
| skip_propagator=self.skip_propagator, | skip_propagator=self.skip_propagator, | ||||
| edge_state_manager=self.state_manager, | |||||
| state_manager=self.state_manager, | |||||
| ) | ) | ||||
| # Event handler registry with all dependencies | # Event handler registry with all dependencies | ||||
| event_collector=self.event_collector, | event_collector=self.event_collector, | ||||
| branch_handler=self.branch_handler, | branch_handler=self.branch_handler, | ||||
| edge_processor=self.edge_processor, | edge_processor=self.edge_processor, | ||||
| node_state_manager=self.state_manager, | |||||
| execution_tracker=self.state_manager, | |||||
| state_manager=self.state_manager, | |||||
| error_handler=self.error_handler, | error_handler=self.error_handler, | ||||
| ) | ) | ||||
| # Orchestration | # Orchestration | ||||
| self.execution_coordinator = ExecutionCoordinator( | self.execution_coordinator = ExecutionCoordinator( | ||||
| graph_execution=self.graph_execution, | graph_execution=self.graph_execution, | ||||
| node_state_manager=self.state_manager, | |||||
| execution_tracker=self.state_manager, | |||||
| state_manager=self.state_manager, | |||||
| event_handler=self.event_handler_registry, | event_handler=self.event_handler_registry, | ||||
| event_collector=self.event_collector, | event_collector=self.event_collector, | ||||
| command_processor=self.command_processor, | command_processor=self.command_processor, | ||||
| # Enqueue root node | # Enqueue root node | ||||
| root_node = self.graph.root_node | root_node = self.graph.root_node | ||||
| self.state_manager.enqueue_node(root_node.id) | self.state_manager.enqueue_node(root_node.id) | ||||
| self.state_manager.add(root_node.id) | |||||
| self.state_manager.start_execution(root_node.id) | |||||
| # Start dispatcher | # Start dispatcher | ||||
| self.dispatcher.start() | self.dispatcher.start() | 
| graph: Graph, | graph: Graph, | ||||
| edge_processor: EdgeProcessor, | edge_processor: EdgeProcessor, | ||||
| skip_propagator: SkipPropagator, | skip_propagator: SkipPropagator, | ||||
| edge_state_manager: UnifiedStateManager, | |||||
| state_manager: UnifiedStateManager, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Initialize the branch handler. | Initialize the branch handler. | ||||
| graph: The workflow graph | graph: The workflow graph | ||||
| edge_processor: Processor for edges | edge_processor: Processor for edges | ||||
| skip_propagator: Propagator for skip states | skip_propagator: Propagator for skip states | ||||
| edge_state_manager: Manager for edge states | |||||
| state_manager: Unified state manager | |||||
| """ | """ | ||||
| self.graph = graph | self.graph = graph | ||||
| self.edge_processor = edge_processor | self.edge_processor = edge_processor | ||||
| self.skip_propagator = skip_propagator | self.skip_propagator = skip_propagator | ||||
| self.edge_state_manager = edge_state_manager | |||||
| self.state_manager = state_manager | |||||
| def handle_branch_completion( | def handle_branch_completion( | ||||
| self, node_id: str, selected_handle: str | None | self, node_id: str, selected_handle: str | None | ||||
| raise ValueError(f"Branch node {node_id} completed without selecting a branch") | raise ValueError(f"Branch node {node_id} completed without selecting a branch") | ||||
| # Categorize edges into selected and unselected | # Categorize edges into selected and unselected | ||||
| _, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||||
| _, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||||
| # Skip all unselected paths | # Skip all unselected paths | ||||
| self.skip_propagator.skip_branch_paths(unselected_edges) | self.skip_propagator.skip_branch_paths(unselected_edges) | 
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| graph: Graph, | graph: Graph, | ||||
| edge_state_manager: UnifiedStateManager, | |||||
| node_state_manager: UnifiedStateManager, | |||||
| state_manager: UnifiedStateManager, | |||||
| response_coordinator: ResponseStreamCoordinator, | response_coordinator: ResponseStreamCoordinator, | ||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Args: | Args: | ||||
| graph: The workflow graph | graph: The workflow graph | ||||
| edge_state_manager: Manager for edge states | |||||
| node_state_manager: Manager for node states | |||||
| state_manager: Unified state manager | |||||
| response_coordinator: Response stream coordinator | response_coordinator: Response stream coordinator | ||||
| """ | """ | ||||
| self.graph = graph | self.graph = graph | ||||
| self.edge_state_manager = edge_state_manager | |||||
| self.node_state_manager = node_state_manager | |||||
| self.state_manager = state_manager | |||||
| self.response_coordinator = response_coordinator | self.response_coordinator = response_coordinator | ||||
| def process_node_success( | def process_node_success( | ||||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | all_streaming_events: list[NodeRunStreamChunkEvent] = [] | ||||
| # Categorize edges | # Categorize edges | ||||
| selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||||
| selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||||
| # Process unselected edges first (mark as skipped) | # Process unselected edges first (mark as skipped) | ||||
| for edge in unselected_edges: | for edge in unselected_edges: | ||||
| Tuple of (list containing downstream node ID if it's ready, list of streaming events) | Tuple of (list containing downstream node ID if it's ready, list of streaming events) | ||||
| """ | """ | ||||
| # Mark edge as taken | # Mark edge as taken | ||||
| self.edge_state_manager.mark_edge_taken(edge.id) | |||||
| self.state_manager.mark_edge_taken(edge.id) | |||||
| # Notify response coordinator and get streaming events | # Notify response coordinator and get streaming events | ||||
| streaming_events = self.response_coordinator.on_edge_taken(edge.id) | streaming_events = self.response_coordinator.on_edge_taken(edge.id) | ||||
| # Check if downstream node is ready | # Check if downstream node is ready | ||||
| ready_nodes: list[str] = [] | ready_nodes: list[str] = [] | ||||
| if self.node_state_manager.is_node_ready(edge.head): | |||||
| if self.state_manager.is_node_ready(edge.head): | |||||
| ready_nodes.append(edge.head) | ready_nodes.append(edge.head) | ||||
| return ready_nodes, streaming_events | return ready_nodes, streaming_events | ||||
| Args: | Args: | ||||
| edge: The edge to skip | edge: The edge to skip | ||||
| """ | """ | ||||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||||
| self.state_manager.mark_edge_skipped(edge.id) | 
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| graph: Graph, | graph: Graph, | ||||
| edge_state_manager: UnifiedStateManager, | |||||
| node_state_manager: UnifiedStateManager, | |||||
| state_manager: UnifiedStateManager, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Initialize the skip propagator. | Initialize the skip propagator. | ||||
| Args: | Args: | ||||
| graph: The workflow graph | graph: The workflow graph | ||||
| edge_state_manager: Manager for edge states | |||||
| node_state_manager: Manager for node states | |||||
| state_manager: Unified state manager | |||||
| """ | """ | ||||
| self.graph = graph | self.graph = graph | ||||
| self.edge_state_manager = edge_state_manager | |||||
| self.node_state_manager = node_state_manager | |||||
| self.state_manager = state_manager | |||||
| def propagate_skip_from_edge(self, edge_id: str) -> None: | def propagate_skip_from_edge(self, edge_id: str) -> None: | ||||
| """ | """ | ||||
| incoming_edges = self.graph.get_incoming_edges(downstream_node_id) | incoming_edges = self.graph.get_incoming_edges(downstream_node_id) | ||||
| # Analyze edge states | # Analyze edge states | ||||
| edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges) | |||||
| edge_states = self.state_manager.analyze_edge_states(incoming_edges) | |||||
| # Stop if there are unknown edges (not yet processed) | # Stop if there are unknown edges (not yet processed) | ||||
| if edge_states["has_unknown"]: | if edge_states["has_unknown"]: | ||||
| # If any edge is taken, node may still execute | # If any edge is taken, node may still execute | ||||
| if edge_states["has_taken"]: | if edge_states["has_taken"]: | ||||
| # Enqueue node | # Enqueue node | ||||
| self.node_state_manager.enqueue_node(downstream_node_id) | |||||
| self.state_manager.enqueue_node(downstream_node_id) | |||||
| return | return | ||||
| # All edges are skipped, propagate skip to this node | # All edges are skipped, propagate skip to this node | ||||
| node_id: The ID of the node to skip | node_id: The ID of the node to skip | ||||
| """ | """ | ||||
| # Mark node as skipped | # Mark node as skipped | ||||
| self.node_state_manager.mark_node_skipped(node_id) | |||||
| self.state_manager.mark_node_skipped(node_id) | |||||
| # Mark all outgoing edges as skipped and propagate | # Mark all outgoing edges as skipped and propagate | ||||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | outgoing_edges = self.graph.get_outgoing_edges(node_id) | ||||
| for edge in outgoing_edges: | for edge in outgoing_edges: | ||||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||||
| self.state_manager.mark_edge_skipped(edge.id) | |||||
| # Recursively propagate skip | # Recursively propagate skip | ||||
| self.propagate_skip_from_edge(edge.id) | self.propagate_skip_from_edge(edge.id) | ||||
| unselected_edges: List of edges not taken by the branch | unselected_edges: List of edges not taken by the branch | ||||
| """ | """ | ||||
| for edge in unselected_edges: | for edge in unselected_edges: | ||||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||||
| self.state_manager.mark_edge_skipped(edge.id) | |||||
| self.propagate_skip_from_edge(edge.id) | self.propagate_skip_from_edge(edge.id) | 
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| graph_execution: GraphExecution, | graph_execution: GraphExecution, | ||||
| node_state_manager: UnifiedStateManager, | |||||
| execution_tracker: UnifiedStateManager, | |||||
| state_manager: UnifiedStateManager, | |||||
| event_handler: "EventHandlerRegistry", | event_handler: "EventHandlerRegistry", | ||||
| event_collector: EventCollector, | event_collector: EventCollector, | ||||
| command_processor: CommandProcessor, | command_processor: CommandProcessor, | ||||
| Args: | Args: | ||||
| graph_execution: Graph execution aggregate | graph_execution: Graph execution aggregate | ||||
| node_state_manager: Manager for node states | |||||
| execution_tracker: Tracker for executing nodes | |||||
| state_manager: Unified state manager | |||||
| event_handler: Event handler registry for processing events | event_handler: Event handler registry for processing events | ||||
| event_collector: Event collector for collecting events | event_collector: Event collector for collecting events | ||||
| command_processor: Processor for commands | command_processor: Processor for commands | ||||
| worker_pool: Pool of workers | worker_pool: Pool of workers | ||||
| """ | """ | ||||
| self.graph_execution = graph_execution | self.graph_execution = graph_execution | ||||
| self.node_state_manager = node_state_manager | |||||
| self.execution_tracker = execution_tracker | |||||
| self.state_manager = state_manager | |||||
| self.event_handler = event_handler | self.event_handler = event_handler | ||||
| self.event_collector = event_collector | self.event_collector = event_collector | ||||
| self.command_processor = command_processor | self.command_processor = command_processor | ||||
| def check_scaling(self) -> None: | def check_scaling(self) -> None: | ||||
| """Check and perform worker scaling if needed.""" | """Check and perform worker scaling if needed.""" | ||||
| queue_depth = self.node_state_manager.ready_queue.qsize() | |||||
| executing_count = self.execution_tracker.count() | |||||
| queue_depth = self.state_manager.ready_queue.qsize() | |||||
| executing_count = self.state_manager.get_executing_count() | |||||
| self.worker_pool.check_scaling(queue_depth, executing_count) | self.worker_pool.check_scaling(queue_depth, executing_count) | ||||
| def is_execution_complete(self) -> bool: | def is_execution_complete(self) -> bool: | ||||
| return True | return True | ||||
| # Complete if no work remains | # Complete if no work remains | ||||
| return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty() | |||||
| return self.state_manager.is_execution_complete() | |||||
| def mark_complete(self) -> None: | def mark_complete(self) -> None: | ||||
| """Mark execution as complete.""" | """Mark execution as complete.""" | 
| "skipped_nodes": skipped_nodes, | "skipped_nodes": skipped_nodes, | ||||
| "unknown_nodes": unknown_nodes, | "unknown_nodes": unknown_nodes, | ||||
| } | } | ||||
| # ============= Backward Compatibility Methods ============= | |||||
| # These methods provide compatibility with existing code | |||||
| @property | |||||
| def execution_tracker(self) -> "UnifiedStateManager": | |||||
| """Compatibility property for ExecutionTracker access.""" | |||||
| return self | |||||
| @property | |||||
| def node_state_manager(self) -> "UnifiedStateManager": | |||||
| """Compatibility property for NodeStateManager access.""" | |||||
| return self | |||||
| @property | |||||
| def edge_state_manager(self) -> "UnifiedStateManager": | |||||
| """Compatibility property for EdgeStateManager access.""" | |||||
| return self | |||||
| # ExecutionTracker compatibility methods | |||||
| def add(self, node_id: str) -> None: | |||||
| """Compatibility method for ExecutionTracker.add().""" | |||||
| self.start_execution(node_id) | |||||
| def remove(self, node_id: str) -> None: | |||||
| """Compatibility method for ExecutionTracker.remove().""" | |||||
| self.finish_execution(node_id) | |||||
| def is_empty(self) -> bool: | |||||
| """Compatibility method for ExecutionTracker.is_empty().""" | |||||
| return len(self._executing_nodes) == 0 | |||||
| def count(self) -> int: | |||||
| """Compatibility method for ExecutionTracker.count().""" | |||||
| return self.get_executing_count() | |||||
| def clear(self) -> None: | |||||
| """Compatibility method for ExecutionTracker.clear().""" | |||||
| self.clear_executing() | 
| "min_workers": self.min_workers, | "min_workers": self.min_workers, | ||||
| "max_workers": self.max_workers, | "max_workers": self.max_workers, | ||||
| } | } | ||||
| # ============= Backward Compatibility ============= | |||||
| def scale_up(self) -> None: | |||||
| """Compatibility method for manual scale up.""" | |||||
| with self._lock: | |||||
| if self._running and len(self.workers) < self.max_workers: | |||||
| self._add_worker() | |||||
| def scale_down(self, worker_ids: list[int]) -> None: | |||||
| """Compatibility method for manual scale down.""" | |||||
| with self._lock: | |||||
| if not self._running: | |||||
| return | |||||
| for worker_id in worker_ids: | |||||
| if len(self.workers) > self.min_workers: | |||||
| self._remove_worker(worker_id) | |||||
| def check_scaling(self, queue_depth: int, executing_count: int) -> None: | |||||
| """ | |||||
| Compatibility method for checking scaling. | |||||
| Args: | |||||
| queue_depth: Current queue depth (ignored, we check directly) | |||||
| executing_count: Number of executing nodes (ignored) | |||||
| """ | |||||
| self.check_and_scale() |