Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| @@ -56,8 +56,7 @@ class EventHandlerRegistry: | |||
| event_collector: "EventCollector", | |||
| branch_handler: "BranchHandler", | |||
| edge_processor: "EdgeProcessor", | |||
| node_state_manager: "UnifiedStateManager", | |||
| execution_tracker: "UnifiedStateManager", | |||
| state_manager: "UnifiedStateManager", | |||
| error_handler: "ErrorHandler", | |||
| ) -> None: | |||
| """ | |||
| @@ -71,8 +70,7 @@ class EventHandlerRegistry: | |||
| event_collector: Event collector for collecting events | |||
| branch_handler: Branch handler for branch node processing | |||
| edge_processor: Edge processor for edge traversal | |||
| node_state_manager: Node state manager | |||
| execution_tracker: Execution tracker | |||
| state_manager: Unified state manager | |||
| error_handler: Error handler | |||
| """ | |||
| self._graph = graph | |||
| @@ -82,8 +80,7 @@ class EventHandlerRegistry: | |||
| self._event_collector = event_collector | |||
| self._branch_handler = branch_handler | |||
| self._edge_processor = edge_processor | |||
| self._node_state_manager = node_state_manager | |||
| self._execution_tracker = execution_tracker | |||
| self._state_manager = state_manager | |||
| self._error_handler = error_handler | |||
| def handle_event(self, event: GraphNodeEventBase) -> None: | |||
| @@ -199,11 +196,11 @@ class EventHandlerRegistry: | |||
| # Enqueue ready nodes | |||
| for node_id in ready_nodes: | |||
| self._node_state_manager.enqueue_node(node_id) | |||
| self._execution_tracker.add(node_id) | |||
| self._state_manager.enqueue_node(node_id) | |||
| self._state_manager.start_execution(node_id) | |||
| # Update execution tracking | |||
| self._execution_tracker.remove(event.node_id) | |||
| self._state_manager.finish_execution(event.node_id) | |||
| # Handle response node outputs | |||
| if node.execution_type == NodeExecutionType.RESPONSE: | |||
| @@ -232,7 +229,7 @@ class EventHandlerRegistry: | |||
| # Abort execution | |||
| self._graph_execution.fail(RuntimeError(event.error)) | |||
| self._event_collector.collect(event) | |||
| self._execution_tracker.remove(event.node_id) | |||
| self._state_manager.finish_execution(event.node_id) | |||
| def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: | |||
| """ | |||
| @@ -137,20 +137,18 @@ class GraphEngine: | |||
| self.node_readiness_checker = NodeReadinessChecker(self.graph) | |||
| self.edge_processor = EdgeProcessor( | |||
| graph=self.graph, | |||
| edge_state_manager=self.state_manager, | |||
| node_state_manager=self.state_manager, | |||
| state_manager=self.state_manager, | |||
| response_coordinator=self.response_coordinator, | |||
| ) | |||
| self.skip_propagator = SkipPropagator( | |||
| graph=self.graph, | |||
| edge_state_manager=self.state_manager, | |||
| node_state_manager=self.state_manager, | |||
| state_manager=self.state_manager, | |||
| ) | |||
| self.branch_handler = BranchHandler( | |||
| graph=self.graph, | |||
| edge_processor=self.edge_processor, | |||
| skip_propagator=self.skip_propagator, | |||
| edge_state_manager=self.state_manager, | |||
| state_manager=self.state_manager, | |||
| ) | |||
| # Event handler registry with all dependencies | |||
| @@ -162,8 +160,7 @@ class GraphEngine: | |||
| event_collector=self.event_collector, | |||
| branch_handler=self.branch_handler, | |||
| edge_processor=self.edge_processor, | |||
| node_state_manager=self.state_manager, | |||
| execution_tracker=self.state_manager, | |||
| state_manager=self.state_manager, | |||
| error_handler=self.error_handler, | |||
| ) | |||
| @@ -180,8 +177,7 @@ class GraphEngine: | |||
| # Orchestration | |||
| self.execution_coordinator = ExecutionCoordinator( | |||
| graph_execution=self.graph_execution, | |||
| node_state_manager=self.state_manager, | |||
| execution_tracker=self.state_manager, | |||
| state_manager=self.state_manager, | |||
| event_handler=self.event_handler_registry, | |||
| event_collector=self.event_collector, | |||
| command_processor=self.command_processor, | |||
| @@ -334,7 +330,7 @@ class GraphEngine: | |||
| # Enqueue root node | |||
| root_node = self.graph.root_node | |||
| self.state_manager.enqueue_node(root_node.id) | |||
| self.state_manager.add(root_node.id) | |||
| self.state_manager.start_execution(root_node.id) | |||
| # Start dispatcher | |||
| self.dispatcher.start() | |||
| @@ -27,7 +27,7 @@ class BranchHandler: | |||
| graph: Graph, | |||
| edge_processor: EdgeProcessor, | |||
| skip_propagator: SkipPropagator, | |||
| edge_state_manager: UnifiedStateManager, | |||
| state_manager: UnifiedStateManager, | |||
| ) -> None: | |||
| """ | |||
| Initialize the branch handler. | |||
| @@ -36,12 +36,12 @@ class BranchHandler: | |||
| graph: The workflow graph | |||
| edge_processor: Processor for edges | |||
| skip_propagator: Propagator for skip states | |||
| edge_state_manager: Manager for edge states | |||
| state_manager: Unified state manager | |||
| """ | |||
| self.graph = graph | |||
| self.edge_processor = edge_processor | |||
| self.skip_propagator = skip_propagator | |||
| self.edge_state_manager = edge_state_manager | |||
| self.state_manager = state_manager | |||
| def handle_branch_completion( | |||
| self, node_id: str, selected_handle: str | None | |||
| @@ -63,7 +63,7 @@ class BranchHandler: | |||
| raise ValueError(f"Branch node {node_id} completed without selecting a branch") | |||
| # Categorize edges into selected and unselected | |||
| _, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| _, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| # Skip all unselected paths | |||
| self.skip_propagator.skip_branch_paths(unselected_edges) | |||
| @@ -25,8 +25,7 @@ class EdgeProcessor: | |||
| def __init__( | |||
| self, | |||
| graph: Graph, | |||
| edge_state_manager: UnifiedStateManager, | |||
| node_state_manager: UnifiedStateManager, | |||
| state_manager: UnifiedStateManager, | |||
| response_coordinator: ResponseStreamCoordinator, | |||
| ) -> None: | |||
| """ | |||
| @@ -34,13 +33,11 @@ class EdgeProcessor: | |||
| Args: | |||
| graph: The workflow graph | |||
| edge_state_manager: Manager for edge states | |||
| node_state_manager: Manager for node states | |||
| state_manager: Unified state manager | |||
| response_coordinator: Response stream coordinator | |||
| """ | |||
| self.graph = graph | |||
| self.edge_state_manager = edge_state_manager | |||
| self.node_state_manager = node_state_manager | |||
| self.state_manager = state_manager | |||
| self.response_coordinator = response_coordinator | |||
| def process_node_success( | |||
| @@ -107,7 +104,7 @@ class EdgeProcessor: | |||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | |||
| # Categorize edges | |||
| selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| # Process unselected edges first (mark as skipped) | |||
| for edge in unselected_edges: | |||
| @@ -132,14 +129,14 @@ class EdgeProcessor: | |||
| Tuple of (list containing downstream node ID if it's ready, list of streaming events) | |||
| """ | |||
| # Mark edge as taken | |||
| self.edge_state_manager.mark_edge_taken(edge.id) | |||
| self.state_manager.mark_edge_taken(edge.id) | |||
| # Notify response coordinator and get streaming events | |||
| streaming_events = self.response_coordinator.on_edge_taken(edge.id) | |||
| # Check if downstream node is ready | |||
| ready_nodes: list[str] = [] | |||
| if self.node_state_manager.is_node_ready(edge.head): | |||
| if self.state_manager.is_node_ready(edge.head): | |||
| ready_nodes.append(edge.head) | |||
| return ready_nodes, streaming_events | |||
| @@ -151,4 +148,4 @@ class EdgeProcessor: | |||
| Args: | |||
| edge: The edge to skip | |||
| """ | |||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| @@ -22,20 +22,17 @@ class SkipPropagator: | |||
| def __init__( | |||
| self, | |||
| graph: Graph, | |||
| edge_state_manager: UnifiedStateManager, | |||
| node_state_manager: UnifiedStateManager, | |||
| state_manager: UnifiedStateManager, | |||
| ) -> None: | |||
| """ | |||
| Initialize the skip propagator. | |||
| Args: | |||
| graph: The workflow graph | |||
| edge_state_manager: Manager for edge states | |||
| node_state_manager: Manager for node states | |||
| state_manager: Unified state manager | |||
| """ | |||
| self.graph = graph | |||
| self.edge_state_manager = edge_state_manager | |||
| self.node_state_manager = node_state_manager | |||
| self.state_manager = state_manager | |||
| def propagate_skip_from_edge(self, edge_id: str) -> None: | |||
| """ | |||
| @@ -53,7 +50,7 @@ class SkipPropagator: | |||
| incoming_edges = self.graph.get_incoming_edges(downstream_node_id) | |||
| # Analyze edge states | |||
| edge_states = self.edge_state_manager.analyze_edge_states(incoming_edges) | |||
| edge_states = self.state_manager.analyze_edge_states(incoming_edges) | |||
| # Stop if there are unknown edges (not yet processed) | |||
| if edge_states["has_unknown"]: | |||
| @@ -62,7 +59,7 @@ class SkipPropagator: | |||
| # If any edge is taken, node may still execute | |||
| if edge_states["has_taken"]: | |||
| # Enqueue node | |||
| self.node_state_manager.enqueue_node(downstream_node_id) | |||
| self.state_manager.enqueue_node(downstream_node_id) | |||
| return | |||
| # All edges are skipped, propagate skip to this node | |||
| @@ -77,12 +74,12 @@ class SkipPropagator: | |||
| node_id: The ID of the node to skip | |||
| """ | |||
| # Mark node as skipped | |||
| self.node_state_manager.mark_node_skipped(node_id) | |||
| self.state_manager.mark_node_skipped(node_id) | |||
| # Mark all outgoing edges as skipped and propagate | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| for edge in outgoing_edges: | |||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| # Recursively propagate skip | |||
| self.propagate_skip_from_edge(edge.id) | |||
| @@ -94,5 +91,5 @@ class SkipPropagator: | |||
| unselected_edges: List of edges not taken by the branch | |||
| """ | |||
| for edge in unselected_edges: | |||
| self.edge_state_manager.mark_edge_skipped(edge.id) | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| self.propagate_skip_from_edge(edge.id) | |||
| @@ -26,8 +26,7 @@ class ExecutionCoordinator: | |||
| def __init__( | |||
| self, | |||
| graph_execution: GraphExecution, | |||
| node_state_manager: UnifiedStateManager, | |||
| execution_tracker: UnifiedStateManager, | |||
| state_manager: UnifiedStateManager, | |||
| event_handler: "EventHandlerRegistry", | |||
| event_collector: EventCollector, | |||
| command_processor: CommandProcessor, | |||
| @@ -38,16 +37,14 @@ class ExecutionCoordinator: | |||
| Args: | |||
| graph_execution: Graph execution aggregate | |||
| node_state_manager: Manager for node states | |||
| execution_tracker: Tracker for executing nodes | |||
| state_manager: Unified state manager | |||
| event_handler: Event handler registry for processing events | |||
| event_collector: Event collector for collecting events | |||
| command_processor: Processor for commands | |||
| worker_pool: Pool of workers | |||
| """ | |||
| self.graph_execution = graph_execution | |||
| self.node_state_manager = node_state_manager | |||
| self.execution_tracker = execution_tracker | |||
| self.state_manager = state_manager | |||
| self.event_handler = event_handler | |||
| self.event_collector = event_collector | |||
| self.command_processor = command_processor | |||
| @@ -59,8 +56,8 @@ class ExecutionCoordinator: | |||
| def check_scaling(self) -> None: | |||
| """Check and perform worker scaling if needed.""" | |||
| queue_depth = self.node_state_manager.ready_queue.qsize() | |||
| executing_count = self.execution_tracker.count() | |||
| queue_depth = self.state_manager.ready_queue.qsize() | |||
| executing_count = self.state_manager.get_executing_count() | |||
| self.worker_pool.check_scaling(queue_depth, executing_count) | |||
| def is_execution_complete(self) -> bool: | |||
| @@ -75,7 +72,7 @@ class ExecutionCoordinator: | |||
| return True | |||
| # Complete if no work remains | |||
| return self.node_state_manager.ready_queue.empty() and self.execution_tracker.is_empty() | |||
| return self.state_manager.is_execution_complete() | |||
| def mark_complete(self) -> None: | |||
| """Mark execution as complete.""" | |||
| @@ -302,42 +302,3 @@ class UnifiedStateManager: | |||
| "skipped_nodes": skipped_nodes, | |||
| "unknown_nodes": unknown_nodes, | |||
| } | |||
| # ============= Backward Compatibility Methods ============= | |||
| # These methods provide compatibility with existing code | |||
| @property | |||
| def execution_tracker(self) -> "UnifiedStateManager": | |||
| """Compatibility property for ExecutionTracker access.""" | |||
| return self | |||
| @property | |||
| def node_state_manager(self) -> "UnifiedStateManager": | |||
| """Compatibility property for NodeStateManager access.""" | |||
| return self | |||
| @property | |||
| def edge_state_manager(self) -> "UnifiedStateManager": | |||
| """Compatibility property for EdgeStateManager access.""" | |||
| return self | |||
| # ExecutionTracker compatibility methods | |||
| def add(self, node_id: str) -> None: | |||
| """Compatibility method for ExecutionTracker.add().""" | |||
| self.start_execution(node_id) | |||
| def remove(self, node_id: str) -> None: | |||
| """Compatibility method for ExecutionTracker.remove().""" | |||
| self.finish_execution(node_id) | |||
| def is_empty(self) -> bool: | |||
| """Compatibility method for ExecutionTracker.is_empty().""" | |||
| return len(self._executing_nodes) == 0 | |||
| def count(self) -> int: | |||
| """Compatibility method for ExecutionTracker.count().""" | |||
| return self.get_executing_count() | |||
| def clear(self) -> None: | |||
| """Compatibility method for ExecutionTracker.clear().""" | |||
| self.clear_executing() | |||
| @@ -330,31 +330,3 @@ class EnhancedWorkerPool: | |||
| "min_workers": self.min_workers, | |||
| "max_workers": self.max_workers, | |||
| } | |||
| # ============= Backward Compatibility ============= | |||
| def scale_up(self) -> None: | |||
| """Compatibility method for manual scale up.""" | |||
| with self._lock: | |||
| if self._running and len(self.workers) < self.max_workers: | |||
| self._add_worker() | |||
| def scale_down(self, worker_ids: list[int]) -> None: | |||
| """Compatibility method for manual scale down.""" | |||
| with self._lock: | |||
| if not self._running: | |||
| return | |||
| for worker_id in worker_ids: | |||
| if len(self.workers) > self.min_workers: | |||
| self._remove_worker(worker_id) | |||
| def check_scaling(self, queue_depth: int, executing_count: int) -> None: | |||
| """ | |||
| Compatibility method for checking scaling. | |||
| Args: | |||
| queue_depth: Current queue depth (ignored, we check directly) | |||
| executing_count: Number of executing nodes (ignored) | |||
| """ | |||
| self.check_and_scale() | |||