Просмотр исходного кода

refactor(graph_engine): Correct private attributes and private methods naming

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.1
-LAN- 2 месяцев назад
Родитель
Сommit
0fdb1b2bc9
Аккаунт пользователя с таким Email не найден

+ 4
- 4
api/core/workflow/graph_engine/command_processing/command_processor.py Просмотреть файл

@@ -39,8 +39,8 @@ class CommandProcessor:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self.command_channel = command_channel
self.graph_execution = graph_execution
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}

def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
@@ -56,7 +56,7 @@ class CommandProcessor:
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self.command_channel.fetch_commands()
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
@@ -72,7 +72,7 @@ class CommandProcessor:
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self.graph_execution)
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:

+ 2
- 0
api/core/workflow/graph_engine/error_handling/abort_strategy.py Просмотреть файл

@@ -32,6 +32,8 @@ class AbortStrategy:
Returns:
None - signals abortion
"""
_ = graph
_ = retry_count
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)

# Return None to signal that execution should stop

+ 1
- 0
api/core/workflow/graph_engine/error_handling/default_value_strategy.py Просмотреть файл

@@ -31,6 +31,7 @@ class DefaultValueStrategy:
Returns:
NodeRunExceptionEvent with default values
"""
_ = retry_count
node = graph.nodes[event.node_id]

