Signed-off-by: -LAN- <laipz8200@outlook.com>tags/2.0.0-beta.1
| @@ -37,7 +37,6 @@ type = layers | |||
| layers = | |||
| graph_engine | |||
| response_coordinator | |||
| output_registry | |||
| containers = | |||
| core.workflow.graph_engine | |||
| @@ -35,7 +35,6 @@ from .event_management import EventCollector, EventEmitter, EventHandlerRegistry | |||
| from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator | |||
| from .layers.base import Layer | |||
| from .orchestration import Dispatcher, ExecutionCoordinator | |||
| from .output_registry import OutputRegistry | |||
| from .protocols.command_channel import CommandChannel | |||
| from .response_coordinator import ResponseStreamCoordinator | |||
| from .state_management import UnifiedStateManager | |||
| @@ -122,8 +121,9 @@ class GraphEngine: | |||
| self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) | |||
| # Response coordination | |||
| self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool) | |||
| self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph) | |||
| self.response_coordinator = ResponseStreamCoordinator( | |||
| variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph | |||
| ) | |||
| # Event management | |||
| self.event_collector = EventCollector() | |||
| @@ -1,10 +0,0 @@ | |||
| """ | |||
| OutputRegistry - Thread-safe storage for node outputs (streams and scalars) | |||
| This component provides thread-safe storage and retrieval of node outputs, | |||
| supporting both scalar values and streaming chunks with proper state management. | |||
| """ | |||
| from .registry import OutputRegistry | |||
| __all__ = ["OutputRegistry"] | |||
| @@ -1,148 +0,0 @@ | |||
| """ | |||
| Main OutputRegistry implementation. | |||
| This module contains the public OutputRegistry class that provides | |||
| thread-safe storage for node outputs. | |||
| """ | |||
| from collections.abc import Sequence | |||
| from threading import RLock | |||
| from typing import TYPE_CHECKING, Any, Union, final | |||
| from core.variables import Segment | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from .stream import Stream | |||
| if TYPE_CHECKING: | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| @final | |||
| class OutputRegistry: | |||
| """ | |||
| Thread-safe registry for storing and retrieving node outputs. | |||
| Supports both scalar values and streaming chunks with proper state management. | |||
| All operations are thread-safe using internal locking. | |||
| """ | |||
| def __init__(self, variable_pool: VariablePool) -> None: | |||
| """Initialize empty registry with thread-safe storage.""" | |||
| self._lock = RLock() | |||
| self._scalars = variable_pool | |||
| self._streams: dict[tuple[str, ...], Stream] = {} | |||
| def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]: | |||
| """Convert selector list to tuple key for internal storage.""" | |||
| return tuple(selector) | |||
| def set_scalar( | |||
| self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]] | |||
| ) -> None: | |||
| """ | |||
| Set a scalar value for the given selector. | |||
| Args: | |||
| selector: List of strings identifying the output location | |||
| value: The scalar value to store | |||
| """ | |||
| with self._lock: | |||
| self._scalars.add(selector, value) | |||
| def get_scalar(self, selector: Sequence[str]) -> "Segment | None": | |||
| """ | |||
| Get a scalar value for the given selector. | |||
| Args: | |||
| selector: List of strings identifying the output location | |||
| Returns: | |||
| The stored Variable object, or None if not found | |||
| """ | |||
| with self._lock: | |||
| return self._scalars.get(selector) | |||
| def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None: | |||
| """ | |||
| Append a NodeRunStreamChunkEvent to the stream for the given selector. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| event: The NodeRunStreamChunkEvent to append | |||
| Raises: | |||
| ValueError: If the stream is already closed | |||
| """ | |||
| key = self._selector_to_key(selector) | |||
| with self._lock: | |||
| if key not in self._streams: | |||
| self._streams[key] = Stream() | |||
| try: | |||
| self._streams[key].append(event) | |||
| except ValueError: | |||
| raise ValueError(f"Stream {'.'.join(selector)} is already closed") | |||
| def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": | |||
| """ | |||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| The next event, or None if no unread events available | |||
| """ | |||
| key = self._selector_to_key(selector) | |||
| with self._lock: | |||
| if key not in self._streams: | |||
| return None | |||
| return self._streams[key].pop_next() | |||
| def has_unread(self, selector: Sequence[str]) -> bool: | |||
| """ | |||
| Check if the stream has unread events. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| True if there are unread events, False otherwise | |||
| """ | |||
| key = self._selector_to_key(selector) | |||
| with self._lock: | |||
| if key not in self._streams: | |||
| return False | |||
| return self._streams[key].has_unread() | |||
| def close_stream(self, selector: Sequence[str]) -> None: | |||
| """ | |||
| Mark a stream as closed (no more chunks can be appended). | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| """ | |||
| key = self._selector_to_key(selector) | |||
| with self._lock: | |||
| if key not in self._streams: | |||
| self._streams[key] = Stream() | |||
| self._streams[key].close() | |||
| def stream_closed(self, selector: Sequence[str]) -> bool: | |||
| """ | |||
| Check if a stream is closed. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| True if the stream is closed, False otherwise | |||
| """ | |||
| key = self._selector_to_key(selector) | |||
| with self._lock: | |||
| if key not in self._streams: | |||
| return False | |||
| return self._streams[key].is_closed | |||
| @@ -1,70 +0,0 @@ | |||
| """ | |||
| Internal stream implementation for OutputRegistry. | |||
| This module contains the private Stream class used internally by OutputRegistry | |||
| to manage streaming data chunks. | |||
| """ | |||
| from typing import TYPE_CHECKING, final | |||
| if TYPE_CHECKING: | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| @final | |||
| class Stream: | |||
| """ | |||
| A stream that holds NodeRunStreamChunkEvent objects and tracks read position. | |||
| This class encapsulates stream-specific data and operations, | |||
| including event storage, read position tracking, and closed state. | |||
| Note: This is an internal class not exposed in the public API. | |||
| """ | |||
| def __init__(self) -> None: | |||
| """Initialize an empty stream.""" | |||
| self.events: list[NodeRunStreamChunkEvent] = [] | |||
| self.read_position: int = 0 | |||
| self.is_closed: bool = False | |||
| def append(self, event: "NodeRunStreamChunkEvent") -> None: | |||
| """ | |||
| Append a NodeRunStreamChunkEvent to the stream. | |||
| Args: | |||
| event: The NodeRunStreamChunkEvent to append | |||
| Raises: | |||
| ValueError: If the stream is already closed | |||
| """ | |||
| if self.is_closed: | |||
| raise ValueError("Cannot append to a closed stream") | |||
| self.events.append(event) | |||
| def pop_next(self) -> "NodeRunStreamChunkEvent | None": | |||
| """ | |||
| Pop the next unread NodeRunStreamChunkEvent from the stream. | |||
| Returns: | |||
| The next event, or None if no unread events available | |||
| """ | |||
| if self.read_position >= len(self.events): | |||
| return None | |||
| event = self.events[self.read_position] | |||
| self.read_position += 1 | |||
| return event | |||
| def has_unread(self) -> bool: | |||
| """ | |||
| Check if the stream has unread events. | |||
| Returns: | |||
| True if there are unread events, False otherwise | |||
| """ | |||
| return self.read_position < len(self.events) | |||
| def close(self) -> None: | |||
| """Mark the stream as closed (no more chunks can be appended).""" | |||
| self.is_closed = True | |||
| @@ -12,12 +12,12 @@ from threading import RLock | |||
| from typing import TypeAlias, final | |||
| from uuid import uuid4 | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import NodeExecutionType, NodeState | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent | |||
| from core.workflow.nodes.base.template import TextSegment, VariableSegment | |||
| from ..output_registry import OutputRegistry | |||
| from .path import Path | |||
| from .session import ResponseSession | |||
| @@ -36,20 +36,25 @@ class ResponseStreamCoordinator: | |||
| Ensures ordered streaming of responses based on upstream node outputs and constants. | |||
| """ | |||
| def __init__(self, registry: OutputRegistry, graph: "Graph") -> None: | |||
| def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: | |||
| """ | |||
| Initialize coordinator with output registry. | |||
| Initialize coordinator with variable pool. | |||
| Args: | |||
| registry: OutputRegistry instance for accessing node outputs | |||
| variable_pool: VariablePool instance for accessing node variables | |||
| graph: Graph instance for looking up node information | |||
| """ | |||
| self.registry = registry | |||
| self.variable_pool = variable_pool | |||
| self.graph = graph | |||
| self.active_session: ResponseSession | None = None | |||
| self.waiting_sessions: deque[ResponseSession] = deque() | |||
| self.lock = RLock() | |||
| # Internal stream management (replacing OutputRegistry) | |||
| self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} | |||
| self._stream_positions: dict[tuple[str, ...], int] = {} | |||
| self._closed_streams: set[tuple[str, ...]] = set() | |||
| # Track response nodes | |||
| self._response_nodes: set[NodeID] = set() | |||
| @@ -256,15 +261,15 @@ class ResponseStreamCoordinator: | |||
| ) -> Sequence[NodeRunStreamChunkEvent]: | |||
| with self.lock: | |||
| if isinstance(event, NodeRunStreamChunkEvent): | |||
| self.registry.append_chunk(event.selector, event) | |||
| self._append_stream_chunk(event.selector, event) | |||
| if event.is_final: | |||
| self.registry.close_stream(event.selector) | |||
| self._close_stream(event.selector) | |||
| return self.try_flush() | |||
| else: | |||
| # Skip cause we share the same variable pool. | |||
| # | |||
| # for variable_name, variable_value in event.node_run_result.outputs.items(): | |||
| # self.registry.set_scalar((event.node_id, variable_name), variable_value) | |||
| # self.variable_pool.add((event.node_id, variable_name), variable_value) | |||
| return self.try_flush() | |||
| return [] | |||
| @@ -327,8 +332,8 @@ class ResponseStreamCoordinator: | |||
| execution_id = self._get_or_create_execution_id(output_node_id) | |||
| # Stream all available chunks | |||
| while self.registry.has_unread(segment.selector): | |||
| if event := self.registry.pop_chunk(segment.selector): | |||
| while self._has_unread_stream(segment.selector): | |||
| if event := self._pop_stream_chunk(segment.selector): | |||
| # For special selectors, we need to update the event to use | |||
| # the active response node's information | |||
| if self.active_session and source_selector_prefix not in self.graph.nodes: | |||
| @@ -349,12 +354,12 @@ class ResponseStreamCoordinator: | |||
| events.append(event) | |||
| # Check if this is the last chunk by looking ahead | |||
| stream_closed = self.registry.stream_closed(segment.selector) | |||
| stream_closed = self._is_stream_closed(segment.selector) | |||
| # Check if stream is closed to determine if segment is complete | |||
| if stream_closed: | |||
| is_complete = True | |||
| elif value := self.registry.get_scalar(segment.selector): | |||
| elif value := self.variable_pool.get(segment.selector): | |||
| # Process scalar value | |||
| is_last_segment = bool( | |||
| self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 | |||
| @@ -464,3 +469,93 @@ class ResponseStreamCoordinator: | |||
| events = self.try_flush() | |||
| return events | |||
| # ============= Internal Stream Management Methods ============= | |||
| def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: | |||
| """ | |||
| Append a stream chunk to the internal buffer. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| event: The NodeRunStreamChunkEvent to append | |||
| Raises: | |||
| ValueError: If the stream is already closed | |||
| """ | |||
| key = tuple(selector) | |||
| if key in self._closed_streams: | |||
| raise ValueError(f"Stream {'.'.join(selector)} is already closed") | |||
| if key not in self._stream_buffers: | |||
| self._stream_buffers[key] = [] | |||
| self._stream_positions[key] = 0 | |||
| self._stream_buffers[key].append(event) | |||
| def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: | |||
| """ | |||
| Pop the next unread stream chunk from the buffer. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| The next event, or None if no unread events available | |||
| """ | |||
| key = tuple(selector) | |||
| if key not in self._stream_buffers: | |||
| return None | |||
| position = self._stream_positions.get(key, 0) | |||
| buffer = self._stream_buffers[key] | |||
| if position >= len(buffer): | |||
| return None | |||
| event = buffer[position] | |||
| self._stream_positions[key] = position + 1 | |||
| return event | |||
| def _has_unread_stream(self, selector: Sequence[str]) -> bool: | |||
| """ | |||
| Check if the stream has unread events. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| True if there are unread events, False otherwise | |||
| """ | |||
| key = tuple(selector) | |||
| if key not in self._stream_buffers: | |||
| return False | |||
| position = self._stream_positions.get(key, 0) | |||
| return position < len(self._stream_buffers[key]) | |||
| def _close_stream(self, selector: Sequence[str]) -> None: | |||
| """ | |||
| Mark a stream as closed (no more chunks can be appended). | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| """ | |||
| key = tuple(selector) | |||
| self._closed_streams.add(key) | |||
| def _is_stream_closed(self, selector: Sequence[str]) -> bool: | |||
| """ | |||
| Check if a stream is closed. | |||
| Args: | |||
| selector: List of strings identifying the stream location | |||
| Returns: | |||
| True if the stream is closed, False otherwise | |||
| """ | |||
| key = tuple(selector) | |||
| return key in self._closed_streams | |||
| @@ -1,135 +0,0 @@ | |||
| from uuid import uuid4 | |||
| import pytest | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import NodeType | |||
| from core.workflow.graph_engine.output_registry import OutputRegistry | |||
| from core.workflow.graph_events import NodeRunStreamChunkEvent | |||
| class TestOutputRegistry: | |||
| def test_scalar_operations(self): | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Test setting and getting scalar | |||
| registry.set_scalar(["node1", "output"], "test_value") | |||
| segment = registry.get_scalar(["node1", "output"]) | |||
| assert segment | |||
| assert segment.text == "test_value" | |||
| # Test getting non-existent scalar | |||
| assert registry.get_scalar(["non_existent"]) is None | |||
| def test_stream_operations(self): | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Create test events | |||
| event1 = NodeRunStreamChunkEvent( | |||
| id=str(uuid4()), | |||
| node_id="node1", | |||
| node_type=NodeType.LLM, | |||
| selector=["node1", "stream"], | |||
| chunk="chunk1", | |||
| is_final=False, | |||
| ) | |||
| event2 = NodeRunStreamChunkEvent( | |||
| id=str(uuid4()), | |||
| node_id="node1", | |||
| node_type=NodeType.LLM, | |||
| selector=["node1", "stream"], | |||
| chunk="chunk2", | |||
| is_final=True, | |||
| ) | |||
| # Test appending events | |||
| registry.append_chunk(["node1", "stream"], event1) | |||
| registry.append_chunk(["node1", "stream"], event2) | |||
| # Test has_unread | |||
| assert registry.has_unread(["node1", "stream"]) is True | |||
| # Test popping events | |||
| popped_event1 = registry.pop_chunk(["node1", "stream"]) | |||
| assert popped_event1 == event1 | |||
| assert popped_event1.chunk == "chunk1" | |||
| popped_event2 = registry.pop_chunk(["node1", "stream"]) | |||
| assert popped_event2 == event2 | |||
| assert popped_event2.chunk == "chunk2" | |||
| assert registry.pop_chunk(["node1", "stream"]) is None | |||
| # Test has_unread after popping all | |||
| assert registry.has_unread(["node1", "stream"]) is False | |||
| def test_stream_closing(self): | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Test stream is not closed initially | |||
| assert registry.stream_closed(["node1", "stream"]) is False | |||
| # Test closing stream | |||
| registry.close_stream(["node1", "stream"]) | |||
| assert registry.stream_closed(["node1", "stream"]) is True | |||
| # Test appending to closed stream raises error | |||
| event = NodeRunStreamChunkEvent( | |||
| id=str(uuid4()), | |||
| node_id="node1", | |||
| node_type=NodeType.LLM, | |||
| selector=["node1", "stream"], | |||
| chunk="chunk", | |||
| is_final=False, | |||
| ) | |||
| with pytest.raises(ValueError, match="Stream node1.stream is already closed"): | |||
| registry.append_chunk(["node1", "stream"], event) | |||
| def test_thread_safety(self): | |||
| import threading | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| results = [] | |||
| def append_chunks(thread_id: int): | |||
| for i in range(100): | |||
| event = NodeRunStreamChunkEvent( | |||
| id=str(uuid4()), | |||
| node_id="test_node", | |||
| node_type=NodeType.LLM, | |||
| selector=["stream"], | |||
| chunk=f"thread{thread_id}_chunk{i}", | |||
| is_final=False, | |||
| ) | |||
| registry.append_chunk(["stream"], event) | |||
| # Start multiple threads | |||
| threads = [] | |||
| for i in range(5): | |||
| thread = threading.Thread(target=append_chunks, args=(i,)) | |||
| threads.append(thread) | |||
| thread.start() | |||
| # Wait for threads | |||
| for thread in threads: | |||
| thread.join() | |||
| # Verify all events are present | |||
| events = [] | |||
| while True: | |||
| event = registry.pop_chunk(["stream"]) | |||
| if event is None: | |||
| break | |||
| events.append(event) | |||
| assert len(events) == 500 # 5 threads * 100 events each | |||
| # Verify the events have the expected chunk content format | |||
| chunk_texts = [e.chunk for e in events] | |||
| for i in range(5): | |||
| for j in range(100): | |||
| assert f"thread{i}_chunk{j}" in chunk_texts | |||
| @@ -1,347 +0,0 @@ | |||
| """Test cases for ResponseStreamCoordinator.""" | |||
| from unittest.mock import Mock | |||
| from core.variables import StringSegment | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import NodeState, NodeType | |||
| from core.workflow.graph import Graph | |||
| from core.workflow.graph_engine.output_registry import OutputRegistry | |||
| from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator | |||
| from core.workflow.graph_engine.response_coordinator.session import ResponseSession | |||
| from core.workflow.nodes.answer.answer_node import AnswerNode | |||
| from core.workflow.nodes.base.node import Node | |||
| from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment | |||
| class TestResponseStreamCoordinator: | |||
| """Test cases for ResponseStreamCoordinator.""" | |||
| def test_skip_variable_segment_from_skipped_node(self): | |||
| """Test that VariableSegments from skipped nodes are properly skipped during try_flush.""" | |||
| # Create mock graph | |||
| graph = Mock(spec=Graph) | |||
| # Create mock nodes | |||
| skipped_node = Mock(spec=Node) | |||
| skipped_node.id = "skipped_node" | |||
| skipped_node.state = NodeState.SKIPPED | |||
| skipped_node.node_type = NodeType.LLM | |||
| active_node = Mock(spec=Node) | |||
| active_node.id = "active_node" | |||
| active_node.state = NodeState.TAKEN | |||
| active_node.node_type = NodeType.LLM | |||
| response_node = Mock(spec=AnswerNode) | |||
| response_node.id = "response_node" | |||
| response_node.node_type = NodeType.ANSWER | |||
| # Set up graph nodes dictionary | |||
| graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node} | |||
| # Create output registry with variable pool | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Add some test data to registry for the active node | |||
| registry.set_scalar(("active_node", "output"), StringSegment(value="Active output")) | |||
| # Create RSC instance | |||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||
| # Create template with segments from both skipped and active nodes | |||
| template = Template( | |||
| segments=[ | |||
| VariableSegment(selector=["skipped_node", "output"]), | |||
| TextSegment(text=" - "), | |||
| VariableSegment(selector=["active_node", "output"]), | |||
| ] | |||
| ) | |||
| # Create and set active session | |||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||
| rsc.active_session = session | |||
| # Execute try_flush | |||
| events = rsc.try_flush() | |||
| # Verify that: | |||
| # 1. The skipped node's variable segment was skipped (index advanced) | |||
| # 2. The text segment was processed | |||
| # 3. The active node's variable segment was processed | |||
| assert len(events) == 2 # TextSegment + VariableSegment from active_node | |||
| # Check that the first event is the text segment | |||
| assert events[0].chunk == " - " | |||
| # Check that the second event is from the active node | |||
| assert events[1].chunk == "Active output" | |||
| assert events[1].selector == ["active_node", "output"] | |||
| # Session should be complete | |||
| assert session.is_complete() | |||
| def test_process_variable_segment_from_non_skipped_node(self): | |||
| """Test that VariableSegments from non-skipped nodes are processed normally.""" | |||
| # Create mock graph | |||
| graph = Mock(spec=Graph) | |||
| # Create mock nodes | |||
| active_node1 = Mock(spec=Node) | |||
| active_node1.id = "node1" | |||
| active_node1.state = NodeState.TAKEN | |||
| active_node1.node_type = NodeType.LLM | |||
| active_node2 = Mock(spec=Node) | |||
| active_node2.id = "node2" | |||
| active_node2.state = NodeState.TAKEN | |||
| active_node2.node_type = NodeType.LLM | |||
| response_node = Mock(spec=AnswerNode) | |||
| response_node.id = "response_node" | |||
| response_node.node_type = NodeType.ANSWER | |||
| # Set up graph nodes dictionary | |||
| graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node} | |||
| # Create output registry with variable pool | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Add test data to registry | |||
| registry.set_scalar(("node1", "output"), StringSegment(value="Output 1")) | |||
| registry.set_scalar(("node2", "output"), StringSegment(value="Output 2")) | |||
| # Create RSC instance | |||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||
| # Create template with segments from active nodes | |||
| template = Template( | |||
| segments=[ | |||
| VariableSegment(selector=["node1", "output"]), | |||
| TextSegment(text=" | "), | |||
| VariableSegment(selector=["node2", "output"]), | |||
| ] | |||
| ) | |||
| # Create and set active session | |||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||
| rsc.active_session = session | |||
| # Execute try_flush | |||
| events = rsc.try_flush() | |||
| # Verify all segments were processed | |||
| assert len(events) == 3 | |||
| # Check events in order | |||
| assert events[0].chunk == "Output 1" | |||
| assert events[0].selector == ["node1", "output"] | |||
| assert events[1].chunk == " | " | |||
| assert events[2].chunk == "Output 2" | |||
| assert events[2].selector == ["node2", "output"] | |||
| # Session should be complete | |||
| assert session.is_complete() | |||
| def test_mixed_skipped_and_active_nodes(self): | |||
| """Test processing with a mix of skipped and active nodes.""" | |||
| # Create mock graph | |||
| graph = Mock(spec=Graph) | |||
| # Create mock nodes with various states | |||
| skipped_node1 = Mock(spec=Node) | |||
| skipped_node1.id = "skip1" | |||
| skipped_node1.state = NodeState.SKIPPED | |||
| skipped_node1.node_type = NodeType.LLM | |||
| active_node = Mock(spec=Node) | |||
| active_node.id = "active" | |||
| active_node.state = NodeState.TAKEN | |||
| active_node.node_type = NodeType.LLM | |||
| skipped_node2 = Mock(spec=Node) | |||
| skipped_node2.id = "skip2" | |||
| skipped_node2.state = NodeState.SKIPPED | |||
| skipped_node2.node_type = NodeType.LLM | |||
| response_node = Mock(spec=AnswerNode) | |||
| response_node.id = "response_node" | |||
| response_node.node_type = NodeType.ANSWER | |||
| # Set up graph nodes dictionary | |||
| graph.nodes = { | |||
| "skip1": skipped_node1, | |||
| "active": active_node, | |||
| "skip2": skipped_node2, | |||
| "response_node": response_node, | |||
| } | |||
| # Create output registry with variable pool | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Add data only for active node | |||
| registry.set_scalar(("active", "result"), StringSegment(value="Active Result")) | |||
| # Create RSC instance | |||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||
| # Create template with mixed segments | |||
| template = Template( | |||
| segments=[ | |||
| TextSegment(text="Start: "), | |||
| VariableSegment(selector=["skip1", "output"]), | |||
| VariableSegment(selector=["active", "result"]), | |||
| VariableSegment(selector=["skip2", "output"]), | |||
| TextSegment(text=" :End"), | |||
| ] | |||
| ) | |||
| # Create and set active session | |||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||
| rsc.active_session = session | |||
| # Execute try_flush | |||
| events = rsc.try_flush() | |||
| # Should have: "Start: ", "Active Result", " :End" | |||
| assert len(events) == 3 | |||
| assert events[0].chunk == "Start: " | |||
| assert events[1].chunk == "Active Result" | |||
| assert events[1].selector == ["active", "result"] | |||
| assert events[2].chunk == " :End" | |||
| # Session should be complete | |||
| assert session.is_complete() | |||
| def test_all_variable_segments_skipped(self): | |||
| """Test when all VariableSegments are from skipped nodes.""" | |||
| # Create mock graph | |||
| graph = Mock(spec=Graph) | |||
| # Create all skipped nodes | |||
| skipped_node1 = Mock(spec=Node) | |||
| skipped_node1.id = "skip1" | |||
| skipped_node1.state = NodeState.SKIPPED | |||
| skipped_node1.node_type = NodeType.LLM | |||
| skipped_node2 = Mock(spec=Node) | |||
| skipped_node2.id = "skip2" | |||
| skipped_node2.state = NodeState.SKIPPED | |||
| skipped_node2.node_type = NodeType.LLM | |||
| response_node = Mock(spec=AnswerNode) | |||
| response_node.id = "response_node" | |||
| response_node.node_type = NodeType.ANSWER | |||
| # Set up graph nodes dictionary | |||
| graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node} | |||
| # Create output registry (empty since nodes are skipped) with variable pool | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| # Create RSC instance | |||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||
| # Create template with only skipped segments | |||
| template = Template( | |||
| segments=[ | |||
| VariableSegment(selector=["skip1", "output"]), | |||
| VariableSegment(selector=["skip2", "output"]), | |||
| TextSegment(text="Final text"), | |||
| ] | |||
| ) | |||
| # Create and set active session | |||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||
| rsc.active_session = session | |||
| # Execute try_flush | |||
| events = rsc.try_flush() | |||
| # Should only have the final text segment | |||
| assert len(events) == 1 | |||
| assert events[0].chunk == "Final text" | |||
| # Session should be complete | |||
| assert session.is_complete() | |||
| def test_special_prefix_selectors(self): | |||
| """Test that special prefix selectors (sys, env, conversation) are handled correctly.""" | |||
| # Create mock graph | |||
| graph = Mock(spec=Graph) | |||
| # Create response node | |||
| response_node = Mock(spec=AnswerNode) | |||
| response_node.id = "response_node" | |||
| response_node.node_type = NodeType.ANSWER | |||
| # Set up graph nodes dictionary (no sys, env, conversation nodes) | |||
| graph.nodes = {"response_node": response_node} | |||
| # Create output registry with special selector data and variable pool | |||
| variable_pool = VariablePool() | |||
| registry = OutputRegistry(variable_pool) | |||
| registry.set_scalar(("sys", "user_id"), StringSegment(value="user123")) | |||
| registry.set_scalar(("env", "api_key"), StringSegment(value="key456")) | |||
| registry.set_scalar(("conversation", "id"), StringSegment(value="conv789")) | |||
| # Create RSC instance | |||
| rsc = ResponseStreamCoordinator(registry=registry, graph=graph) | |||
| # Create template with special selectors | |||
| template = Template( | |||
| segments=[ | |||
| TextSegment(text="User: "), | |||
| VariableSegment(selector=["sys", "user_id"]), | |||
| TextSegment(text=", API: "), | |||
| VariableSegment(selector=["env", "api_key"]), | |||
| TextSegment(text=", Conv: "), | |||
| VariableSegment(selector=["conversation", "id"]), | |||
| ] | |||
| ) | |||
| # Create and set active session | |||
| session = ResponseSession(node_id="response_node", template=template, index=0) | |||
| rsc.active_session = session | |||
| # Execute try_flush | |||
| events = rsc.try_flush() | |||
| # Should have all segments processed | |||
| assert len(events) == 6 | |||
| # Check text segments | |||
| assert events[0].chunk == "User: " | |||
| assert events[0].node_id == "response_node" | |||
| # Check sys selector - should use response node's info | |||
| assert events[1].chunk == "user123" | |||
| assert events[1].selector == ["sys", "user_id"] | |||
| assert events[1].node_id == "response_node" | |||
| assert events[1].node_type == NodeType.ANSWER | |||
| assert events[2].chunk == ", API: " | |||
| # Check env selector - should use response node's info | |||
| assert events[3].chunk == "key456" | |||
| assert events[3].selector == ["env", "api_key"] | |||
| assert events[3].node_id == "response_node" | |||
| assert events[3].node_type == NodeType.ANSWER | |||
| assert events[4].chunk == ", Conv: " | |||
| # Check conversation selector - should use response node's info | |||
| assert events[5].chunk == "conv789" | |||
| assert events[5].selector == ["conversation", "id"] | |||
| assert events[5].node_id == "response_node" | |||
| assert events[5].node_type == NodeType.ANSWER | |||
| # Session should be complete | |||
| assert session.is_complete() | |||