Selaa lähdekoodia

refactor(graph_engine): Merge worker management into one WorkerPool

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.1
-LAN- 2 kuukautta sitten
vanhempi
commit
64c1234724
No account linked to committer's email address

+ 8
- 9
api/.importlinter Näytä tiedosto

@@ -77,16 +77,15 @@ forbidden_modules =
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols

[importlinter:contract:worker-management-layers]
name = Worker Management Layers
type = layers
layers =
worker_pool
worker_factory
dynamic_scaler
activity_tracker
containers =
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management

[importlinter:contract:error-handling-strategies]
name = Error Handling Strategies

+ 11
- 29
api/core/workflow/graph_engine/graph_engine.py Näytä tiedosto

@@ -13,7 +13,6 @@ from typing import final

from flask import Flask, current_app

from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
@@ -40,7 +39,7 @@ from .output_registry import OutputRegistry
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import UnifiedStateManager
from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool
from .worker_management import SimpleWorkerPool

logger = logging.getLogger(__name__)

@@ -215,31 +214,17 @@ class GraphEngine:

context_vars = contextvars.copy_context()

# Create worker management components
self._activity_tracker = ActivityTracker()
self._dynamic_scaler = DynamicScaler(
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
scale_up_threshold=(
self._scale_up_threshold
if self._scale_up_threshold is not None
else dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
),
scale_down_idle_time=(
self._scale_down_idle_time
if self._scale_down_idle_time is not None
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
),
)
self._worker_factory = WorkerFactory(flask_app, context_vars)

self._worker_pool = WorkerPool(
# Create simple worker pool
self._worker_pool = SimpleWorkerPool(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_factory=self._worker_factory,
dynamic_scaler=self._dynamic_scaler,
activity_tracker=self._activity_tracker,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
)

def _validate_graph_state_consistency(self) -> None:
@@ -316,11 +301,8 @@ class GraphEngine:

def _start_execution(self) -> None:
"""Start execution subsystems."""
# Calculate initial worker count
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)

# Start worker pool
self._worker_pool.start(initial_workers)
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()

# Register response nodes
for node in self.graph.nodes.values():

+ 3
- 5
api/core/workflow/graph_engine/orchestration/execution_coordinator.py Näytä tiedosto

@@ -8,7 +8,7 @@ from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventCollector
from ..state_management import UnifiedStateManager
from ..worker_management import WorkerPool
from ..worker_management import SimpleWorkerPool

if TYPE_CHECKING:
from ..event_management import EventHandlerRegistry
@@ -30,7 +30,7 @@ class ExecutionCoordinator:
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
worker_pool: SimpleWorkerPool,
) -> None:
"""
Initialize the execution coordinator.
@@ -56,9 +56,7 @@ class ExecutionCoordinator:

def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
queue_depth = self.state_manager.ready_queue.qsize()
executing_count = self.state_manager.get_executing_count()
self.worker_pool.check_scaling(queue_depth, executing_count)
self.worker_pool.check_and_scale()

def is_execution_complete(self) -> bool:
"""

+ 2
- 8
api/core/workflow/graph_engine/worker_management/__init__.py Näytä tiedosto

@@ -5,14 +5,8 @@ This package manages the worker pool, including creation,
scaling, and activity tracking.
"""

from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory
from .worker_pool import WorkerPool
from .simple_worker_pool import SimpleWorkerPool

__all__ = [
"ActivityTracker",
"DynamicScaler",
"WorkerFactory",
"WorkerPool",
"SimpleWorkerPool",
]

+ 0
- 76
api/core/workflow/graph_engine/worker_management/activity_tracker.py Näytä tiedosto

@@ -1,76 +0,0 @@
"""
Activity tracker for monitoring worker activity.
"""

import threading
import time
from typing import final


@final
class ActivityTracker:
"""
Tracks worker activity for scaling decisions.

This monitors which workers are active or idle to support
dynamic scaling decisions.
"""

def __init__(self, idle_threshold: float = 30.0) -> None:
"""
Initialize the activity tracker.

Args:
idle_threshold: Seconds before a worker is considered idle
"""
self.idle_threshold = idle_threshold
self._worker_activity: dict[int, tuple[bool, float]] = {}
self._lock = threading.RLock()

def track_activity(self, worker_id: int, is_active: bool) -> None:
"""
Track worker activity state.

Args:
worker_id: ID of the worker
is_active: Whether the worker is active
"""
with self._lock:
self._worker_activity[worker_id] = (is_active, time.time())

def get_idle_workers(self) -> list[int]:
"""
Get list of workers that have been idle too long.

Returns:
List of idle worker IDs
"""
current_time = time.time()
idle_workers: list[int] = []

with self._lock:
for worker_id, (is_active, last_change) in self._worker_activity.items():
if not is_active and (current_time - last_change) > self.idle_threshold:
idle_workers.append(worker_id)

return idle_workers

def remove_worker(self, worker_id: int) -> None:
"""
Remove a worker from tracking.