outputs = {

+ 2
- 0
api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py Просмотреть файл

@@ -31,6 +31,8 @@ class FailBranchStrategy:
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
_ = graph
_ = retry_count
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,

+ 4
- 4
api/core/workflow/graph_engine/event_management/event_collector.py Просмотреть файл

@@ -23,7 +23,7 @@ class ReadWriteLock:

def acquire_read(self) -> None:
"""Acquire a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
@@ -31,7 +31,7 @@ class ReadWriteLock:

def release_read(self) -> None:
"""Release a read lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
@@ -41,9 +41,9 @@ class ReadWriteLock:

def acquire_write(self) -> None:
"""Acquire a write lock."""
self._read_ready.acquire()
_ = self._read_ready.acquire()
while self._readers > 0:
self._read_ready.wait()
_ = self._read_ready.wait()

def release_write(self) -> None:
"""Release a write lock."""

+ 3
- 3
api/core/workflow/graph_engine/event_management/event_emitter.py Просмотреть файл

@@ -28,7 +28,7 @@ class EventEmitter:
Args:
event_collector: The collector to emit events from
"""
self.event_collector = event_collector
self._event_collector = event_collector
self._execution_complete = threading.Event()

def mark_complete(self) -> None:
@@ -44,9 +44,9 @@ class EventEmitter:
"""
yielded_count = 0

while not self._execution_complete.is_set() or yielded_count < self.event_collector.event_count():
while not self._execution_complete.is_set() or yielded_count < self._event_collector.event_count():
# Get new events since last yield
new_events = self.event_collector.get_new_events(yielded_count)
new_events = self._event_collector.get_new_events(yielded_count)

# Yield any new events
for event in new_events:

+ 90
- 79
api/core/workflow/graph_engine/graph_engine.py Просмотреть файл

@@ -75,7 +75,7 @@ class GraphEngine:
"""Initialize the graph engine with separated concerns."""

# Create domain models
self.execution_context = ExecutionContext(
self._execution_context = ExecutionContext(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
@@ -87,13 +87,13 @@ class GraphEngine:
max_execution_time=max_execution_time,
)

self.graph_execution = GraphExecution(workflow_id=workflow_id)
self._graph_execution = GraphExecution(workflow_id=workflow_id)

# Store core dependencies
self.graph = graph
self.graph_config = graph_config
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
self._graph = graph
self._graph_config = graph_config
self._graph_runtime_state = graph_runtime_state
self._command_channel = command_channel

# Store worker management parameters
self._min_workers = min_workers
@@ -102,8 +102,8 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time

# Initialize queues
self.ready_queue: queue.Queue[str] = queue.Queue()
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
self._ready_queue: queue.Queue[str] = queue.Queue()
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

# Initialize subsystems
self._initialize_subsystems()
@@ -118,55 +118,55 @@ class GraphEngine:
"""Initialize all subsystems with proper dependency injection."""

# Unified state management - single instance handles all state operations
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)

# Response coordination
self.response_coordinator = ResponseStreamCoordinator(
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)

# Event management
self.event_collector = EventCollector()
self.event_emitter = EventEmitter(self.event_collector)
self._event_collector = EventCollector()
self._event_emitter = EventEmitter(self._event_collector)

# Error handling
self.error_handler = ErrorHandler(self.graph, self.graph_execution)
self._error_handler = ErrorHandler(self._graph, self._graph_execution)

# Graph traversal
self.node_readiness_checker = NodeReadinessChecker(self.graph)
self.edge_processor = EdgeProcessor(
graph=self.graph,
state_manager=self.state_manager,
response_coordinator=self.response_coordinator,
self._node_readiness_checker = NodeReadinessChecker(self._graph)
self._edge_processor = EdgeProcessor(
graph=self._graph,
state_manager=self._state_manager,
response_coordinator=self._response_coordinator,
)
self.skip_propagator = SkipPropagator(
graph=self.graph,
state_manager=self.state_manager,
self._skip_propagator = SkipPropagator(
graph=self._graph,
state_manager=self._state_manager,
)
self.branch_handler = BranchHandler(
graph=self.graph,
edge_processor=self.edge_processor,
skip_propagator=self.skip_propagator,
state_manager=self.state_manager,
self._branch_handler = BranchHandler(
graph=self._graph,
edge_processor=self._edge_processor,
skip_propagator=self._skip_propagator,
state_manager=self._state_manager,
)

# Event handler registry with all dependencies
self.event_handler_registry = EventHandlerRegistry(
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
graph_execution=self.graph_execution,
response_coordinator=self.response_coordinator,
event_collector=self.event_collector,
branch_handler=self.branch_handler,
edge_processor=self.edge_processor,
state_manager=self.state_manager,
error_handler=self.error_handler,
self._event_handler_registry = EventHandlerRegistry(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_collector,
branch_handler=self._branch_handler,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)

# Command processing
self.command_processor = CommandProcessor(
command_channel=self.command_channel,
graph_execution=self.graph_execution,
self._command_processor = CommandProcessor(
command_channel=self._command_channel,
graph_execution=self._graph_execution,
)
self._setup_command_handlers()

@@ -174,29 +174,29 @@ class GraphEngine:
self._setup_worker_management()

# Orchestration
self.execution_coordinator = ExecutionCoordinator(
graph_execution=self.graph_execution,
state_manager=self.state_manager,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_collector,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)

self.dispatcher = Dispatcher(
event_queue=self.event_queue,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
execution_coordinator=self.execution_coordinator,
max_execution_time=self.execution_context.max_execution_time,
event_emitter=self.event_emitter,
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_collector,
execution_coordinator=self._execution_coordinator,
max_execution_time=self._execution_context.max_execution_time,
event_emitter=self._event_emitter,
)

def _setup_command_handlers(self) -> None:
"""Configure command handlers."""
# Create handler instance that follows the protocol
abort_handler = AbortCommandHandler()
self.command_processor.register_handler(
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
@@ -216,9 +216,9 @@ class GraphEngine:

# Create simple worker pool
self._worker_pool = SimpleWorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
@@ -229,8 +229,8 @@ class GraphEngine:

def _validate_graph_state_consistency(self) -> None:
"""Validate that all nodes share the same GraphRuntimeState."""
expected_state_id = id(self.graph_runtime_state)
for node in self.graph.nodes.values():
expected_state_id = id(self._graph_runtime_state)
for node in self._graph.nodes.values():
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")

@@ -251,7 +251,7 @@ class GraphEngine:
self._initialize_layers()

# Start execution
self.graph_execution.start()
self._graph_execution.start()
start_event = GraphRunStartedEvent()
yield start_event

@@ -259,23 +259,23 @@ class GraphEngine:
self._start_execution()

# Yield events as they occur
yield from self.event_emitter.emit_events()
yield from self._event_emitter.emit_events()

# Handle completion
if self.graph_execution.aborted:
if self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self.graph_execution.error:
abort_reason = str(self.graph_execution.error)
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
reason=abort_reason,
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)
elif self.graph_execution.has_error:
if self.graph_execution.error:
raise self.graph_execution.error
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self.graph_runtime_state.outputs,
outputs=self._graph_runtime_state.outputs,
)

except Exception as e:
@@ -287,10 +287,10 @@ class GraphEngine:

def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self.event_collector.set_layers(self._layers)
self._event_collector.set_layers(self._layers)
for layer in self._layers:
try:
layer.initialize(self.graph_runtime_state, self.command_channel)
layer.initialize(self._graph_runtime_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)

@@ -305,21 +305,21 @@ class GraphEngine:
self._worker_pool.start()

# Register response nodes
for node in self.graph.nodes.values():
for node in self._graph.nodes.values():
if node.execution_type == NodeExecutionType.RESPONSE:
self.response_coordinator.register(node.id)
self._response_coordinator.register(node.id)

# Enqueue root node
root_node = self.graph.root_node
self.state_manager.enqueue_node(root_node.id)
self.state_manager.start_execution(root_node.id)
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)

# Start dispatcher
self.dispatcher.start()
self._dispatcher.start()

def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self.dispatcher.stop()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

@@ -328,6 +328,17 @@ class GraphEngine:

for layer in self._layers:
try:
layer.on_graph_end(self.graph_execution.error)
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)

# Public property accessors for attributes that need external access
@property
def graph_runtime_state(self) -> GraphRuntimeState:
"""Get the graph runtime state."""
return self._graph_runtime_state

@property
def graph(self) -> Graph:
"""Get the graph."""
return self._graph

+ 8
- 8
api/core/workflow/graph_engine/graph_traversal/branch_handler.py Просмотреть файл

@@ -38,10 +38,10 @@ class BranchHandler:
skip_propagator: Propagator for skip states
state_manager: Unified state manager
"""
self.graph = graph
self.edge_processor = edge_processor
self.skip_propagator = skip_propagator
self.state_manager = state_manager
self._graph = graph
self._edge_processor = edge_processor
self._skip_propagator = skip_propagator
self._state_manager = state_manager

def handle_branch_completion(
self, node_id: str, selected_handle: str | None
@@ -63,13 +63,13 @@ class BranchHandler:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")

# Categorize edges into selected and unselected
_, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)

# Skip all unselected paths
self.skip_propagator.skip_branch_paths(unselected_edges)
self._skip_propagator.skip_branch_paths(unselected_edges)

# Process selected edges and get ready nodes and streaming events
return self.edge_processor.process_node_success(node_id, selected_handle)
return self._edge_processor.process_node_success(node_id, selected_handle)

def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
@@ -82,6 +82,6 @@ class BranchHandler:
Returns:
True if the selection is valid
"""
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

+ 10
- 10
api/core/workflow/graph_engine/graph_traversal/edge_processor.py Просмотреть файл

@@ -36,9 +36,9 @@ class EdgeProcessor:
state_manager: Unified state manager
response_coordinator: Response stream coordinator
"""
self.graph = graph
self.state_manager = state_manager
self.response_coordinator = response_coordinator
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator

def process_node_success(
self, node_id: str, selected_handle: str | None = None
@@ -53,7 +53,7 @@ class EdgeProcessor:
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]

if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
@@ -72,7 +72,7 @@ class EdgeProcessor:
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)

for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
@@ -104,7 +104,7 @@ class EdgeProcessor:
all_streaming_events: list[NodeRunStreamChunkEvent] = []

# Categorize edges
selected_edges, unselected_edges = self.state_manager.categorize_branch_edges(node_id, selected_handle)
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)

# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
@@ -129,14 +129,14 @@ class EdgeProcessor:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self.state_manager.mark_edge_taken(edge.id)
self._state_manager.mark_edge_taken(edge.id)

# Notify response coordinator and get streaming events
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
streaming_events = self._response_coordinator.on_edge_taken(edge.id)

# Check if downstream node is ready
ready_nodes: list[str] = []
if self.state_manager.is_node_ready(edge.head):
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)

return ready_nodes, streaming_events
@@ -148,4 +148,4 @@ class EdgeProcessor:
Args:
edge: The edge to skip
"""
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)

+ 3
- 3
api/core/workflow/graph_engine/graph_traversal/node_readiness.py Просмотреть файл

@@ -24,7 +24,7 @@ class NodeReadinessChecker:
Args:
graph: The workflow graph
"""
self.graph = graph
self._graph = graph

def is_node_ready(self, node_id: str) -> bool:
"""
@@ -40,7 +40,7 @@ class NodeReadinessChecker:
Returns:
True if the node is ready for execution
"""
incoming_edges = self.graph.get_incoming_edges(node_id)
incoming_edges = self._graph.get_incoming_edges(node_id)

# No dependencies means always ready
if not incoming_edges:
@@ -75,7 +75,7 @@ class NodeReadinessChecker:
List of node IDs that are now ready
"""
ready_nodes: list[str] = []
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
outgoing_edges = self._graph.get_outgoing_edges(from_node_id)

for edge in outgoing_edges:
if edge.state == NodeState.TAKEN:

+ 10
- 10
api/core/workflow/graph_engine/graph_traversal/skip_propagator.py Просмотреть файл

@@ -31,8 +31,8 @@ class SkipPropagator:
graph: The workflow graph
state_manager: Unified state manager
"""
self.graph = graph
self.state_manager = state_manager
self._graph = graph
self._state_manager = state_manager

def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
@@ -46,11 +46,11 @@ class SkipPropagator:
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self.graph.edges[edge_id].head
incoming_edges = self.graph.get_incoming_edges(downstream_node_id)
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)

# Analyze edge states
edge_states = self.state_manager.analyze_edge_states(incoming_edges)
edge_states = self._state_manager.analyze_edge_states(incoming_edges)

# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
@@ -59,7 +59,7 @@ class SkipPropagator:
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self.state_manager.enqueue_node(downstream_node_id)
self._state_manager.enqueue_node(downstream_node_id)
return

# All edges are skipped, propagate skip to this node
@@ -74,12 +74,12 @@ class SkipPropagator:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self.state_manager.mark_node_skipped(node_id)
self._state_manager.mark_node_skipped(node_id)

# Mark all outgoing edges as skipped and propagate
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)

@@ -91,5 +91,5 @@ class SkipPropagator:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self.state_manager.mark_edge_skipped(edge.id)
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

+ 16
- 16
api/core/workflow/graph_engine/orchestration/dispatcher.py Просмотреть файл

@@ -48,12 +48,12 @@ class Dispatcher:
max_execution_time: Maximum execution time in seconds
event_emitter: Optional event emitter to signal completion
"""
self.event_queue = event_queue
self.event_handler = event_handler
self.event_collector = event_collector
self.execution_coordinator = execution_coordinator
self.max_execution_time = max_execution_time
self.event_emitter = event_emitter
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._max_execution_time = max_execution_time
self._event_emitter = event_emitter

self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
@@ -80,28 +80,28 @@ class Dispatcher:
try:
while not self._stop_event.is_set():
# Check for commands
self.execution_coordinator.check_commands()
self._execution_coordinator.check_commands()

# Check for scaling
self.execution_coordinator.check_scaling()
self._execution_coordinator.check_scaling()

# Process events
try:
event = self.event_queue.get(timeout=0.1)
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self.event_handler.handle_event(event)
self.event_queue.task_done()
self._event_handler.handle_event(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self.execution_coordinator.is_execution_complete():
if self._execution_coordinator.is_execution_complete():
break

except Exception as e:
logger.exception("Dispatcher error")
self.execution_coordinator.mark_failed(e)
self._execution_coordinator.mark_failed(e)

finally:
self.execution_coordinator.mark_complete()
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self.event_emitter:
self.event_emitter.mark_complete()
if self._event_emitter:
self._event_emitter.mark_complete()

+ 13
- 13
api/core/workflow/graph_engine/orchestration/execution_coordinator.py Просмотреть файл

@@ -43,20 +43,20 @@ class ExecutionCoordinator:
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self.graph_execution = graph_execution
self.state_manager = state_manager
self.event_handler = event_handler
self.event_collector = event_collector
self.command_processor = command_processor
self.worker_pool = worker_pool
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool

def check_commands(self) -> None:
"""Process any pending commands."""
self.command_processor.process_commands()
self._command_processor.process_commands()

def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
self.worker_pool.check_and_scale()
self._worker_pool.check_and_scale()

def is_execution_complete(self) -> bool:
"""
@@ -66,16 +66,16 @@ class ExecutionCoordinator:
True if execution is complete
"""
# Check if aborted or failed
if self.graph_execution.aborted or self.graph_execution.has_error:
if self._graph_execution.aborted or self._graph_execution.has_error:
return True

# Complete if no work remains
return self.state_manager.is_execution_complete()
return self._state_manager.is_execution_complete()

def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self.graph_execution.completed:
self.graph_execution.complete()
if not self._graph_execution.completed:
self._graph_execution.complete()

def mark_failed(self, error: Exception) -> None:
"""
@@ -84,4 +84,4 @@ class ExecutionCoordinator:
Args:
error: The error that caused failure
"""
self.graph_execution.fail(error)
self._graph_execution.fail(error)

+ 50
- 51
api/core/workflow/graph_engine/response_coordinator/coordinator.py Просмотреть файл

@@ -44,11 +44,11 @@ class ResponseStreamCoordinator:
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self.variable_pool = variable_pool
self.graph = graph
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()

# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
@@ -68,7 +68,7 @@ class ResponseStreamCoordinator:
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session

def register(self, response_node_id: NodeID) -> None:
with self.lock:
with self._lock:
self._response_nodes.add(response_node_id)

# Build and save paths map for this response node
@@ -76,7 +76,7 @@ class ResponseStreamCoordinator:
self._paths_maps[response_node_id] = paths_map

# Create and store response session for this node
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session

@@ -87,7 +87,7 @@ class ResponseStreamCoordinator:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self.lock:
with self._lock:
self._node_execution_ids[node_id] = execution_id

def _get_or_create_execution_id(self, node_id: NodeID) -> str:
@@ -99,7 +99,7 @@ class ResponseStreamCoordinator:
Returns:
The execution ID for the node
"""
with self.lock:
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
@@ -116,14 +116,14 @@ class ResponseStreamCoordinator:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self.graph.root_node.id
root_node_id = self._graph.root_node.id

# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]

# Extract variable selectors from the response node's template
response_node = self.graph.nodes[response_node_id]
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template

@@ -149,7 +149,7 @@ class ResponseStreamCoordinator:
visited.add(current_node_id)

# Explore outgoing edges
outgoing_edges = self.graph.get_outgoing_edges(current_node_id)
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
@@ -168,8 +168,8 @@ class ResponseStreamCoordinator:
for path in all_complete_paths:
blocking_edges: list[str] = []
for edge_id in path:
edge = self.graph.edges[edge_id]
source_node = self.graph.nodes[edge.tail]
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]

# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
@@ -199,7 +199,7 @@ class ResponseStreamCoordinator:
"""
events: list[NodeRunStreamChunkEvent] = []

with self.lock:
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
@@ -245,21 +245,21 @@ class ResponseStreamCoordinator:
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]

if self.active_session is None:
self.active_session = session
if self._active_session is None:
self._active_session = session

# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self.waiting_sessions.append(session)
self._waiting_sessions.append(session)

return events

def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self._append_stream_chunk(event.selector, event)
if event.is_final:
@@ -269,9 +269,8 @@ class ResponseStreamCoordinator:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self.variable_pool.add((event.node_id, variable_name), variable_value)
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
return []

def _create_stream_chunk_event(
self,
@@ -287,9 +286,9 @@ class ResponseStreamCoordinator:
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self.graph.nodes and self.active_session:
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self.graph.nodes[self.active_session.node_id]
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
@@ -300,7 +299,7 @@ class ResponseStreamCoordinator:
)

# Standard case: selector refers to an actual node
node = self.graph.nodes[node_id]
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
@@ -323,9 +322,9 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self.active_session and source_selector_prefix not in self.graph.nodes:
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self.active_session.node_id
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
@@ -336,8 +335,8 @@ class ResponseStreamCoordinator:
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes:
response_node = self.graph.nodes[self.active_session.node_id]
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
@@ -359,10 +358,10 @@ class ResponseStreamCoordinator:
if stream_closed:
is_complete = True

elif value := self.variable_pool.get(segment.selector):
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
@@ -379,13 +378,13 @@ class ResponseStreamCoordinator:

def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self.active_session is not None
current_response_node = self.graph.nodes[self.active_session.node_id]
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]

# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)

is_last_segment = self.active_session.index == len(self.active_session.template.segments) - 1
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
@@ -396,29 +395,29 @@ class ResponseStreamCoordinator:
return [event]

def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self.lock:
if not self.active_session:
with self._lock:
if not self._active_session:
return []

template = self.active_session.template
response_node_id = self.active_session.node_id
template = self._active_session.template
response_node_id = self._active_session.node_id

events: list[NodeRunStreamChunkEvent] = []

# Process segments sequentially from current index
while self.active_session.index < len(template.segments):
segment = template.segments[self.active_session.index]
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]

