Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| layers = | layers = | ||||
| graph_engine | graph_engine | ||||
| response_coordinator | response_coordinator | ||||
| output_registry | |||||
| containers = | containers = | ||||
| core.workflow.graph_engine | core.workflow.graph_engine | ||||
| from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator | from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator | ||||
| from .layers.base import Layer | from .layers.base import Layer | ||||
| from .orchestration import Dispatcher, ExecutionCoordinator | from .orchestration import Dispatcher, ExecutionCoordinator | ||||
| from .output_registry import OutputRegistry | |||||
| from .protocols.command_channel import CommandChannel | from .protocols.command_channel import CommandChannel | ||||
| from .response_coordinator import ResponseStreamCoordinator | from .response_coordinator import ResponseStreamCoordinator | ||||
| from .state_management import UnifiedStateManager | from .state_management import UnifiedStateManager | ||||
| self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) | self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) | ||||
| # Response coordination | # Response coordination | ||||
| self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool) | |||||
| self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph) | |||||
| self.response_coordinator = ResponseStreamCoordinator( | |||||
| variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph | |||||
| ) | |||||
| # Event management | # Event management | ||||
| self.event_collector = EventCollector() | self.event_collector = EventCollector() |
| """ | |||||
| OutputRegistry - Thread-safe storage for node outputs (streams and scalars) | |||||
| This component provides thread-safe storage and retrieval of node outputs, | |||||
| supporting both scalar values and streaming chunks with proper state management. | |||||
| """ | |||||
| from .registry import OutputRegistry | |||||
| __all__ = ["OutputRegistry"] |
| """ | |||||
| Main OutputRegistry implementation. | |||||
| This module contains the public OutputRegistry class that provides | |||||
| thread-safe storage for node outputs. | |||||
| """ | |||||
| from collections.abc import Sequence | |||||
| from threading import RLock | |||||
| from typing import TYPE_CHECKING, Any, Union, final | |||||
| from core.variables import Segment | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from .stream import Stream | |||||
| if TYPE_CHECKING: | |||||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||||
| @final | |||||
| class OutputRegistry: | |||||
| """ | |||||
| Thread-safe registry for storing and retrieving node outputs. | |||||
| Supports both scalar values and streaming chunks with proper state management. | |||||
| All operations are thread-safe using internal locking. | |||||
| """ | |||||
| def __init__(self, variable_pool: VariablePool) -> None: | |||||
| """Initialize empty registry with thread-safe storage.""" | |||||
| self._lock = RLock() | |||||
| self._scalars = variable_pool | |||||
| self._streams: dict[tuple[str, ...], Stream] = {} | |||||
| def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]: | |||||
| """Convert selector list to tuple key for internal storage.""" | |||||
| return tuple(selector) | |||||
| def set_scalar( | |||||
| self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]] | |||||
| ) -> None: | |||||
| """ | |||||
| Set a scalar value for the given selector. | |||||
| Args: | |||||
| selector: List of strings identifying the output location | |||||
| value: The scalar value to store | |||||
| """ | |||||
| with self._lock: | |||||
| self._scalars.add(selector, value) | |||||
| def get_scalar(self, selector: Sequence[str]) -> "Segment | None": | |||||
| """ | |||||
| Get a scalar value for the given selector. | |||||
| Args: | |||||
| selector: List of strings identifying the output location | |||||
| Returns: | |||||
| The stored Variable object, or None if not found | |||||
| """ | |||||
| with self._lock: | |||||
| return self._scalars.get(selector) | |||||
| def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None: | |||||
| """ | |||||
| Append a NodeRunStreamChunkEvent to the stream for the given selector. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| event: The NodeRunStreamChunkEvent to append | |||||
| Raises: | |||||
| ValueError: If the stream is already closed | |||||
| """ | |||||
| key = self._selector_to_key(selector) | |||||
| with self._lock: | |||||
| if key not in self._streams: | |||||
| self._streams[key] = Stream() | |||||
| try: | |||||
| self._streams[key].append(event) | |||||
| except ValueError: | |||||
| raise ValueError(f"Stream {'.'.join(selector)} is already closed") | |||||
| def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": | |||||
| """ | |||||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| The next event, or None if no unread events available | |||||
| """ | |||||
| key = self._selector_to_key(selector) | |||||
| with self._lock: | |||||
| if key not in self._streams: | |||||
| return None | |||||
| return self._streams[key].pop_next() | |||||
| def has_unread(self, selector: Sequence[str]) -> bool: | |||||
| """ | |||||
| Check if the stream has unread events. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| True if there are unread events, False otherwise | |||||
| """ | |||||
| key = self._selector_to_key(selector) | |||||
| with self._lock: | |||||
| if key not in self._streams: | |||||
| return False | |||||
| return self._streams[key].has_unread() | |||||
| def close_stream(self, selector: Sequence[str]) -> None: | |||||
| """ | |||||
| Mark a stream as closed (no more chunks can be appended). | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| """ | |||||
| key = self._selector_to_key(selector) | |||||
| with self._lock: | |||||
| if key not in self._streams: | |||||
| self._streams[key] = Stream() | |||||
| self._streams[key].close() | |||||
| def stream_closed(self, selector: Sequence[str]) -> bool: | |||||
| """ | |||||
| Check if a stream is closed. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| True if the stream is closed, False otherwise | |||||
| """ | |||||
| key = self._selector_to_key(selector) | |||||
| with self._lock: | |||||
| if key not in self._streams: | |||||
| return False | |||||
| return self._streams[key].is_closed |
| """ | |||||
| Internal stream implementation for OutputRegistry. | |||||
| This module contains the private Stream class used internally by OutputRegistry | |||||
| to manage streaming data chunks. | |||||
| """ | |||||
| from typing import TYPE_CHECKING, final | |||||
| if TYPE_CHECKING: | |||||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||||
| @final | |||||
| class Stream: | |||||
| """ | |||||
| A stream that holds NodeRunStreamChunkEvent objects and tracks read position. | |||||
| This class encapsulates stream-specific data and operations, | |||||
| including event storage, read position tracking, and closed state. | |||||
| Note: This is an internal class not exposed in the public API. | |||||
| """ | |||||
| def __init__(self) -> None: | |||||
| """Initialize an empty stream.""" | |||||
| self.events: list[NodeRunStreamChunkEvent] = [] | |||||
| self.read_position: int = 0 | |||||
| self.is_closed: bool = False | |||||
| def append(self, event: "NodeRunStreamChunkEvent") -> None: | |||||
| """ | |||||
| Append a NodeRunStreamChunkEvent to the stream. | |||||
| Args: | |||||
| event: The NodeRunStreamChunkEvent to append | |||||
| Raises: | |||||
| ValueError: If the stream is already closed | |||||
| """ | |||||
| if self.is_closed: | |||||
| raise ValueError("Cannot append to a closed stream") | |||||
| self.events.append(event) | |||||
| def pop_next(self) -> "NodeRunStreamChunkEvent | None": | |||||
| """ | |||||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||||
| Returns: | |||||
| The next event, or None if no unread events available | |||||
| """ | |||||
| if self.read_position >= len(self.events): | |||||
| return None | |||||
| event = self.events[self.read_position] | |||||
| self.read_position += 1 | |||||
| return event | |||||
| def has_unread(self) -> bool: | |||||
| """ | |||||
| Check if the stream has unread events. | |||||
| Returns: | |||||
| True if there are unread events, False otherwise | |||||
| """ | |||||
| return self.read_position < len(self.events) | |||||
| def close(self) -> None: | |||||
| """Mark the stream as closed (no more chunks can be appended).""" | |||||
| self.is_closed = True |
| from typing import TypeAlias, final | from typing import TypeAlias, final | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.enums import NodeExecutionType, NodeState | from core.workflow.enums import NodeExecutionType, NodeState | ||||
| from core.workflow.graph import Graph | from core.workflow.graph import Graph | ||||
| from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent | from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent | ||||
| from core.workflow.nodes.base.template import TextSegment, VariableSegment | from core.workflow.nodes.base.template import TextSegment, VariableSegment | ||||
| from ..output_registry import OutputRegistry | |||||
| from .path import Path | from .path import Path | ||||
| from .session import ResponseSession | from .session import ResponseSession | ||||
| Ensures ordered streaming of responses based on upstream node outputs and constants. | Ensures ordered streaming of responses based on upstream node outputs and constants. | ||||
| """ | """ | ||||
| def __init__(self, registry: OutputRegistry, graph: "Graph") -> None: | |||||
| def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: | |||||
| """ | """ | ||||
| Initialize coordinator with output registry. | |||||
| Initialize coordinator with variable pool. | |||||
| Args: | Args: | ||||
| registry: OutputRegistry instance for accessing node outputs | |||||
| variable_pool: VariablePool instance for accessing node variables | |||||
| graph: Graph instance for looking up node information | graph: Graph instance for looking up node information | ||||
| """ | """ | ||||
| self.registry = registry | |||||
| self.variable_pool = variable_pool | |||||
| self.graph = graph | self.graph = graph | ||||
| self.active_session: ResponseSession | None = None | self.active_session: ResponseSession | None = None | ||||
| self.waiting_sessions: deque[ResponseSession] = deque() | self.waiting_sessions: deque[ResponseSession] = deque() | ||||
| self.lock = RLock() | self.lock = RLock() | ||||
| # Internal stream management (replacing OutputRegistry) | |||||
| self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} | |||||
| self._stream_positions: dict[tuple[str, ...], int] = {} | |||||
| self._closed_streams: set[tuple[str, ...]] = set() | |||||
| # Track response nodes | # Track response nodes | ||||
| self._response_nodes: set[NodeID] = set() | self._response_nodes: set[NodeID] = set() | ||||
| ) -> Sequence[NodeRunStreamChunkEvent]: | ) -> Sequence[NodeRunStreamChunkEvent]: | ||||
| with self.lock: | with self.lock: | ||||
| if isinstance(event, NodeRunStreamChunkEvent): | if isinstance(event, NodeRunStreamChunkEvent): | ||||
| self.registry.append_chunk(event.selector, event) | |||||
| self._append_stream_chunk(event.selector, event) | |||||
| if event.is_final: | if event.is_final: | ||||
| self.registry.close_stream(event.selector) | |||||
| self._close_stream(event.selector) | |||||
| return self.try_flush() | return self.try_flush() | ||||
| else: | else: | ||||
| # Skip cause we share the same variable pool. | # Skip cause we share the same variable pool. | ||||
| # | # | ||||
| # for variable_name, variable_value in event.node_run_result.outputs.items(): | # for variable_name, variable_value in event.node_run_result.outputs.items(): | ||||
| # self.registry.set_scalar((event.node_id, variable_name), variable_value) | |||||
| # self.variable_pool.add((event.node_id, variable_name), variable_value) | |||||
| return self.try_flush() | return self.try_flush() | ||||
| return [] | return [] | ||||
| execution_id = self._get_or_create_execution_id(output_node_id) | execution_id = self._get_or_create_execution_id(output_node_id) | ||||
| # Stream all available chunks | # Stream all available chunks | ||||
| while self.registry.has_unread(segment.selector): | |||||
| if event := self.registry.pop_chunk(segment.selector): | |||||
| while self._has_unread_stream(segment.selector): | |||||
| if event := self._pop_stream_chunk(segment.selector): | |||||
| # For special selectors, we need to update the event to use | # For special selectors, we need to update the event to use | ||||
| # the active response node's information | # the active response node's information | ||||
| if self.active_session and source_selector_prefix not in self.graph.nodes: | if self.active_session and source_selector_prefix not in self.graph.nodes: | ||||
| events.append(event) | events.append(event) | ||||
| # Check if this is the last chunk by looking ahead | # Check if this is the last chunk by looking ahead | ||||
| stream_closed = self.registry.stream_closed(segment.selector) | |||||
| stream_closed = self._is_stream_closed(segment.selector) | |||||
| # Check if stream is closed to determine if segment is complete | # Check if stream is closed to determine if segment is complete | ||||
| if stream_closed: | if stream_closed: | ||||
| is_complete = True | is_complete = True | ||||
| elif value := self.registry.get_scalar(segment.selector): | |||||
| elif value := self.variable_pool.get(segment.selector): | |||||
| # Process scalar value | # Process scalar value | ||||
| is_last_segment = bool( | is_last_segment = bool( | ||||
| self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 | self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 | ||||
| events = self.try_flush() | events = self.try_flush() | ||||
| return events | return events | ||||
| # ============= Internal Stream Management Methods ============= | |||||
| def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: | |||||
| """ | |||||
| Append a stream chunk to the internal buffer. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| event: The NodeRunStreamChunkEvent to append | |||||
| Raises: | |||||
| ValueError: If the stream is already closed | |||||
| """ | |||||
| key = tuple(selector) | |||||
| if key in self._closed_streams: | |||||
| raise ValueError(f"Stream {'.'.join(selector)} is already closed") | |||||
| if key not in self._stream_buffers: | |||||
| self._stream_buffers[key] = [] | |||||
| self._stream_positions[key] = 0 | |||||
| self._stream_buffers[key].append(event) | |||||
| def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: | |||||
| """ | |||||
| Pop the next unread stream chunk from the buffer. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| The next event, or None if no unread events available | |||||
| """ | |||||
| key = tuple(selector) | |||||
| if key not in self._stream_buffers: | |||||
| return None | |||||
| position = self._stream_positions.get(key, 0) | |||||
| buffer = self._stream_buffers[key] | |||||
| if position >= len(buffer): | |||||
| return None | |||||
| event = buffer[position] | |||||
| self._stream_positions[key] = position + 1 | |||||
| return event | |||||
| def _has_unread_stream(self, selector: Sequence[str]) -> bool: | |||||
| """ | |||||
| Check if the stream has unread events. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| True if there are unread events, False otherwise | |||||
| """ | |||||
| key = tuple(selector) | |||||
| if key not in self._stream_buffers: | |||||
| return False | |||||
| position = self._stream_positions.get(key, 0) | |||||
| return position < len(self._stream_buffers[key]) | |||||
| def _close_stream(self, selector: Sequence[str]) -> None: | |||||
| """ | |||||
| Mark a stream as closed (no more chunks can be appended). | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| """ | |||||
| key = tuple(selector) | |||||
| self._closed_streams.add(key) | |||||
| def _is_stream_closed(self, selector: Sequence[str]) -> bool: | |||||
| """ | |||||
| Check if a stream is closed. | |||||
| Args: | |||||
| selector: List of strings identifying the stream location | |||||
| Returns: | |||||
| True if the stream is closed, False otherwise | |||||
| """ | |||||
| key = tuple(selector) | |||||
| return key in self._closed_streams |
| from uuid import uuid4 | |||||
| import pytest | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.enums import NodeType | |||||
| from core.workflow.graph_engine.output_registry import OutputRegistry | |||||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||||
| class TestOutputRegistry: | |||||
| def test_scalar_operations(self): | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Test setting and getting scalar | |||||
| registry.set_scalar(["node1", "output"], "test_value") | |||||
| segment = registry.get_scalar(["node1", "output"]) | |||||
| assert segment | |||||
| assert segment.text == "test_value" | |||||
| # Test getting non-existent scalar | |||||
| assert registry.get_scalar(["non_existent"]) is None | |||||
| def test_stream_operations(self): | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Create test events | |||||
| event1 = NodeRunStreamChunkEvent( | |||||
| id=str(uuid4()), | |||||
| node_id="node1", | |||||
| node_type=NodeType.LLM, | |||||
| selector=["node1", "stream"], | |||||
| chunk="chunk1", | |||||
| is_final=False, | |||||
| ) | |||||
| event2 = NodeRunStreamChunkEvent( | |||||
| id=str(uuid4()), | |||||
| node_id="node1", | |||||
| node_type=NodeType.LLM, | |||||
| selector=["node1", "stream"], | |||||
| chunk="chunk2", | |||||
| is_final=True, | |||||
| ) | |||||
| # Test appending events | |||||
| registry.append_chunk(["node1", "stream"], event1) | |||||
| registry.append_chunk(["node1", "stream"], event2) | |||||
| # Test has_unread | |||||
| assert registry.has_unread(["node1", "stream"]) is True | |||||
| # Test popping events | |||||
| popped_event1 = registry.pop_chunk(["node1", "stream"]) | |||||
| assert popped_event1 == event1 | |||||
| assert popped_event1.chunk == "chunk1" | |||||
| popped_event2 = registry.pop_chunk(["node1", "stream"]) | |||||
| assert popped_event2 == event2 | |||||
| assert popped_event2.chunk == "chunk2" | |||||
| assert registry.pop_chunk(["node1", "stream"]) is None | |||||
| # Test has_unread after popping all | |||||
| assert registry.has_unread(["node1", "stream"]) is False | |||||
| def test_stream_closing(self): | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Test stream is not closed initially | |||||
| assert registry.stream_closed(["node1", "stream"]) is False | |||||
| # Test closing stream | |||||
| registry.close_stream(["node1", "stream"]) | |||||
| assert registry.stream_closed(["node1", "stream"]) is True | |||||
| # Test appending to closed stream raises error | |||||
| event = NodeRunStreamChunkEvent( | |||||
| id=str(uuid4()), | |||||
| node_id="node1", | |||||
| node_type=NodeType.LLM, | |||||
| selector=["node1", "stream"], | |||||
| chunk="chunk", | |||||
| is_final=False, | |||||
| ) | |||||
| with pytest.raises(ValueError, match="Stream node1.stream is already closed"): | |||||
| registry.append_chunk(["node1", "stream"], event) | |||||
| def test_thread_safety(self): | |||||
| import threading | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| results = [] | |||||
| def append_chunks(thread_id: int): | |||||
| for i in range(100): | |||||
| event = NodeRunStreamChunkEvent( | |||||
| id=str(uuid4()), | |||||
| node_id="test_node", | |||||
| node_type=NodeType.LLM, | |||||
| selector=["stream"], | |||||
| chunk=f"thread{thread_id}_chunk{i}", | |||||
| is_final=False, | |||||
| ) | |||||
| registry.append_chunk(["stream"], event) | |||||
| # Start multiple threads | |||||
| threads = [] | |||||
| for i in range(5): | |||||
| thread = threading.Thread(target=append_chunks, args=(i,)) | |||||
| threads.append(thread) | |||||
| thread.start() | |||||
| # Wait for threads | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # Verify all events are present | |||||
| events = [] | |||||
| while True: | |||||
| event = registry.pop_chunk(["stream"]) | |||||
| if event is None: | |||||
| break | |||||
| events.append(event) | |||||
| assert len(events) == 500 # 5 threads * 100 events each | |||||
| # Verify the events have the expected chunk content format | |||||
| chunk_texts = [e.chunk for e in events] | |||||
| for i in range(5): | |||||
| for j in range(100): | |||||
| assert f"thread{i}_chunk{j}" in chunk_texts |
| """Test cases for ResponseStreamCoordinator.""" | |||||
| from unittest.mock import Mock | |||||
| from core.variables import StringSegment | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.enums import NodeState, NodeType | |||||
| from core.workflow.graph import Graph | |||||
| from core.workflow.graph_engine.output_registry import OutputRegistry | |||||
| from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator | |||||
| from core.workflow.graph_engine.response_coordinator.session import ResponseSession | |||||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||||
| from core.workflow.nodes.base.node import Node | |||||
| from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment | |||||
| class TestResponseStreamCoordinator: | |||||
| """Test cases for ResponseStreamCoordinator.""" | |||||
| def test_skip_variable_segment_from_skipped_node(self): | |||||
| """Test that VariableSegments from skipped nodes are properly skipped during try_flush.""" | |||||
| # Create mock graph | |||||
| graph = Mock(spec=Graph) | |||||
| # Create mock nodes | |||||
| skipped_node = Mock(spec=Node) | |||||
| skipped_node.id = "skipped_node" | |||||
| skipped_node.state = NodeState.SKIPPED | |||||
| skipped_node.node_type = NodeType.LLM | |||||
| active_node = Mock(spec=Node) | |||||
| active_node.id = "active_node" | |||||
| active_node.state = NodeState.TAKEN | |||||
| active_node.node_type = NodeType.LLM | |||||
| response_node = Mock(spec=AnswerNode) | |||||
| response_node.id = "response_node" | |||||
| response_node.node_type = NodeType.ANSWER | |||||
| # Set up graph nodes dictionary | |||||
| graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node} | |||||
| # Create output registry with variable pool | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Add some test data to registry for the active node | |||||
| registry.set_scalar(("active_node", "output"), StringSegment(value="Active output")) | |||||
| # Create RSC instance | |||||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||||
| # Create template with segments from both skipped and active nodes | |||||
| template = Template( | |||||
| segments=[ | |||||
| VariableSegment(selector=["skipped_node", "output"]), | |||||
| TextSegment(text=" - "), | |||||
| VariableSegment(selector=["active_node", "output"]), | |||||
| ] | |||||
| ) | |||||
| # Create and set active session | |||||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||||
| rsc.active_session = session | |||||
| # Execute try_flush | |||||
| events = rsc.try_flush() | |||||
| # Verify that: | |||||
| # 1. The skipped node's variable segment was skipped (index advanced) | |||||
| # 2. The text segment was processed | |||||
| # 3. The active node's variable segment was processed | |||||
| assert len(events) == 2 # TextSegment + VariableSegment from active_node | |||||
| # Check that the first event is the text segment | |||||
| assert events[0].chunk == " - " | |||||
| # Check that the second event is from the active node | |||||
| assert events[1].chunk == "Active output" | |||||
| assert events[1].selector == ["active_node", "output"] | |||||
| # Session should be complete | |||||
| assert session.is_complete() | |||||
| def test_process_variable_segment_from_non_skipped_node(self): | |||||
| """Test that VariableSegments from non-skipped nodes are processed normally.""" | |||||
| # Create mock graph | |||||
| graph = Mock(spec=Graph) | |||||
| # Create mock nodes | |||||
| active_node1 = Mock(spec=Node) | |||||
| active_node1.id = "node1" | |||||
| active_node1.state = NodeState.TAKEN | |||||
| active_node1.node_type = NodeType.LLM | |||||
| active_node2 = Mock(spec=Node) | |||||
| active_node2.id = "node2" | |||||
| active_node2.state = NodeState.TAKEN | |||||
| active_node2.node_type = NodeType.LLM | |||||
| response_node = Mock(spec=AnswerNode) | |||||
| response_node.id = "response_node" | |||||
| response_node.node_type = NodeType.ANSWER | |||||
| # Set up graph nodes dictionary | |||||
| graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node} | |||||
| # Create output registry with variable pool | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Add test data to registry | |||||
| registry.set_scalar(("node1", "output"), StringSegment(value="Output 1")) | |||||
| registry.set_scalar(("node2", "output"), StringSegment(value="Output 2")) | |||||
| # Create RSC instance | |||||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||||
| # Create template with segments from active nodes | |||||
| template = Template( | |||||
| segments=[ | |||||
| VariableSegment(selector=["node1", "output"]), | |||||
| TextSegment(text=" | "), | |||||
| VariableSegment(selector=["node2", "output"]), | |||||
| ] | |||||
| ) | |||||
| # Create and set active session | |||||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||||
| rsc.active_session = session | |||||
| # Execute try_flush | |||||
| events = rsc.try_flush() | |||||
| # Verify all segments were processed | |||||
| assert len(events) == 3 | |||||
| # Check events in order | |||||
| assert events[0].chunk == "Output 1" | |||||
| assert events[0].selector == ["node1", "output"] | |||||
| assert events[1].chunk == " | " | |||||
| assert events[2].chunk == "Output 2" | |||||
| assert events[2].selector == ["node2", "output"] | |||||
| # Session should be complete | |||||
| assert session.is_complete() | |||||
| def test_mixed_skipped_and_active_nodes(self): | |||||
| """Test processing with a mix of skipped and active nodes.""" | |||||
| # Create mock graph | |||||
| graph = Mock(spec=Graph) | |||||
| # Create mock nodes with various states | |||||
| skipped_node1 = Mock(spec=Node) | |||||
| skipped_node1.id = "skip1" | |||||
| skipped_node1.state = NodeState.SKIPPED | |||||
| skipped_node1.node_type = NodeType.LLM | |||||
| active_node = Mock(spec=Node) | |||||
| active_node.id = "active" | |||||
| active_node.state = NodeState.TAKEN | |||||
| active_node.node_type = NodeType.LLM | |||||
| skipped_node2 = Mock(spec=Node) | |||||
| skipped_node2.id = "skip2" | |||||
| skipped_node2.state = NodeState.SKIPPED | |||||
| skipped_node2.node_type = NodeType.LLM | |||||
| response_node = Mock(spec=AnswerNode) | |||||
| response_node.id = "response_node" | |||||
| response_node.node_type = NodeType.ANSWER | |||||
| # Set up graph nodes dictionary | |||||
| graph.nodes = { | |||||
| "skip1": skipped_node1, | |||||
| "active": active_node, | |||||
| "skip2": skipped_node2, | |||||
| "response_node": response_node, | |||||
| } | |||||
| # Create output registry with variable pool | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Add data only for active node | |||||
| registry.set_scalar(("active", "result"), StringSegment(value="Active Result")) | |||||
| # Create RSC instance | |||||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||||
| # Create template with mixed segments | |||||
| template = Template( | |||||
| segments=[ | |||||
| TextSegment(text="Start: "), | |||||
| VariableSegment(selector=["skip1", "output"]), | |||||
| VariableSegment(selector=["active", "result"]), | |||||
| VariableSegment(selector=["skip2", "output"]), | |||||
| TextSegment(text=" :End"), | |||||
| ] | |||||
| ) | |||||
| # Create and set active session | |||||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||||
| rsc.active_session = session | |||||
| # Execute try_flush | |||||
| events = rsc.try_flush() | |||||
| # Should have: "Start: ", "Active Result", " :End" | |||||
| assert len(events) == 3 | |||||
| assert events[0].chunk == "Start: " | |||||
| assert events[1].chunk == "Active Result" | |||||
| assert events[1].selector == ["active", "result"] | |||||
| assert events[2].chunk == " :End" | |||||
| # Session should be complete | |||||
| assert session.is_complete() | |||||
| def test_all_variable_segments_skipped(self): | |||||
| """Test when all VariableSegments are from skipped nodes.""" | |||||
| # Create mock graph | |||||
| graph = Mock(spec=Graph) | |||||
| # Create all skipped nodes | |||||
| skipped_node1 = Mock(spec=Node) | |||||
| skipped_node1.id = "skip1" | |||||
| skipped_node1.state = NodeState.SKIPPED | |||||
| skipped_node1.node_type = NodeType.LLM | |||||
| skipped_node2 = Mock(spec=Node) | |||||
| skipped_node2.id = "skip2" | |||||
| skipped_node2.state = NodeState.SKIPPED | |||||
| skipped_node2.node_type = NodeType.LLM | |||||
| response_node = Mock(spec=AnswerNode) | |||||
| response_node.id = "response_node" | |||||
| response_node.node_type = NodeType.ANSWER | |||||
| # Set up graph nodes dictionary | |||||
| graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node} | |||||
| # Create output registry (empty since nodes are skipped) with variable pool | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| # Create RSC instance | |||||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||||
| # Create template with only skipped segments | |||||
| template = Template( | |||||
| segments=[ | |||||
| VariableSegment(selector=["skip1", "output"]), | |||||
| VariableSegment(selector=["skip2", "output"]), | |||||
| TextSegment(text="Final text"), | |||||
| ] | |||||
| ) | |||||
| # Create and set active session | |||||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||||
| rsc.active_session = session | |||||
| # Execute try_flush | |||||
| events = rsc.try_flush() | |||||
| # Should only have the final text segment | |||||
| assert len(events) == 1 | |||||
| assert events[0].chunk == "Final text" | |||||
| # Session should be complete | |||||
| assert session.is_complete() | |||||
| def test_special_prefix_selectors(self): | |||||
| """Test that special prefix selectors (sys, env, conversation) are handled correctly.""" | |||||
| # Create mock graph | |||||
| graph = Mock(spec=Graph) | |||||
| # Create response node | |||||
| response_node = Mock(spec=AnswerNode) | |||||
| response_node.id = "response_node" | |||||
| response_node.node_type = NodeType.ANSWER | |||||
| # Set up graph nodes dictionary (no sys, env, conversation nodes) | |||||
| graph.nodes = {"response_node": response_node} | |||||
| # Create output registry with special selector data and variable pool | |||||
| variable_pool = VariablePool() | |||||
| registry = OutputRegistry(variable_pool) | |||||
| registry.set_scalar(("sys", "user_id"), StringSegment(value="user123")) | |||||
| registry.set_scalar(("env", "api_key"), StringSegment(value="key456")) | |||||
| registry.set_scalar(("conversation", "id"), StringSegment(value="conv789")) | |||||
| # Create RSC instance | |||||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||||
| # Create template with special selectors | |||||
| template = Template( | |||||
| segments=[ | |||||
| TextSegment(text="User: "), | |||||
| VariableSegment(selector=["sys", "user_id"]), | |||||
| TextSegment(text=", API: "), | |||||
| VariableSegment(selector=["env", "api_key"]), | |||||
| TextSegment(text=", Conv: "), | |||||
| VariableSegment(selector=["conversation", "id"]), | |||||
| ] | |||||
| ) | |||||
| # Create and set active session | |||||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||||
| rsc.active_session = session | |||||
| # Execute try_flush | |||||
| events = rsc.try_flush() | |||||
| # Should have all segments processed | |||||
| assert len(events) == 6 | |||||
| # Check text segments | |||||
| assert events[0].chunk == "User: " | |||||
| assert events[0].node_id == "response_node" | |||||
| # Check sys selector - should use response node's info | |||||
| assert events[1].chunk == "user123" | |||||
| assert events[1].selector == ["sys", "user_id"] | |||||
| assert events[1].node_id == "response_node" | |||||
| assert events[1].node_type == NodeType.ANSWER | |||||
| assert events[2].chunk == ", API: " | |||||
| # Check env selector - should use response node's info | |||||
| assert events[3].chunk == "key456" | |||||
| assert events[3].selector == ["env", "api_key"] | |||||
| assert events[3].node_id == "response_node" | |||||
| assert events[3].node_type == NodeType.ANSWER | |||||
| assert events[4].chunk == ", Conv: " | |||||
| # Check conversation selector - should use response node's info | |||||
| assert events[5].chunk == "conv789" | |||||
| assert events[5].selector == ["conversation", "id"] | |||||
| assert events[5].node_id == "response_node" | |||||
| assert events[5].node_type == NodeType.ANSWER | |||||
| # Session should be complete | |||||
| assert session.is_complete() |