Args:
worker_id: ID of the worker to remove
"""
with self._lock:
self._worker_activity.pop(worker_id, None)

def get_active_count(self) -> int:
"""
Get count of currently active workers.

Returns:
Number of active workers
"""
with self._lock:
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)

+ 0
- 101
api/core/workflow/graph_engine/worker_management/dynamic_scaler.py Näytä tiedosto

@@ -1,101 +0,0 @@
"""
Dynamic scaler for worker pool sizing.
"""

from typing import final

from core.workflow.graph import Graph


@final
class DynamicScaler:
"""
Manages dynamic scaling decisions for the worker pool.

This encapsulates the logic for when to scale up or down
based on workload and configuration.
"""

def __init__(
self,
min_workers: int = 2,
max_workers: int = 10,
scale_up_threshold: int = 5,
scale_down_idle_time: float = 30.0,
) -> None:
"""
Initialize the dynamic scaler.

Args:
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Idle time before scaling down
"""
self.min_workers = min_workers
self.max_workers = max_workers
self.scale_up_threshold = scale_up_threshold
self.scale_down_idle_time = scale_down_idle_time

def calculate_initial_workers(self, graph: Graph) -> int:
"""
Calculate initial worker count based on graph complexity.

Args:
graph: The workflow graph

Returns:
Initial number of workers to create
"""
node_count = len(graph.nodes)

# Simple heuristic: more nodes = more workers
if node_count < 10:
initial = self.min_workers
elif node_count < 50:
initial = min(4, self.max_workers)
elif node_count < 100:
initial = min(6, self.max_workers)
else:
initial = min(8, self.max_workers)

return max(self.min_workers, initial)

def should_scale_up(self, current_workers: int, queue_depth: int, executing_count: int) -> bool:
"""
Determine if scaling up is needed.

Args:
current_workers: Current number of workers
queue_depth: Number of nodes waiting
executing_count: Number of nodes executing

Returns:
True if should scale up
"""
if current_workers >= self.max_workers:
return False

# Scale up if queue is deep and workers are busy
if queue_depth > self.scale_up_threshold:
if executing_count >= current_workers * 0.8:
return True

return False

def should_scale_down(self, current_workers: int, idle_workers: list[int]) -> bool:
"""
Determine if scaling down is appropriate.

Args:
current_workers: Current number of workers
idle_workers: List of idle worker IDs

Returns:
True if should scale down
"""
if current_workers <= self.min_workers:
return False

# Scale down if we have idle workers
return len(idle_workers) > 0

+ 0
- 332
api/core/workflow/graph_engine/worker_management/enhanced_worker_pool.py Näytä tiedosto

@@ -1,332 +0,0 @@
"""
Enhanced worker pool with integrated activity tracking and dynamic scaling.

This is a proposed simplification that merges WorkerPool, ActivityTracker,
and DynamicScaler into a single cohesive class.
"""

import queue
import threading
import time
from typing import TYPE_CHECKING, final

from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase

from ..worker import Worker

if TYPE_CHECKING:
from contextvars import Context

from flask import Flask


@final
class EnhancedWorkerPool:
"""
Enhanced worker pool with integrated features.

This class combines the responsibilities of:
- WorkerPool: Managing worker threads
- ActivityTracker: Tracking worker activity
- DynamicScaler: Making scaling decisions

Benefits:
- Simplified interface with fewer classes
- Direct integration of related features
- Reduced inter-class communication overhead
"""

def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the enhanced worker pool.

Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.flask_app = flask_app
self.context_vars = context_vars

# Scaling parameters
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME

# Worker management
self.workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False

# Activity tracking (integrated)
self._worker_activity: dict[int, tuple[bool, float]] = {}

# Scaling control
self._last_scale_check = time.time()
self._scale_check_interval = 1.0 # Check scaling every second

def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool with initial workers.

Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return

self._running = True

# Calculate initial worker count if not specified
if initial_count is None:
initial_count = self._calculate_initial_workers()

# Create initial workers
for _ in range(initial_count):
self._add_worker()

def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False

# Stop all workers
for worker in self.workers:
worker.stop()

# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)

self.workers.clear()
self._worker_activity.clear()

def check_and_scale(self) -> None:
"""
Check and perform scaling if needed.

This method should be called periodically to adjust pool size.
"""
current_time = time.time()

# Rate limit scaling checks
if current_time - self._last_scale_check < self._scale_check_interval:
return

self._last_scale_check = current_time

with self._lock:
if not self._running:
return

current_count = len(self.workers)
queue_depth = self.ready_queue.qsize()

# Check for scale up
if self._should_scale_up(current_count, queue_depth):
self._add_worker()

# Check for scale down
idle_workers = self._get_idle_workers(current_time)
if idle_workers and self._should_scale_down(current_count):
# Remove the most idle worker
self._remove_worker(idle_workers[0])

# ============= Private Methods =============

def _calculate_initial_workers(self) -> int:
"""
Calculate initial number of workers based on graph complexity.