if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self.graph.nodes:
source_node = self.graph.nodes[source_selector_prefix]
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]

if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self.active_session.index += 1
self._active_session.index += 1
continue

segment_events, is_complete = self._process_variable_segment(segment)
@@ -426,7 +425,7 @@ class ResponseStreamCoordinator:

# Only advance index if this variable segment is complete
if is_complete:
self.active_session.index += 1
self._active_session.index += 1
else:
# Wait for more data
break
@@ -434,9 +433,9 @@ class ResponseStreamCoordinator:
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self.active_session.index += 1
self._active_session.index += 1

if self.active_session.is_complete():
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
@@ -454,16 +453,16 @@ class ResponseStreamCoordinator:
Returns:
List of events from starting the next session
"""
with self.lock:
with self._lock:
events: list[NodeRunStreamChunkEvent] = []

if self.active_session and self.active_session.node_id == node_id:
self.active_session = None
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None

# Try to start next waiting session
if self.waiting_sessions:
next_session = self.waiting_sessions.popleft()
self.active_session = next_session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session

# Immediately try to flush any available segments
events = self.try_flush()

+ 17
- 17
api/core/workflow/graph_engine/state_management/unified_state_manager.py Просмотреть файл

@@ -46,8 +46,8 @@ class UnifiedStateManager:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()

# Execution tracking state
@@ -66,8 +66,8 @@ class UnifiedStateManager:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)

def mark_node_skipped(self, node_id: str) -> None:
"""
@@ -77,7 +77,7 @@ class UnifiedStateManager:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
self._graph.nodes[node_id].state = NodeState.SKIPPED

