Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| @@ -39,8 +39,8 @@ class CommandProcessor: | |||
| command_channel: Channel for receiving commands | |||
| graph_execution: Graph execution aggregate | |||
| """ | |||
| self.command_channel = command_channel | |||
| self.graph_execution = graph_execution | |||
| self._command_channel = command_channel | |||
| self._graph_execution = graph_execution | |||
| self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} | |||
| def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: | |||
| @@ -56,7 +56,7 @@ class CommandProcessor: | |||
| def process_commands(self) -> None: | |||
| """Check for and process any pending commands.""" | |||
| try: | |||
| commands = self.command_channel.fetch_commands() | |||
| commands = self._command_channel.fetch_commands() | |||
| for command in commands: | |||
| self._handle_command(command) | |||
| except Exception as e: | |||
| @@ -72,7 +72,7 @@ class CommandProcessor: | |||
| handler = self._handlers.get(type(command)) | |||
| if handler: | |||
| try: | |||
| handler.handle(command, self.graph_execution) | |||
| handler.handle(command, self._graph_execution) | |||
| except Exception: | |||
| logger.exception("Error handling command %s", command.__class__.__name__) | |||
| else: | |||
| @@ -32,6 +32,8 @@ class AbortStrategy: | |||
| Returns: | |||
| None - signals abortion | |||
| """ | |||
| _ = graph | |||
| _ = retry_count | |||
| logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) | |||
| # Return None to signal that execution should stop | |||
| @@ -31,6 +31,7 @@ class DefaultValueStrategy: | |||
| Returns: | |||
| NodeRunExceptionEvent with default values | |||
| """ | |||
| _ = retry_count | |||
| node = graph.nodes[event.node_id] | |||
| outputs = { | |||
| @@ -31,6 +31,8 @@ class FailBranchStrategy: | |||
| Returns: | |||
| NodeRunExceptionEvent to continue via fail branch | |||
| """ | |||
| _ = graph | |||
| _ = retry_count | |||
| outputs = { | |||
| "error_message": event.node_run_result.error, | |||
| "error_type": event.node_run_result.error_type, | |||
| @@ -23,7 +23,7 @@ class ReadWriteLock: | |||
| def acquire_read(self) -> None: | |||
| """Acquire a read lock.""" | |||
| self._read_ready.acquire() | |||
| _ = self._read_ready.acquire() | |||
| try: | |||
| self._readers += 1 | |||
| finally: | |||
| @@ -31,7 +31,7 @@ class ReadWriteLock: | |||
| def release_read(self) -> None: | |||
| """Release a read lock.""" | |||
| self._read_ready.acquire() | |||
| _ = self._read_ready.acquire() | |||
| try: | |||
| self._readers -= 1 | |||
| if self._readers == 0: | |||
| @@ -41,9 +41,9 @@ class ReadWriteLock: | |||
| def acquire_write(self) -> None: | |||
| """Acquire a write lock.""" | |||
| self._read_ready.acquire() | |||
| _ = self._read_ready.acquire() | |||
| while self._readers > 0: | |||
| self._read_ready.wait() | |||
| _ = self._read_ready.wait() | |||
| def release_write(self) -> None: | |||
| """Release a write lock.""" | |||
| @@ -28,7 +28,7 @@ class EventEmitter: | |||
| Args: | |||
| event_collector: The collector to emit events from | |||
| """ | |||
| self.event_collector = event_collector | |||
| self._event_collector = event_collector | |||
| self._execution_complete = threading.Event() | |||
| def mark_complete(self) -> None: | |||
| @@ -44,9 +44,9 @@ class EventEmitter: | |||
| """ | |||
| yielded_count = 0 | |||
| while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count(): | |||
| while not self._execution_complete.is_set() or yielded_count < self._event_collector.event_count(): | |||
| # Get new events since last yield | |||
| new_events = self.event_collector.get_new_events(yielded_count) | |||
| new_events = self._event_collector.get_new_events(yielded_count) | |||
| # Yield any new events | |||
| for event in new_events: | |||
| @@ -75,7 +75,7 @@ class GraphEngine: | |||
| """Initialize the graph engine with separated concerns.""" | |||
| # Create domain models | |||
| self.execution_context = ExecutionContext( | |||
| self._execution_context = ExecutionContext( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_id=workflow_id, | |||
| @@ -87,13 +87,13 @@ class GraphEngine: | |||
| max_execution_time=max_execution_time, | |||
| ) | |||
| self.graph_execution = GraphExecution(workflow_id=workflow_id) | |||
| self._graph_execution = GraphExecution(workflow_id=workflow_id) | |||
| # Store core dependencies | |||
| self.graph = graph | |||
| self.graph_config = graph_config | |||
| self.graph_runtime_state = graph_runtime_state | |||
| self.command_channel = command_channel | |||
| self._graph = graph | |||
| self._graph_config = graph_config | |||
| self._graph_runtime_state = graph_runtime_state | |||
| self._command_channel = command_channel | |||
| # Store worker management parameters | |||
| self._min_workers = min_workers | |||
| @@ -102,8 +102,8 @@ class GraphEngine: | |||
| self._scale_down_idle_time = scale_down_idle_time | |||
| # Initialize queues | |||
| self.ready_queue: queue.Queue[str] = queue.Queue() | |||
| self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() | |||
| self._ready_queue: queue.Queue[str] = queue.Queue() | |||
| self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() | |||
| # Initialize subsystems | |||
| self._initialize_subsystems() | |||
| @@ -118,55 +118,55 @@ class GraphEngine: | |||
| """Initialize all subsystems with proper dependency injection.""" | |||
| # Unified state management - single instance handles all state operations | |||
| self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) | |||
| self._state_manager = UnifiedStateManager(self._graph, self._ready_queue) | |||
| # Response coordination | |||
| self.response_coordinator = ResponseStreamCoordinator( | |||
| variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph | |||
| self._response_coordinator = ResponseStreamCoordinator( | |||
| variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph | |||
| ) | |||
| # Event management | |||
| self.event_collector = EventCollector() | |||
| self.event_emitter = EventEmitter(self.event_collector) | |||
| self._event_collector = EventCollector() | |||
| self._event_emitter = EventEmitter(self._event_collector) | |||
| # Error handling | |||
| self.error_handler = ErrorHandler(self.graph, self.graph_execution) | |||
| self._error_handler = ErrorHandler(self._graph, self._graph_execution) | |||
| # Graph traversal | |||
| self.node_readiness_checker = NodeReadinessChecker(self.graph) | |||
| self.edge_processor = EdgeProcessor( | |||
| graph=self.graph, | |||
| state_manager=self.state_manager, | |||
| response_coordinator=self.response_coordinator, | |||
| self._node_readiness_checker = NodeReadinessChecker(self._graph) | |||
| self._edge_processor = EdgeProcessor( | |||
| graph=self._graph, | |||
| state_manager=self._state_manager, | |||
| response_coordinator=self._response_coordinator, | |||
| ) | |||
| self.skip_propagator = SkipPropagator( | |||
| graph=self.graph, | |||
| state_manager=self.state_manager, | |||
| self._skip_propagator = SkipPropagator( | |||
| graph=self._graph, | |||
| state_manager=self._state_manager, | |||
| ) | |||
| self.branch_handler = BranchHandler( | |||
| graph=self.graph, | |||
| edge_processor=self.edge_processor, | |||
| skip_propagator=self.skip_propagator, | |||
| state_manager=self.state_manager, | |||
| self._branch_handler = BranchHandler( | |||
| graph=self._graph, | |||
| edge_processor=self._edge_processor, | |||
| skip_propagator=self._skip_propagator, | |||
| state_manager=self._state_manager, | |||
| ) | |||
| # Event handler registry with all dependencies | |||
| self.event_handler_registry = EventHandlerRegistry( | |||
| graph=self.graph, | |||
| graph_runtime_state=self.graph_runtime_state, | |||
| graph_execution=self.graph_execution, | |||
| response_coordinator=self.response_coordinator, | |||
| event_collector=self.event_collector, | |||
| branch_handler=self.branch_handler, | |||
| edge_processor=self.edge_processor, | |||
| state_manager=self.state_manager, | |||
| error_handler=self.error_handler, | |||
| self._event_handler_registry = EventHandlerRegistry( | |||
| graph=self._graph, | |||
| graph_runtime_state=self._graph_runtime_state, | |||
| graph_execution=self._graph_execution, | |||
| response_coordinator=self._response_coordinator, | |||
| event_collector=self._event_collector, | |||
| branch_handler=self._branch_handler, | |||
| edge_processor=self._edge_processor, | |||
| state_manager=self._state_manager, | |||
| error_handler=self._error_handler, | |||
| ) | |||
| # Command processing | |||
| self.command_processor = CommandProcessor( | |||
| command_channel=self.command_channel, | |||
| graph_execution=self.graph_execution, | |||
| self._command_processor = CommandProcessor( | |||
| command_channel=self._command_channel, | |||
| graph_execution=self._graph_execution, | |||
| ) | |||
| self._setup_command_handlers() | |||
| @@ -174,29 +174,29 @@ class GraphEngine: | |||
| self._setup_worker_management() | |||
| # Orchestration | |||
| self.execution_coordinator = ExecutionCoordinator( | |||
| graph_execution=self.graph_execution, | |||
| state_manager=self.state_manager, | |||
| event_handler=self.event_handler_registry, | |||
| event_collector=self.event_collector, | |||
| command_processor=self.command_processor, | |||
| self._execution_coordinator = ExecutionCoordinator( | |||
| graph_execution=self._graph_execution, | |||
| state_manager=self._state_manager, | |||
| event_handler=self._event_handler_registry, | |||
| event_collector=self._event_collector, | |||
| command_processor=self._command_processor, | |||
| worker_pool=self._worker_pool, | |||
| ) | |||
| self.dispatcher = Dispatcher( | |||
| event_queue=self.event_queue, | |||
| event_handler=self.event_handler_registry, | |||
| event_collector=self.event_collector, | |||
| execution_coordinator=self.execution_coordinator, | |||
| max_execution_time=self.execution_context.max_execution_time, | |||
| event_emitter=self.event_emitter, | |||
| self._dispatcher = Dispatcher( | |||
| event_queue=self._event_queue, | |||
| event_handler=self._event_handler_registry, | |||
| event_collector=self._event_collector, | |||
| execution_coordinator=self._execution_coordinator, | |||
| max_execution_time=self._execution_context.max_execution_time, | |||
| event_emitter=self._event_emitter, | |||
| ) | |||
| def _setup_command_handlers(self) -> None: | |||
| """Configure command handlers.""" | |||
| # Create handler instance that follows the protocol | |||
| abort_handler = AbortCommandHandler() | |||
| self.command_processor.register_handler( | |||
| self._command_processor.register_handler( | |||
| AbortCommand, | |||
| abort_handler, | |||
| ) | |||
| @@ -216,9 +216,9 @@ class GraphEngine: | |||
| # Create simple worker pool | |||
| self._worker_pool = SimpleWorkerPool( | |||
| ready_queue=self.ready_queue, | |||
| event_queue=self.event_queue, | |||
| graph=self.graph, | |||
| ready_queue=self._ready_queue, | |||
| event_queue=self._event_queue, | |||
| graph=self._graph, | |||
| flask_app=flask_app, | |||
| context_vars=context_vars, | |||
| min_workers=self._min_workers, | |||
| @@ -229,8 +229,8 @@ class GraphEngine: | |||
| def _validate_graph_state_consistency(self) -> None: | |||
| """Validate that all nodes share the same GraphRuntimeState.""" | |||
| expected_state_id = id(self.graph_runtime_state) | |||
| for node in self.graph.nodes.values(): | |||
| expected_state_id = id(self._graph_runtime_state) | |||
| for node in self._graph.nodes.values(): | |||
| if id(node.graph_runtime_state) != expected_state_id: | |||
| raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") | |||
| @@ -251,7 +251,7 @@ class GraphEngine: | |||
| self._initialize_layers() | |||
| # Start execution | |||
| self.graph_execution.start() | |||
| self._graph_execution.start() | |||
| start_event = GraphRunStartedEvent() | |||
| yield start_event | |||
| @@ -259,23 +259,23 @@ class GraphEngine: | |||
| self._start_execution() | |||
| # Yield events as they occur | |||
| yield from self.event_emitter.emit_events() | |||
| yield from self._event_emitter.emit_events() | |||
| # Handle completion | |||
| if self.graph_execution.aborted: | |||
| if self._graph_execution.aborted: | |||
| abort_reason = "Workflow execution aborted by user command" | |||
| if self.graph_execution.error: | |||
| abort_reason = str(self.graph_execution.error) | |||
| if self._graph_execution.error: | |||
| abort_reason = str(self._graph_execution.error) | |||
| yield GraphRunAbortedEvent( | |||
| reason=abort_reason, | |||
| outputs=self.graph_runtime_state.outputs, | |||
| outputs=self._graph_runtime_state.outputs, | |||
| ) | |||
| elif self.graph_execution.has_error: | |||
| if self.graph_execution.error: | |||
| raise self.graph_execution.error | |||
| elif self._graph_execution.has_error: | |||
| if self._graph_execution.error: | |||
| raise self._graph_execution.error | |||
| else: | |||
| yield GraphRunSucceededEvent( | |||
| outputs=self.graph_runtime_state.outputs, | |||
| outputs=self._graph_runtime_state.outputs, | |||
| ) | |||
| except Exception as e: | |||
| @@ -287,10 +287,10 @@ class GraphEngine: | |||
| def _initialize_layers(self) -> None: | |||
| """Initialize layers with context.""" | |||
| self.event_collector.set_layers(self._layers) | |||
| self._event_collector.set_layers(self._layers) | |||
| for layer in self._layers: | |||
| try: | |||
| layer.initialize(self.graph_runtime_state, self.command_channel) | |||
| layer.initialize(self._graph_runtime_state, self._command_channel) | |||
| except Exception as e: | |||
| logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) | |||
| @@ -305,21 +305,21 @@ class GraphEngine: | |||
| self._worker_pool.start() | |||
| # Register response nodes | |||
| for node in self.graph.nodes.values(): | |||
| for node in self._graph.nodes.values(): | |||
| if node.execution_type == NodeExecutionType.RESPONSE: | |||
| self.response_coordinator.register(node.id) | |||
| self._response_coordinator.register(node.id) | |||
| # Enqueue root node | |||
| root_node = self.graph.root_node | |||
| self.state_manager.enqueue_node(root_node.id) | |||
| self.state_manager.start_execution(root_node.id) | |||
| root_node = self._graph.root_node | |||
| self._state_manager.enqueue_node(root_node.id) | |||
| self._state_manager.start_execution(root_node.id) | |||
| # Start dispatcher | |||
| self.dispatcher.start() | |||
| self._dispatcher.start() | |||
| def _stop_execution(self) -> None: | |||
| """Stop execution subsystems.""" | |||
| self.dispatcher.stop() | |||
| self._dispatcher.stop() | |||
| self._worker_pool.stop() | |||
| # Don't mark complete here as the dispatcher already does it | |||
| @@ -328,6 +328,17 @@ class GraphEngine: | |||
| for layer in self._layers: | |||
| try: | |||
| layer.on_graph_end(self.graph_execution.error) | |||
| layer.on_graph_end(self._graph_execution.error) | |||
| except Exception as e: | |||
| logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) | |||
| # Public property accessors for attributes that need external access | |||
| @property | |||
| def graph_runtime_state(self) -> GraphRuntimeState: | |||
| """Get the graph runtime state.""" | |||
| return self._graph_runtime_state | |||
| @property | |||
| def graph(self) -> Graph: | |||
| """Get the graph.""" | |||
| return self._graph | |||
| @@ -38,10 +38,10 @@ class BranchHandler: | |||
| skip_propagator: Propagator for skip states | |||
| state_manager: Unified state manager | |||
| """ | |||
| self.graph = graph | |||
| self.edge_processor = edge_processor | |||
| self.skip_propagator = skip_propagator | |||
| self.state_manager = state_manager | |||
| self._graph = graph | |||
| self._edge_processor = edge_processor | |||
| self._skip_propagator = skip_propagator | |||
| self._state_manager = state_manager | |||
| def handle_branch_completion( | |||
| self, node_id: str, selected_handle: str | None | |||
| @@ -63,13 +63,13 @@ class BranchHandler: | |||
| raise ValueError(f"Branch node {node_id} completed without selecting a branch") | |||
| # Categorize edges into selected and unselected | |||
| _, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| # Skip all unselected paths | |||
| self.skip_propagator.skip_branch_paths(unselected_edges) | |||
| self._skip_propagator.skip_branch_paths(unselected_edges) | |||
| # Process selected edges and get ready nodes and streaming events | |||
| return self.edge_processor.process_node_success(node_id, selected_handle) | |||
| return self._edge_processor.process_node_success(node_id, selected_handle) | |||
| def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: | |||
| """ | |||
| @@ -82,6 +82,6 @@ class BranchHandler: | |||
| Returns: | |||
| True if the selection is valid | |||
| """ | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(node_id) | |||
| valid_handles = {edge.source_handle for edge in outgoing_edges} | |||
| return selected_handle in valid_handles | |||
| @@ -36,9 +36,9 @@ class EdgeProcessor: | |||
| state_manager: Unified state manager | |||
| response_coordinator: Response stream coordinator | |||
| """ | |||
| self.graph = graph | |||
| self.state_manager = state_manager | |||
| self.response_coordinator = response_coordinator | |||
| self._graph = graph | |||
| self._state_manager = state_manager | |||
| self._response_coordinator = response_coordinator | |||
| def process_node_success( | |||
| self, node_id: str, selected_handle: str | None = None | |||
| @@ -53,7 +53,7 @@ class EdgeProcessor: | |||
| Returns: | |||
| Tuple of (list of downstream node IDs that are now ready, list of streaming events) | |||
| """ | |||
| node = self.graph.nodes[node_id] | |||
| node = self._graph.nodes[node_id] | |||
| if node.execution_type == NodeExecutionType.BRANCH: | |||
| return self._process_branch_node_edges(node_id, selected_handle) | |||
| @@ -72,7 +72,7 @@ class EdgeProcessor: | |||
| """ | |||
| ready_nodes: list[str] = [] | |||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(node_id) | |||
| for edge in outgoing_edges: | |||
| nodes, events = self._process_taken_edge(edge) | |||
| @@ -104,7 +104,7 @@ class EdgeProcessor: | |||
| all_streaming_events: list[NodeRunStreamChunkEvent] = [] | |||
| # Categorize edges | |||
| selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) | |||
| # Process unselected edges first (mark as skipped) | |||
| for edge in unselected_edges: | |||
| @@ -129,14 +129,14 @@ class EdgeProcessor: | |||
| Tuple of (list containing downstream node ID if it's ready, list of streaming events) | |||
| """ | |||
| # Mark edge as taken | |||
| self.state_manager.mark_edge_taken(edge.id) | |||
| self._state_manager.mark_edge_taken(edge.id) | |||
| # Notify response coordinator and get streaming events | |||
| streaming_events = self.response_coordinator.on_edge_taken(edge.id) | |||
| streaming_events = self._response_coordinator.on_edge_taken(edge.id) | |||
| # Check if downstream node is ready | |||
| ready_nodes: list[str] = [] | |||
| if self.state_manager.is_node_ready(edge.head): | |||
| if self._state_manager.is_node_ready(edge.head): | |||
| ready_nodes.append(edge.head) | |||
| return ready_nodes, streaming_events | |||
| @@ -148,4 +148,4 @@ class EdgeProcessor: | |||
| Args: | |||
| edge: The edge to skip | |||
| """ | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| self._state_manager.mark_edge_skipped(edge.id) | |||
| @@ -24,7 +24,7 @@ class NodeReadinessChecker: | |||
| Args: | |||
| graph: The workflow graph | |||
| """ | |||
| self.graph = graph | |||
| self._graph = graph | |||
| def is_node_ready(self, node_id: str) -> bool: | |||
| """ | |||
| @@ -40,7 +40,7 @@ class NodeReadinessChecker: | |||
| Returns: | |||
| True if the node is ready for execution | |||
| """ | |||
| incoming_edges = self.graph.get_incoming_edges(node_id) | |||
| incoming_edges = self._graph.get_incoming_edges(node_id) | |||
| # No dependencies means always ready | |||
| if not incoming_edges: | |||
| @@ -75,7 +75,7 @@ class NodeReadinessChecker: | |||
| List of node IDs that are now ready | |||
| """ | |||
| ready_nodes: list[str] = [] | |||
| outgoing_edges = self.graph.get_outgoing_edges(from_node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(from_node_id) | |||
| for edge in outgoing_edges: | |||
| if edge.state == NodeState.TAKEN: | |||
| @@ -31,8 +31,8 @@ class SkipPropagator: | |||
| graph: The workflow graph | |||
| state_manager: Unified state manager | |||
| """ | |||
| self.graph = graph | |||
| self.state_manager = state_manager | |||
| self._graph = graph | |||
| self._state_manager = state_manager | |||
| def propagate_skip_from_edge(self, edge_id: str) -> None: | |||
| """ | |||
| @@ -46,11 +46,11 @@ class SkipPropagator: | |||
| Args: | |||
| edge_id: The ID of the skipped edge to start from | |||
| """ | |||
| downstream_node_id = self.graph.edges[edge_id].head | |||
| incoming_edges = self.graph.get_incoming_edges(downstream_node_id) | |||
| downstream_node_id = self._graph.edges[edge_id].head | |||
| incoming_edges = self._graph.get_incoming_edges(downstream_node_id) | |||
| # Analyze edge states | |||
| edge_states = self.state_manager.analyze_edge_states(incoming_edges) | |||
| edge_states = self._state_manager.analyze_edge_states(incoming_edges) | |||
| # Stop if there are unknown edges (not yet processed) | |||
| if edge_states["has_unknown"]: | |||
| @@ -59,7 +59,7 @@ class SkipPropagator: | |||
| # If any edge is taken, node may still execute | |||
| if edge_states["has_taken"]: | |||
| # Enqueue node | |||
| self.state_manager.enqueue_node(downstream_node_id) | |||
| self._state_manager.enqueue_node(downstream_node_id) | |||
| return | |||
| # All edges are skipped, propagate skip to this node | |||
| @@ -74,12 +74,12 @@ class SkipPropagator: | |||
| node_id: The ID of the node to skip | |||
| """ | |||
| # Mark node as skipped | |||
| self.state_manager.mark_node_skipped(node_id) | |||
| self._state_manager.mark_node_skipped(node_id) | |||
| # Mark all outgoing edges as skipped and propagate | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(node_id) | |||
| for edge in outgoing_edges: | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| self._state_manager.mark_edge_skipped(edge.id) | |||
| # Recursively propagate skip | |||
| self.propagate_skip_from_edge(edge.id) | |||
| @@ -91,5 +91,5 @@ class SkipPropagator: | |||
| unselected_edges: List of edges not taken by the branch | |||
| """ | |||
| for edge in unselected_edges: | |||
| self.state_manager.mark_edge_skipped(edge.id) | |||
| self._state_manager.mark_edge_skipped(edge.id) | |||
| self.propagate_skip_from_edge(edge.id) | |||
| @@ -48,12 +48,12 @@ class Dispatcher: | |||
| max_execution_time: Maximum execution time in seconds | |||
| event_emitter: Optional event emitter to signal completion | |||
| """ | |||
| self.event_queue = event_queue | |||
| self.event_handler = event_handler | |||
| self.event_collector = event_collector | |||
| self.execution_coordinator = execution_coordinator | |||
| self.max_execution_time = max_execution_time | |||
| self.event_emitter = event_emitter | |||
| self._event_queue = event_queue | |||
| self._event_handler = event_handler | |||
| self._event_collector = event_collector | |||
| self._execution_coordinator = execution_coordinator | |||
| self._max_execution_time = max_execution_time | |||
| self._event_emitter = event_emitter | |||
| self._thread: threading.Thread | None = None | |||
| self._stop_event = threading.Event() | |||
| @@ -80,28 +80,28 @@ class Dispatcher: | |||
| try: | |||
| while not self._stop_event.is_set(): | |||
| # Check for commands | |||
| self.execution_coordinator.check_commands() | |||
| self._execution_coordinator.check_commands() | |||
| # Check for scaling | |||
| self.execution_coordinator.check_scaling() | |||
| self._execution_coordinator.check_scaling() | |||
| # Process events | |||
| try: | |||
| event = self.event_queue.get(timeout=0.1) | |||
| event = self._event_queue.get(timeout=0.1) | |||
| # Route to the event handler | |||
| self.event_handler.handle_event(event) | |||
| self.event_queue.task_done() | |||
| self._event_handler.handle_event(event) | |||
| self._event_queue.task_done() | |||
| except queue.Empty: | |||
| # Check if execution is complete | |||
| if self.execution_coordinator.is_execution_complete(): | |||
| if self._execution_coordinator.is_execution_complete(): | |||
| break | |||
| except Exception as e: | |||
| logger.exception("Dispatcher error") | |||
| self.execution_coordinator.mark_failed(e) | |||
| self._execution_coordinator.mark_failed(e) | |||
| finally: | |||
| self.execution_coordinator.mark_complete() | |||
| self._execution_coordinator.mark_complete() | |||
| # Signal the event emitter that execution is complete | |||
| if self.event_emitter: | |||
| self.event_emitter.mark_complete() | |||
| if self._event_emitter: | |||
| self._event_emitter.mark_complete() | |||
| @@ -43,20 +43,20 @@ class ExecutionCoordinator: | |||
| command_processor: Processor for commands | |||
| worker_pool: Pool of workers | |||
| """ | |||
| self.graph_execution = graph_execution | |||
| self.state_manager = state_manager | |||
| self.event_handler = event_handler | |||
| self.event_collector = event_collector | |||
| self.command_processor = command_processor | |||
| self.worker_pool = worker_pool | |||
| self._graph_execution = graph_execution | |||
| self._state_manager = state_manager | |||
| self._event_handler = event_handler | |||
| self._event_collector = event_collector | |||
| self._command_processor = command_processor | |||
| self._worker_pool = worker_pool | |||
| def check_commands(self) -> None: | |||
| """Process any pending commands.""" | |||
| self.command_processor.process_commands() | |||
| self._command_processor.process_commands() | |||
| def check_scaling(self) -> None: | |||
| """Check and perform worker scaling if needed.""" | |||
| self.worker_pool.check_and_scale() | |||
| self._worker_pool.check_and_scale() | |||
| def is_execution_complete(self) -> bool: | |||
| """ | |||
| @@ -66,16 +66,16 @@ class ExecutionCoordinator: | |||
| True if execution is complete | |||
| """ | |||
| # Check if aborted or failed | |||
| if self.graph_execution.aborted or self.graph_execution.has_error: | |||
| if self._graph_execution.aborted or self._graph_execution.has_error: | |||
| return True | |||
| # Complete if no work remains | |||
| return self.state_manager.is_execution_complete() | |||
| return self._state_manager.is_execution_complete() | |||
| def mark_complete(self) -> None: | |||
| """Mark execution as complete.""" | |||
| if not self.graph_execution.completed: | |||
| self.graph_execution.complete() | |||
| if not self._graph_execution.completed: | |||
| self._graph_execution.complete() | |||
| def mark_failed(self, error: Exception) -> None: | |||
| """ | |||
| @@ -84,4 +84,4 @@ class ExecutionCoordinator: | |||
| Args: | |||
| error: The error that caused failure | |||
| """ | |||
| self.graph_execution.fail(error) | |||
| self._graph_execution.fail(error) | |||
| @@ -44,11 +44,11 @@ class ResponseStreamCoordinator: | |||
| variable_pool: VariablePool instance for accessing node variables | |||
| graph: Graph instance for looking up node information | |||
| """ | |||
| self.variable_pool = variable_pool | |||
| self.graph = graph | |||
| self.active_session: ResponseSession | None = None | |||
| self.waiting_sessions: deque[ResponseSession] = deque() | |||
| self.lock = RLock() | |||
| self._variable_pool = variable_pool | |||
| self._graph = graph | |||
| self._active_session: ResponseSession | None = None | |||
| self._waiting_sessions: deque[ResponseSession] = deque() | |||
| self._lock = RLock() | |||
| # Internal stream management (replacing OutputRegistry) | |||
| self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} | |||
| @@ -68,7 +68,7 @@ class ResponseStreamCoordinator: | |||
| self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session | |||
| def register(self, response_node_id: NodeID) -> None: | |||
| with self.lock: | |||
| with self._lock: | |||
| self._response_nodes.add(response_node_id) | |||
| # Build and save paths map for this response node | |||
| @@ -76,7 +76,7 @@ class ResponseStreamCoordinator: | |||
| self._paths_maps[response_node_id] = paths_map | |||
| # Create and store response session for this node | |||
| response_node = self.graph.nodes[response_node_id] | |||
| response_node = self._graph.nodes[response_node_id] | |||
| session = ResponseSession.from_node(response_node) | |||
| self._response_sessions[response_node_id] = session | |||
| @@ -87,7 +87,7 @@ class ResponseStreamCoordinator: | |||
| node_id: The ID of the node | |||
| execution_id: The execution ID from NodeRunStartedEvent | |||
| """ | |||
| with self.lock: | |||
| with self._lock: | |||
| self._node_execution_ids[node_id] = execution_id | |||
| def _get_or_create_execution_id(self, node_id: NodeID) -> str: | |||
| @@ -99,7 +99,7 @@ class ResponseStreamCoordinator: | |||
| Returns: | |||
| The execution ID for the node | |||
| """ | |||
| with self.lock: | |||
| with self._lock: | |||
| if node_id not in self._node_execution_ids: | |||
| self._node_execution_ids[node_id] = str(uuid4()) | |||
| return self._node_execution_ids[node_id] | |||
| @@ -116,14 +116,14 @@ class ResponseStreamCoordinator: | |||
| List of Path objects, where each path contains branch edge IDs | |||
| """ | |||
| # Get root node ID | |||
| root_node_id = self.graph.root_node.id | |||
| root_node_id = self._graph.root_node.id | |||
| # If root is the response node, return empty path | |||
| if root_node_id == response_node_id: | |||
| return [Path()] | |||
| # Extract variable selectors from the response node's template | |||
| response_node = self.graph.nodes[response_node_id] | |||
| response_node = self._graph.nodes[response_node_id] | |||
| response_session = ResponseSession.from_node(response_node) | |||
| template = response_session.template | |||
| @@ -149,7 +149,7 @@ class ResponseStreamCoordinator: | |||
| visited.add(current_node_id) | |||
| # Explore outgoing edges | |||
| outgoing_edges = self.graph.get_outgoing_edges(current_node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(current_node_id) | |||
| for edge in outgoing_edges: | |||
| edge_id = edge.id | |||
| next_node_id = edge.head | |||
| @@ -168,8 +168,8 @@ class ResponseStreamCoordinator: | |||
| for path in all_complete_paths: | |||
| blocking_edges: list[str] = [] | |||
| for edge_id in path: | |||
| edge = self.graph.edges[edge_id] | |||
| source_node = self.graph.nodes[edge.tail] | |||
| edge = self._graph.edges[edge_id] | |||
| source_node = self._graph.nodes[edge.tail] | |||
| # Check if node is a branch/container (original behavior) | |||
| if source_node.execution_type in { | |||
| @@ -199,7 +199,7 @@ class ResponseStreamCoordinator: | |||
| """ | |||
| events: list[NodeRunStreamChunkEvent] = [] | |||
| with self.lock: | |||
| with self._lock: | |||
| # Check each response node in order | |||
| for response_node_id in self._response_nodes: | |||
| if response_node_id not in self._paths_maps: | |||
| @@ -245,21 +245,21 @@ class ResponseStreamCoordinator: | |||
| # Remove from map to ensure it won't be activated again | |||
| del self._response_sessions[node_id] | |||
| if self.active_session is None: | |||
| self.active_session = session | |||
| if self._active_session is None: | |||
| self._active_session = session | |||
| # Try to flush immediately | |||
| events.extend(self.try_flush()) | |||
| else: | |||
| # Queue the session if another is active | |||
| self.waiting_sessions.append(session) | |||
| self._waiting_sessions.append(session) | |||
| return events | |||
| def intercept_event( | |||
| self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent | |||
| ) -> Sequence[NodeRunStreamChunkEvent]: | |||
| with self.lock: | |||
| with self._lock: | |||
| if isinstance(event, NodeRunStreamChunkEvent): | |||
| self._append_stream_chunk(event.selector, event) | |||
| if event.is_final: | |||
| @@ -269,9 +269,8 @@ class ResponseStreamCoordinator: | |||
| # Skip cause we share the same variable pool. | |||
| # | |||
| # for variable_name, variable_value in event.node_run_result.outputs.items(): | |||
| # self.variable_pool.add((event.node_id, variable_name), variable_value) | |||
| # self._variable_pool.add((event.node_id, variable_name), variable_value) | |||
| return self.try_flush() | |||
| return [] | |||
| def _create_stream_chunk_event( | |||
| self, | |||
| @@ -287,9 +286,9 @@ class ResponseStreamCoordinator: | |||
| active response node's information since these are not actual node IDs. | |||
| """ | |||
| # Check if this is a special selector that doesn't correspond to a node | |||
| if selector and selector[0] not in self.graph.nodes and self.active_session: | |||
| if selector and selector[0] not in self._graph.nodes and self._active_session: | |||
| # Use the active response node for special selectors | |||
| response_node = self.graph.nodes[self.active_session.node_id] | |||
| response_node = self._graph.nodes[self._active_session.node_id] | |||
| return NodeRunStreamChunkEvent( | |||
| id=execution_id, | |||
| node_id=response_node.id, | |||
| @@ -300,7 +299,7 @@ class ResponseStreamCoordinator: | |||
| ) | |||
| # Standard case: selector refers to an actual node | |||
| node = self.graph.nodes[node_id] | |||
| node = self._graph.nodes[node_id] | |||
| return NodeRunStreamChunkEvent( | |||
| id=execution_id, | |||
| node_id=node.id, | |||
| @@ -323,9 +322,9 @@ class ResponseStreamCoordinator: | |||
| # Determine which node to attribute the output to | |||
| # For special selectors (sys, env, conversation), use the active response node | |||
| # For regular selectors, use the source node | |||
| if self.active_session and source_selector_prefix not in self.graph.nodes: | |||
| if self._active_session and source_selector_prefix not in self._graph.nodes: | |||
| # Special selector - use active response node | |||
| output_node_id = self.active_session.node_id | |||
| output_node_id = self._active_session.node_id | |||
| else: | |||
| # Regular node selector | |||
| output_node_id = source_selector_prefix | |||
| @@ -336,8 +335,8 @@ class ResponseStreamCoordinator: | |||
| if event := self._pop_stream_chunk(segment.selector): | |||
| # For special selectors, we need to update the event to use | |||
| # the active response node's information | |||
| if self.active_session and source_selector_prefix not in self.graph.nodes: | |||
| response_node = self.graph.nodes[self.active_session.node_id] | |||
| if self._active_session and source_selector_prefix not in self._graph.nodes: | |||
| response_node = self._graph.nodes[self._active_session.node_id] | |||
| # Create a new event with the response node's information | |||
| # but keep the original selector | |||
| updated_event = NodeRunStreamChunkEvent( | |||
| @@ -359,10 +358,10 @@ class ResponseStreamCoordinator: | |||
| if stream_closed: | |||
| is_complete = True | |||
| elif value := self.variable_pool.get(segment.selector): | |||
| elif value := self._variable_pool.get(segment.selector): | |||
| # Process scalar value | |||
| is_last_segment = bool( | |||
| self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 | |||
| self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 | |||
| ) | |||
| events.append( | |||
| self._create_stream_chunk_event( | |||
| @@ -379,13 +378,13 @@ class ResponseStreamCoordinator: | |||
| def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: | |||
| """Process a text segment. Returns (events, is_complete).""" | |||
| assert self.active_session is not None | |||
| current_response_node = self.graph.nodes[self.active_session.node_id] | |||
| assert self._active_session is not None | |||
| current_response_node = self._graph.nodes[self._active_session.node_id] | |||
| # Use get_or_create_execution_id to ensure we have a consistent ID | |||
| execution_id = self._get_or_create_execution_id(current_response_node.id) | |||
| is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1 | |||
| is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 | |||
| event = self._create_stream_chunk_event( | |||
| node_id=current_response_node.id, | |||
| execution_id=execution_id, | |||
| @@ -396,29 +395,29 @@ class ResponseStreamCoordinator: | |||
| return [event] | |||
| def try_flush(self) -> list[NodeRunStreamChunkEvent]: | |||
| with self.lock: | |||
| if not self.active_session: | |||
| with self._lock: | |||
| if not self._active_session: | |||
| return [] | |||
| template = self.active_session.template | |||
| response_node_id = self.active_session.node_id | |||
| template = self._active_session.template | |||
| response_node_id = self._active_session.node_id | |||
| events: list[NodeRunStreamChunkEvent] = [] | |||
| # Process segments sequentially from current index | |||
| while self.active_session.index < len(template.segments): | |||
| segment = template.segments[self.active_session.index] | |||
| while self._active_session.index < len(template.segments): | |||
| segment = template.segments[self._active_session.index] | |||
| if isinstance(segment, VariableSegment): | |||
| # Check if the source node for this variable is skipped | |||
| # Only check for actual nodes, not special selectors (sys, env, conversation) | |||
| source_selector_prefix = segment.selector[0] if segment.selector else "" | |||
| if source_selector_prefix in self.graph.nodes: | |||
| source_node = self.graph.nodes[source_selector_prefix] | |||
| if source_selector_prefix in self._graph.nodes: | |||
| source_node = self._graph.nodes[source_selector_prefix] | |||
| if source_node.state == NodeState.SKIPPED: | |||
| # Skip this variable segment if the source node is skipped | |||
| self.active_session.index += 1 | |||
| self._active_session.index += 1 | |||
| continue | |||
| segment_events, is_complete = self._process_variable_segment(segment) | |||
| @@ -426,7 +425,7 @@ class ResponseStreamCoordinator: | |||
| # Only advance index if this variable segment is complete | |||
| if is_complete: | |||
| self.active_session.index += 1 | |||
| self._active_session.index += 1 | |||
| else: | |||
| # Wait for more data | |||
| break | |||
| @@ -434,9 +433,9 @@ class ResponseStreamCoordinator: | |||
| else: | |||
| segment_events = self._process_text_segment(segment) | |||
| events.extend(segment_events) | |||
| self.active_session.index += 1 | |||
| self._active_session.index += 1 | |||
| if self.active_session.is_complete(): | |||
| if self._active_session.is_complete(): | |||
| # End current session and get events from starting next session | |||
| next_session_events = self.end_session(response_node_id) | |||
| events.extend(next_session_events) | |||
| @@ -454,16 +453,16 @@ class ResponseStreamCoordinator: | |||
| Returns: | |||
| List of events from starting the next session | |||
| """ | |||
| with self.lock: | |||
| with self._lock: | |||
| events: list[NodeRunStreamChunkEvent] = [] | |||
| if self.active_session and self.active_session.node_id == node_id: | |||
| self.active_session = None | |||
| if self._active_session and self._active_session.node_id == node_id: | |||
| self._active_session = None | |||
| # Try to start next waiting session | |||
| if self.waiting_sessions: | |||
| next_session = self.waiting_sessions.popleft() | |||
| self.active_session = next_session | |||
| if self._waiting_sessions: | |||
| next_session = self._waiting_sessions.popleft() | |||
| self._active_session = next_session | |||
| # Immediately try to flush any available segments | |||
| events = self.try_flush() | |||
| @@ -46,8 +46,8 @@ class UnifiedStateManager: | |||
| graph: The workflow graph | |||
| ready_queue: Queue for nodes ready to execute | |||
| """ | |||
| self.graph = graph | |||
| self.ready_queue = ready_queue | |||
| self._graph = graph | |||
| self._ready_queue = ready_queue | |||
| self._lock = threading.RLock() | |||
| # Execution tracking state | |||
| @@ -66,8 +66,8 @@ class UnifiedStateManager: | |||
| node_id: The ID of the node to enqueue | |||
| """ | |||
| with self._lock: | |||
| self.graph.nodes[node_id].state = NodeState.TAKEN | |||
| self.ready_queue.put(node_id) | |||
| self._graph.nodes[node_id].state = NodeState.TAKEN | |||
| self._ready_queue.put(node_id) | |||
| def mark_node_skipped(self, node_id: str) -> None: | |||
| """ | |||
| @@ -77,7 +77,7 @@ class UnifiedStateManager: | |||
| node_id: The ID of the node to skip | |||
| """ | |||
| with self._lock: | |||
| self.graph.nodes[node_id].state = NodeState.SKIPPED | |||
| self._graph.nodes[node_id].state = NodeState.SKIPPED | |||
| def is_node_ready(self, node_id: str) -> bool: | |||
| """ | |||
| @@ -94,7 +94,7 @@ class UnifiedStateManager: | |||
| """ | |||
| with self._lock: | |||
| # Get all incoming edges to this node | |||
| incoming_edges = self.graph.get_incoming_edges(node_id) | |||
| incoming_edges = self._graph.get_incoming_edges(node_id) | |||
| # If no incoming edges, node is always ready | |||
| if not incoming_edges: | |||
| @@ -118,7 +118,7 @@ class UnifiedStateManager: | |||
| The current node state | |||
| """ | |||
| with self._lock: | |||
| return self.graph.nodes[node_id].state | |||
| return self._graph.nodes[node_id].state | |||
| # ============= Edge State Operations ============= | |||
| @@ -130,7 +130,7 @@ class UnifiedStateManager: | |||
| edge_id: The ID of the edge to mark | |||
| """ | |||
| with self._lock: | |||
| self.graph.edges[edge_id].state = NodeState.TAKEN | |||
| self._graph.edges[edge_id].state = NodeState.TAKEN | |||
| def mark_edge_skipped(self, edge_id: str) -> None: | |||
| """ | |||
| @@ -140,7 +140,7 @@ class UnifiedStateManager: | |||
| edge_id: The ID of the edge to mark | |||
| """ | |||
| with self._lock: | |||
| self.graph.edges[edge_id].state = NodeState.SKIPPED | |||
| self._graph.edges[edge_id].state = NodeState.SKIPPED | |||
| def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: | |||
| """ | |||
| @@ -172,7 +172,7 @@ class UnifiedStateManager: | |||
| The current edge state | |||
| """ | |||
| with self._lock: | |||
| return self.graph.edges[edge_id].state | |||
| return self._graph.edges[edge_id].state | |||
| def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: | |||
| """ | |||
| @@ -186,7 +186,7 @@ class UnifiedStateManager: | |||
| A tuple of (selected_edges, unselected_edges) | |||
| """ | |||
| with self._lock: | |||
| outgoing_edges = self.graph.get_outgoing_edges(node_id) | |||
| outgoing_edges = self._graph.get_outgoing_edges(node_id) | |||
| selected_edges: list[Edge] = [] | |||
| unselected_edges: list[Edge] = [] | |||
| @@ -272,7 +272,7 @@ class UnifiedStateManager: | |||
| True if execution is complete | |||
| """ | |||
| with self._lock: | |||
| return self.ready_queue.empty() and len(self._executing_nodes) == 0 | |||
| return self._ready_queue.empty() and len(self._executing_nodes) == 0 | |||
| def get_queue_depth(self) -> int: | |||
| """ | |||
| @@ -281,7 +281,7 @@ class UnifiedStateManager: | |||
| Returns: | |||
| Number of nodes in the ready queue | |||
| """ | |||
| return self.ready_queue.qsize() | |||
| return self._ready_queue.qsize() | |||
| def get_execution_stats(self) -> dict[str, int]: | |||
| """ | |||
| @@ -291,12 +291,12 @@ class UnifiedStateManager: | |||
| Dictionary with execution statistics | |||
| """ | |||
| with self._lock: | |||
| taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN) | |||
| skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED) | |||
| unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN) | |||
| taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) | |||
| skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) | |||
| unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) | |||
| return { | |||
| "queue_depth": self.ready_queue.qsize(), | |||
| "queue_depth": self._ready_queue.qsize(), | |||
| "executing": len(self._executing_nodes), | |||
| "taken_nodes": taken_nodes, | |||
| "skipped_nodes": skipped_nodes, | |||
| @@ -59,16 +59,16 @@ class Worker(threading.Thread): | |||
| on_active_callback: Optional callback when worker becomes active | |||
| """ | |||
| super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) | |||
| self.ready_queue = ready_queue | |||
| self.event_queue = event_queue | |||
| self.graph = graph | |||
| self.worker_id = worker_id | |||
| self.flask_app = flask_app | |||
| self.context_vars = context_vars | |||
| self._ready_queue = ready_queue | |||
| self._event_queue = event_queue | |||
| self._graph = graph | |||
| self._worker_id = worker_id | |||
| self._flask_app = flask_app | |||
| self._context_vars = context_vars | |||
| self._stop_event = threading.Event() | |||
| self.on_idle_callback = on_idle_callback | |||
| self.on_active_callback = on_active_callback | |||
| self.last_task_time = time.time() | |||
| self._on_idle_callback = on_idle_callback | |||
| self._on_active_callback = on_active_callback | |||
| self._last_task_time = time.time() | |||
| def stop(self) -> None: | |||
| """Signal the worker to stop processing.""" | |||
| @@ -85,22 +85,22 @@ class Worker(threading.Thread): | |||
| while not self._stop_event.is_set(): | |||
| # Try to get a node ID from the ready queue (with timeout) | |||
| try: | |||
| node_id = self.ready_queue.get(timeout=0.1) | |||
| node_id = self._ready_queue.get(timeout=0.1) | |||
| except queue.Empty: | |||
| # Notify that worker is idle | |||
| if self.on_idle_callback: | |||
| self.on_idle_callback(self.worker_id) | |||
| if self._on_idle_callback: | |||
| self._on_idle_callback(self._worker_id) | |||
| continue | |||
| # Notify that worker is active | |||
| if self.on_active_callback: | |||
| self.on_active_callback(self.worker_id) | |||
| if self._on_active_callback: | |||
| self._on_active_callback(self._worker_id) | |||
| self.last_task_time = time.time() | |||
| node = self.graph.nodes[node_id] | |||
| self._last_task_time = time.time() | |||
| node = self._graph.nodes[node_id] | |||
| try: | |||
| self._execute_node(node) | |||
| self.ready_queue.task_done() | |||
| self._ready_queue.task_done() | |||
| except Exception as e: | |||
| error_event = NodeRunFailedEvent( | |||
| id=str(uuid4()), | |||
| @@ -110,7 +110,7 @@ class Worker(threading.Thread): | |||
| error=str(e), | |||
| start_at=datetime.now(), | |||
| ) | |||
| self.event_queue.put(error_event) | |||
| self._event_queue.put(error_event) | |||
| def _execute_node(self, node: Node) -> None: | |||
| """ | |||
| @@ -120,19 +120,19 @@ class Worker(threading.Thread): | |||
| node: The node instance to execute | |||
| """ | |||
| # Execute the node with preserved context if Flask app is provided | |||
| if self.flask_app and self.context_vars: | |||
| if self._flask_app and self._context_vars: | |||
| with preserve_flask_contexts( | |||
| flask_app=self.flask_app, | |||
| context_vars=self.context_vars, | |||
| flask_app=self._flask_app, | |||
| context_vars=self._context_vars, | |||
| ): | |||
| # Execute the node | |||
| node_events = node.run() | |||
| for event in node_events: | |||
| # Forward event to dispatcher immediately for streaming | |||
| self.event_queue.put(event) | |||
| self._event_queue.put(event) | |||
| else: | |||
| # Execute without context preservation | |||
| node_events = node.run() | |||
| for event in node_events: | |||
| # Forward event to dispatcher immediately for streaming | |||
| self.event_queue.put(event) | |||
| self._event_queue.put(event) | |||
| @@ -56,20 +56,20 @@ class SimpleWorkerPool: | |||
| scale_up_threshold: Queue depth to trigger scale up | |||
| scale_down_idle_time: Seconds before scaling down idle workers | |||
| """ | |||
| self.ready_queue = ready_queue | |||
| self.event_queue = event_queue | |||
| self.graph = graph | |||
| self.flask_app = flask_app | |||
| self.context_vars = context_vars | |||
| self._ready_queue = ready_queue | |||
| self._event_queue = event_queue | |||
| self._graph = graph | |||
| self._flask_app = flask_app | |||
| self._context_vars = context_vars | |||
| # Scaling parameters with defaults | |||
| self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS | |||
| self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS | |||
| self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD | |||
| self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME | |||
| self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS | |||
| self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS | |||
| self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD | |||
| self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME | |||
| # Worker management | |||
| self.workers: list[Worker] = [] | |||
| self._workers: list[Worker] = [] | |||
| self._worker_counter = 0 | |||
| self._lock = threading.RLock() | |||
| self._running = False | |||
| @@ -89,13 +89,13 @@ class SimpleWorkerPool: | |||
| # Calculate initial worker count | |||
| if initial_count is None: | |||
| node_count = len(self.graph.nodes) | |||
| node_count = len(self._graph.nodes) | |||
| if node_count < 10: | |||
| initial_count = self.min_workers | |||
| initial_count = self._min_workers | |||
| elif node_count < 50: | |||
| initial_count = min(self.min_workers + 1, self.max_workers) | |||
| initial_count = min(self._min_workers + 1, self._max_workers) | |||
| else: | |||
| initial_count = min(self.min_workers + 2, self.max_workers) | |||
| initial_count = min(self._min_workers + 2, self._max_workers) | |||
| # Create initial workers | |||
| for _ in range(initial_count): | |||
| @@ -107,15 +107,15 @@ class SimpleWorkerPool: | |||
| self._running = False | |||
| # Stop all workers | |||
| for worker in self.workers: | |||
| for worker in self._workers: | |||
| worker.stop() | |||
| # Wait for workers to finish | |||
| for worker in self.workers: | |||
| for worker in self._workers: | |||
| if worker.is_alive(): | |||
| worker.join(timeout=10.0) | |||
| self.workers.clear() | |||
| self._workers.clear() | |||
| def _create_worker(self) -> None: | |||
| """Create and start a new worker.""" | |||
| @@ -123,16 +123,16 @@ class SimpleWorkerPool: | |||
| self._worker_counter += 1 | |||
| worker = Worker( | |||
| ready_queue=self.ready_queue, | |||
| event_queue=self.event_queue, | |||
| graph=self.graph, | |||
| ready_queue=self._ready_queue, | |||
| event_queue=self._event_queue, | |||
| graph=self._graph, | |||
| worker_id=worker_id, | |||
| flask_app=self.flask_app, | |||
| context_vars=self.context_vars, | |||
| flask_app=self._flask_app, | |||
| context_vars=self._context_vars, | |||
| ) | |||
| worker.start() | |||
| self.workers.append(worker) | |||
| self._workers.append(worker) | |||
| def check_and_scale(self) -> None: | |||
| """Check and perform scaling if needed.""" | |||
| @@ -140,17 +140,17 @@ class SimpleWorkerPool: | |||
| if not self._running: | |||
| return | |||
| current_count = len(self.workers) | |||
| queue_depth = self.ready_queue.qsize() | |||
| current_count = len(self._workers) | |||
| queue_depth = self._ready_queue.qsize() | |||
| # Simple scaling logic | |||
| if queue_depth > self.scale_up_threshold and current_count < self.max_workers: | |||
| if queue_depth > self._scale_up_threshold and current_count < self._max_workers: | |||
| self._create_worker() | |||
| def get_worker_count(self) -> int: | |||
| """Get current number of workers.""" | |||
| with self._lock: | |||
| return len(self.workers) | |||
| return len(self._workers) | |||
| def get_status(self) -> dict[str, int]: | |||
| """ | |||
| @@ -161,8 +161,8 @@ class SimpleWorkerPool: | |||
| """ | |||
| with self._lock: | |||
| return { | |||
| "total_workers": len(self.workers), | |||
| "queue_depth": self.ready_queue.qsize(), | |||
| "min_workers": self.min_workers, | |||
| "max_workers": self.max_workers, | |||
| "total_workers": len(self._workers), | |||
| "queue_depth": self._ready_queue.qsize(), | |||
| "min_workers": self._min_workers, | |||
| "max_workers": self._max_workers, | |||
| } | |||