Returns:
Initial worker count
"""
# Simple heuristic: start with min_workers, scale based on graph size
node_count = len(self.graph.nodes)

if node_count < 10:
return self.min_workers
elif node_count < 50:
return min(self.min_workers + 1, self.max_workers)
else:
return min(self.min_workers + 2, self.max_workers)

def _should_scale_up(self, current_count: int, queue_depth: int) -> bool:
"""
Determine if pool should scale up.

Args:
current_count: Current number of workers
queue_depth: Current queue depth

Returns:
True if should scale up
"""
if current_count >= self.max_workers:
return False

# Scale up if queue is deep
if queue_depth > self.scale_up_threshold:
return True

# Scale up if all workers are busy and queue is not empty
active_count = self._get_active_count()
if active_count == current_count and queue_depth > 0:
return True

return False

def _should_scale_down(self, current_count: int) -> bool:
"""
Determine if pool should scale down.

Args:
current_count: Current number of workers

Returns:
True if should scale down
"""
return current_count > self.min_workers

def _add_worker(self) -> None:
"""Add a new worker to the pool."""
worker_id = self._worker_counter
self._worker_counter += 1

# Create worker with activity callbacks
worker = Worker(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
on_idle_callback=self._on_worker_idle,
on_active_callback=self._on_worker_active,
)

worker.start()
self.workers.append(worker)
self._worker_activity[worker_id] = (False, time.time())

def _remove_worker(self, worker_id: int) -> None:
"""
Remove a specific worker from the pool.

Args:
worker_id: ID of worker to remove
"""
worker_to_remove = None
for worker in self.workers:
if worker.worker_id == worker_id:
worker_to_remove = worker
break

if worker_to_remove:
worker_to_remove.stop()
self.workers.remove(worker_to_remove)
self._worker_activity.pop(worker_id, None)

if worker_to_remove.is_alive():
worker_to_remove.join(timeout=1.0)

def _on_worker_idle(self, worker_id: int) -> None:
"""
Callback when worker becomes idle.

Args:
worker_id: ID of the idle worker
"""
with self._lock:
self._worker_activity[worker_id] = (False, time.time())

def _on_worker_active(self, worker_id: int) -> None:
"""
Callback when worker becomes active.

Args:
worker_id: ID of the active worker
"""
with self._lock:
self._worker_activity[worker_id] = (True, time.time())

def _get_idle_workers(self, current_time: float) -> list[int]:
"""
Get list of workers that have been idle too long.

Args:
current_time: Current timestamp

Returns:
List of idle worker IDs sorted by idle time (longest first)
"""
idle_workers: list[tuple[int, float]] = []

for worker_id, (is_active, last_change) in self._worker_activity.items():
if not is_active:
idle_time = current_time - last_change
if idle_time > self.scale_down_idle_time:
idle_workers.append((worker_id, idle_time))

# Sort by idle time (longest first)
idle_workers.sort(key=lambda x: x[1], reverse=True)
return [worker_id for worker_id, _ in idle_workers]

def _get_active_count(self) -> int:
"""
Get count of currently active workers.

Returns:
Number of active workers
"""
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)

# ============= Public Status Methods =============

def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)

def get_status(self) -> dict[str, int]:
"""
Get pool status information.

Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self.workers),
"active_workers": self._get_active_count(),
"idle_workers": len(self.workers) - self._get_active_count(),
"queue_depth": self.ready_queue.qsize(),
"min_workers": self.min_workers,
"max_workers": self.max_workers,
}

+ 168
- 0
api/core/workflow/graph_engine/worker_management/simple_worker_pool.py Näytä tiedosto

@@ -0,0 +1,168 @@
"""
Simple worker pool that consolidates functionality.

This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""

import queue
import threading
from typing import TYPE_CHECKING, final

from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase

from ..worker import Worker

if TYPE_CHECKING:
from contextvars import Context

from flask import Flask


@final
class SimpleWorkerPool:
"""
Simple worker pool with integrated management.