def is_node_ready(self, node_id: str) -> bool:
"""
@@ -94,7 +94,7 @@ class UnifiedStateManager:
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
incoming_edges = self._graph.get_incoming_edges(node_id)

# If no incoming edges, node is always ready
if not incoming_edges:
@@ -118,7 +118,7 @@ class UnifiedStateManager:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state
return self._graph.nodes[node_id].state

# ============= Edge State Operations =============

@@ -130,7 +130,7 @@ class UnifiedStateManager:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
self._graph.edges[edge_id].state = NodeState.TAKEN

def mark_edge_skipped(self, edge_id: str) -> None:
"""
@@ -140,7 +140,7 @@ class UnifiedStateManager:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
self._graph.edges[edge_id].state = NodeState.SKIPPED

def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
@@ -172,7 +172,7 @@ class UnifiedStateManager:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
return self._graph.edges[edge_id].state

def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
@@ -186,7 +186,7 @@ class UnifiedStateManager:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []

@@ -272,7 +272,7 @@ class UnifiedStateManager:
True if execution is complete
"""
with self._lock:
return self.ready_queue.empty() and len(self._executing_nodes) == 0
return self._ready_queue.empty() and len(self._executing_nodes) == 0

def get_queue_depth(self) -> int:
"""
@@ -281,7 +281,7 @@ class UnifiedStateManager:
Returns:
Number of nodes in the ready queue
"""
return self.ready_queue.qsize()
return self._ready_queue.qsize()

def get_execution_stats(self) -> dict[str, int]:
"""
@@ -291,12 +291,12 @@ class UnifiedStateManager:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN)
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)