This class consolidates all worker management functionality into
a single, simpler implementation without excessive abstraction.
"""

def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the simple worker pool.

Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.flask_app = flask_app
self.context_vars = context_vars

# Scaling parameters with defaults
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME

# Worker management
self.workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False

def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.

Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return

self._running = True

# Calculate initial worker count
if initial_count is None:
node_count = len(self.graph.nodes)
if node_count < 10:
initial_count = self.min_workers
elif node_count < 50:
initial_count = min(self.min_workers + 1, self.max_workers)
else:
initial_count = min(self.min_workers + 2, self.max_workers)

# Create initial workers
for _ in range(initial_count):
self._create_worker()

def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False

# Stop all workers
for worker in self.workers:
worker.stop()

# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)

self.workers.clear()

def _create_worker(self) -> None:
"""Create and start a new worker."""
worker_id = self._worker_counter
self._worker_counter += 1

worker = Worker(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
)

worker.start()
self.workers.append(worker)

def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
if not self._running:
return

current_count = len(self.workers)
queue_depth = self.ready_queue.qsize()

# Simple scaling logic
if queue_depth > self.scale_up_threshold and current_count < self.max_workers:
self._create_worker()

def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)

def get_status(self) -> dict[str, int]:
"""
Get pool status information.

Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self.workers),
"queue_depth": self.ready_queue.qsize(),
"min_workers": self.min_workers,
"max_workers": self.max_workers,
}

+ 0
- 76
api/core/workflow/graph_engine/worker_management/worker_factory.py Näytä tiedosto

@@ -1,76 +0,0 @@
"""
Factory for creating worker instances.
"""

import contextvars
import queue
from collections.abc import Callable
from typing import final

from flask import Flask

from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase

from ..worker import Worker


@final
class WorkerFactory:
"""
Factory for creating worker instances with proper context.

This encapsulates worker creation logic and ensures all workers
are created with the necessary Flask and context variable setup.
"""

def __init__(
self,
flask_app: Flask | None,
context_vars: contextvars.Context,
) -> None:
"""
Initialize the worker factory.

Args:
flask_app: Flask application context
context_vars: Context variables to propagate
"""
self.flask_app = flask_app
self.context_vars = context_vars
self._next_worker_id = 0

def create_worker(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> Worker:
"""
Create a new worker instance.

Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
on_idle_callback: Callback when worker becomes idle
on_active_callback: Callback when worker becomes active

Returns:
Configured worker instance
"""
worker_id = self._next_worker_id
self._next_worker_id += 1

return Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
on_idle_callback=on_idle_callback,
on_active_callback=on_active_callback,
)

+ 0
- 148
api/core/workflow/graph_engine/worker_management/worker_pool.py Näytä tiedosto

@@ -1,148 +0,0 @@
"""
Worker pool management.
"""

import queue
import threading
from typing import final

from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase

from ..worker import Worker
from .activity_tracker import ActivityTracker
from .dynamic_scaler import DynamicScaler
from .worker_factory import WorkerFactory


@final
class WorkerPool:
"""
Manages a pool of worker threads for executing nodes.

This provides dynamic scaling, activity tracking, and lifecycle
management for worker threads.
"""

def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_factory: WorkerFactory,
dynamic_scaler: DynamicScaler,
activity_tracker: ActivityTracker,
) -> None:
"""
Initialize the worker pool.

Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
worker_factory: Factory for creating workers
dynamic_scaler: Scaler for dynamic sizing
activity_tracker: Tracker for worker activity
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.worker_factory = worker_factory
self.dynamic_scaler = dynamic_scaler
self.activity_tracker = activity_tracker

self.workers: list[Worker] = []
self._lock = threading.RLock()
self._running = False

def start(self, initial_count: int) -> None:
"""
Start the worker pool with initial workers.

Args:
initial_count: Number of workers to start with
"""
with self._lock:
if self._running:
return

self._running = True

# Create initial workers
for _ in range(initial_count):
worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)

def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False

# Stop all workers
for worker in self.workers:
worker.stop()

# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)

self.workers.clear()

def scale_up(self) -> None:
"""Add a worker to the pool if allowed."""
with self._lock:
if not self._running:
return

if len(self.workers) >= self.dynamic_scaler.max_workers:
return

worker = self.worker_factory.create_worker(self.ready_queue, self.event_queue, self.graph)
worker.start()
self.workers.append(worker)

def scale_down(self, worker_ids: list[int]) -> None:
"""
Remove specific workers from the pool.

Args:
worker_ids: IDs of workers to remove
"""
with self._lock:
if not self._running:
return

if len(self.workers) <= self.dynamic_scaler.min_workers:
return

workers_to_remove = [w for w in self.workers if w.worker_id in worker_ids]

for worker in workers_to_remove:
worker.stop()
self.workers.remove(worker)
if worker.is_alive():
worker.join(timeout=1.0)

def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)

def check_scaling(self, queue_depth: int, executing_count: int) -> None:
"""
Check and perform scaling if needed.

Args:
queue_depth: Current queue depth
executing_count: Number of executing nodes
"""
current_count = self.get_worker_count()

if self.dynamic_scaler.should_scale_up(current_count, queue_depth, executing_count):
self.scale_up()

idle_workers = self.activity_tracker.get_idle_workers()
if idle_workers:
self.scale_down(idle_workers)

Loading…
Peruuta
Tallenna