return {
"queue_depth": self.ready_queue.qsize(),
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,

+ 23
- 23
api/core/workflow/graph_engine/worker.py Просмотреть файл

@@ -59,16 +59,16 @@ class Worker(threading.Thread):
on_active_callback: Optional callback when worker becomes active
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_id = worker_id
self.flask_app = flask_app
self.context_vars = context_vars
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self.on_idle_callback = on_idle_callback
self.on_active_callback = on_active_callback
self.last_task_time = time.time()
self._on_idle_callback = on_idle_callback
self._on_active_callback = on_active_callback
self._last_task_time = time.time()

def stop(self) -> None:
"""Signal the worker to stop processing."""
@@ -85,22 +85,22 @@ class Worker(threading.Thread):
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self.ready_queue.get(timeout=0.1)
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
# Notify that worker is idle
if self.on_idle_callback:
self.on_idle_callback(self.worker_id)
if self._on_idle_callback:
self._on_idle_callback(self._worker_id)
continue

# Notify that worker is active
if self.on_active_callback:
self.on_active_callback(self.worker_id)
if self._on_active_callback:
self._on_active_callback(self._worker_id)

self.last_task_time = time.time()
node = self.graph.nodes[node_id]
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._execute_node(node)
self.ready_queue.task_done()
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
@@ -110,7 +110,7 @@ class Worker(threading.Thread):
error=str(e),
start_at=datetime.now(),
)
self.event_queue.put(error_event)
self._event_queue.put(error_event)

def _execute_node(self, node: Node) -> None:
"""
@@ -120,19 +120,19 @@ class Worker(threading.Thread):
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self.flask_app and self.context_vars:
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self.flask_app,
context_vars=self.context_vars,
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self.event_queue.put(event)
self._event_queue.put(event)

+ 31
- 31
api/core/workflow/graph_engine/worker_management/simple_worker_pool.py Просмотреть файл

@@ -56,20 +56,20 @@ class SimpleWorkerPool:
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.flask_app = flask_app
self.context_vars = context_vars
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars

# Scaling parameters with defaults
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME

# Worker management
self.workers: list[Worker] = []
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
@@ -89,13 +89,13 @@ class SimpleWorkerPool:

# Calculate initial worker count
if initial_count is None:
node_count = len(self.graph.nodes)
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self.min_workers
initial_count = self._min_workers
elif node_count < 50:
initial_count = min(self.min_workers + 1, self.max_workers)
initial_count = min(self._min_workers + 1, self._max_workers)
else:
initial_count = min(self.min_workers + 2, self.max_workers)
initial_count = min(self._min_workers + 2, self._max_workers)

# Create initial workers
for _ in range(initial_count):
@@ -107,15 +107,15 @@ class SimpleWorkerPool:
self._running = False

# Stop all workers
for worker in self.workers:
for worker in self._workers:
worker.stop()

# Wait for workers to finish
for worker in self.workers:
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)

self.workers.clear()
self._workers.clear()

def _create_worker(self) -> None:
"""Create and start a new worker."""
@@ -123,16 +123,16 @@ class SimpleWorkerPool:
self._worker_counter += 1

worker = Worker(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
flask_app=self._flask_app,
context_vars=self._context_vars,
)

worker.start()
self.workers.append(worker)
self._workers.append(worker)

def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
@@ -140,17 +140,17 @@ class SimpleWorkerPool:
if not self._running:
return

current_count = len(self.workers)
queue_depth = self.ready_queue.qsize()
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()

# Simple scaling logic
if queue_depth > self.scale_up_threshold and current_count < self.max_workers:
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
self._create_worker()

def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)
return len(self._workers)

def get_status(self) -> dict[str, int]:
"""
@@ -161,8 +161,8 @@ class SimpleWorkerPool:
"""
with self._lock:
return {
"total_workers": len(self.workers),
"queue_depth": self.ready_queue.qsize(),
"min_workers": self.min_workers,
"max_workers": self.max_workers,
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers,
"max_workers": self._max_workers,
}

Загрузка…
Отмена
Сохранить