Ver código fonte

refactor(graph_engine): inline output_registry into response_coordinator

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.1
-LAN- 2 meses atrás
pai
commit
a5cb9d2b73
Nenhuma conta vinculada ao e-mail do autor do commit

+ 0
- 1
api/.importlinter Ver arquivo

layers = layers =
graph_engine graph_engine
response_coordinator response_coordinator
output_registry
containers = containers =
core.workflow.graph_engine core.workflow.graph_engine



+ 3
- 3
api/core/workflow/graph_engine/graph_engine.py Ver arquivo

from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
from .layers.base import Layer from .layers.base import Layer
from .orchestration import Dispatcher, ExecutionCoordinator from .orchestration import Dispatcher, ExecutionCoordinator
from .output_registry import OutputRegistry
from .protocols.command_channel import CommandChannel from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator from .response_coordinator import ResponseStreamCoordinator
from .state_management import UnifiedStateManager from .state_management import UnifiedStateManager
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue) self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)


# Response coordination # Response coordination
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph)
self.response_coordinator = ResponseStreamCoordinator(
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
)


# Event management # Event management
self.event_collector = EventCollector() self.event_collector = EventCollector()

+ 0
- 10
api/core/workflow/graph_engine/output_registry/__init__.py Ver arquivo

"""
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)

This component provides thread-safe storage and retrieval of node outputs,
supporting both scalar values and streaming chunks with proper state management.
"""

from .registry import OutputRegistry

__all__ = ["OutputRegistry"]

+ 0
- 148
api/core/workflow/graph_engine/output_registry/registry.py Ver arquivo

"""
Main OutputRegistry implementation.

This module contains the public OutputRegistry class that provides
thread-safe storage for node outputs.
"""

from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Any, Union, final

from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool

from .stream import Stream

if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.

Supports both scalar values and streaming chunks with proper state management.
All operations are thread-safe using internal locking.
"""

def __init__(self, variable_pool: VariablePool) -> None:
"""Initialize empty registry with thread-safe storage."""
self._lock = RLock()
self._scalars = variable_pool
self._streams: dict[tuple[str, ...], Stream] = {}

def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
"""Convert selector list to tuple key for internal storage."""
return tuple(selector)

def set_scalar(
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
) -> None:
"""
Set a scalar value for the given selector.

Args:
selector: List of strings identifying the output location
value: The scalar value to store
"""
with self._lock:
self._scalars.add(selector, value)

def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.

Args:
selector: List of strings identifying the output location

Returns:
The stored Variable object, or None if not found
"""
with self._lock:
return self._scalars.get(selector)

def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream for the given selector.

Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()

try:
self._streams[key].append(event)
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")

def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

Args:
selector: List of strings identifying the stream location

Returns:
The next event, or None if no unread events available
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return None

return self._streams[key].pop_next()

def has_unread(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.

Args:
selector: List of strings identifying the stream location

Returns:
True if there are unread events, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False

return self._streams[key].has_unread()

def close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).

Args:
selector: List of strings identifying the stream location
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
self._streams[key].close()

def stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.

Args:
selector: List of strings identifying the stream location

Returns:
True if the stream is closed, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].is_closed

+ 0
- 70
api/core/workflow/graph_engine/output_registry/stream.py Ver arquivo

"""
Internal stream implementation for OutputRegistry.

This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""

from typing import TYPE_CHECKING, final

if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.

This class encapsulates stream-specific data and operations,
including event storage, read position tracking, and closed state.

Note: This is an internal class not exposed in the public API.
"""

def __init__(self) -> None:
"""Initialize an empty stream."""
self.events: list[NodeRunStreamChunkEvent] = []
self.read_position: int = 0
self.is_closed: bool = False

def append(self, event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream.

Args:
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
if self.is_closed:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)

def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

Returns:
The next event, or None if no unread events available
"""
if self.read_position >= len(self.events):
return None

event = self.events[self.read_position]
self.read_position += 1
return event

def has_unread(self) -> bool:
"""
Check if the stream has unread events.

Returns:
True if there are unread events, False otherwise
"""
return self.read_position < len(self.events)

def close(self) -> None:
"""Mark the stream as closed (no more chunks can be appended)."""
self.is_closed = True

+ 107
- 12
api/core/workflow/graph_engine/response_coordinator/coordinator.py Ver arquivo

from typing import TypeAlias, final from typing import TypeAlias, final
from uuid import uuid4 from uuid import uuid4


from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.nodes.base.template import TextSegment, VariableSegment


from ..output_registry import OutputRegistry
from .path import Path from .path import Path
from .session import ResponseSession from .session import ResponseSession


Ensures ordered streaming of responses based on upstream node outputs and constants. Ensures ordered streaming of responses based on upstream node outputs and constants.
""" """


def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
""" """
Initialize coordinator with output registry.
Initialize coordinator with variable pool.


Args: Args:
registry: OutputRegistry instance for accessing node outputs
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information graph: Graph instance for looking up node information
""" """
self.registry = registry
self.variable_pool = variable_pool
self.graph = graph self.graph = graph
self.active_session: ResponseSession | None = None self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque() self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock() self.lock = RLock()


# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()

# Track response nodes # Track response nodes
self._response_nodes: set[NodeID] = set() self._response_nodes: set[NodeID] = set()


) -> Sequence[NodeRunStreamChunkEvent]: ) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock: with self.lock:
if isinstance(event, NodeRunStreamChunkEvent): if isinstance(event, NodeRunStreamChunkEvent):
self.registry.append_chunk(event.selector, event)
self._append_stream_chunk(event.selector, event)
if event.is_final: if event.is_final:
self.registry.close_stream(event.selector)
self._close_stream(event.selector)
return self.try_flush() return self.try_flush()
else: else:
# Skip cause we share the same variable pool. # Skip cause we share the same variable pool.
# #
# for variable_name, variable_value in event.node_run_result.outputs.items(): # for variable_name, variable_value in event.node_run_result.outputs.items():
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
# self.variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush() return self.try_flush()
return [] return []


execution_id = self._get_or_create_execution_id(output_node_id) execution_id = self._get_or_create_execution_id(output_node_id)


# Stream all available chunks # Stream all available chunks
while self.registry.has_unread(segment.selector):
if event := self.registry.pop_chunk(segment.selector):
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use # For special selectors, we need to update the event to use
# the active response node's information # the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes: if self.active_session and source_selector_prefix not in self.graph.nodes:
events.append(event) events.append(event)


# Check if this is the last chunk by looking ahead # Check if this is the last chunk by looking ahead
stream_closed = self.registry.stream_closed(segment.selector)
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete # Check if stream is closed to determine if segment is complete
if stream_closed: if stream_closed:
is_complete = True is_complete = True


elif value := self.registry.get_scalar(segment.selector):
elif value := self.variable_pool.get(segment.selector):
# Process scalar value # Process scalar value
is_last_segment = bool( is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1 self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
events = self.try_flush() events = self.try_flush()


return events return events

# ============= Internal Stream Management Methods =============

def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)

if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")

if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0

self._stream_buffers[key].append(event)

def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.

Args:
selector: List of strings identifying the stream location

Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)

if key not in self._stream_buffers:
return None

position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]

if position >= len(buffer):
return None

event = buffer[position]
self._stream_positions[key] = position + 1
return event

def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.

Args:
selector: List of strings identifying the stream location

Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)

if key not in self._stream_buffers:
return False

position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])

def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).

Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)

def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.

Args:
selector: List of strings identifying the stream location

Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams

+ 0
- 135
api/tests/unit_tests/core/workflow/graph_engine/test_output_registry.py Ver arquivo

from uuid import uuid4

import pytest

from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeType
from core.workflow.graph_engine.output_registry import OutputRegistry
from core.workflow.graph_events import NodeRunStreamChunkEvent


class TestOutputRegistry:
def test_scalar_operations(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Test setting and getting scalar
registry.set_scalar(["node1", "output"], "test_value")

segment = registry.get_scalar(["node1", "output"])
assert segment
assert segment.text == "test_value"

# Test getting non-existent scalar
assert registry.get_scalar(["non_existent"]) is None

def test_stream_operations(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Create test events
event1 = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk1",
is_final=False,
)
event2 = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk2",
is_final=True,
)

# Test appending events
registry.append_chunk(["node1", "stream"], event1)
registry.append_chunk(["node1", "stream"], event2)

# Test has_unread
assert registry.has_unread(["node1", "stream"]) is True

# Test popping events
popped_event1 = registry.pop_chunk(["node1", "stream"])
assert popped_event1 == event1
assert popped_event1.chunk == "chunk1"

popped_event2 = registry.pop_chunk(["node1", "stream"])
assert popped_event2 == event2
assert popped_event2.chunk == "chunk2"

assert registry.pop_chunk(["node1", "stream"]) is None

# Test has_unread after popping all
assert registry.has_unread(["node1", "stream"]) is False

def test_stream_closing(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Test stream is not closed initially
assert registry.stream_closed(["node1", "stream"]) is False

# Test closing stream
registry.close_stream(["node1", "stream"])
assert registry.stream_closed(["node1", "stream"]) is True

# Test appending to closed stream raises error
event = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk",
is_final=False,
)
with pytest.raises(ValueError, match="Stream node1.stream is already closed"):
registry.append_chunk(["node1", "stream"], event)

def test_thread_safety(self):
import threading

variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)
results = []

def append_chunks(thread_id: int):
for i in range(100):
event = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="test_node",
node_type=NodeType.LLM,
selector=["stream"],
chunk=f"thread{thread_id}_chunk{i}",
is_final=False,
)
registry.append_chunk(["stream"], event)

# Start multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=append_chunks, args=(i,))
threads.append(thread)
thread.start()

# Wait for threads
for thread in threads:
thread.join()

# Verify all events are present
events = []
while True:
event = registry.pop_chunk(["stream"])
if event is None:
break
events.append(event)

assert len(events) == 500 # 5 threads * 100 events each
# Verify the events have the expected chunk content format
chunk_texts = [e.chunk for e in events]
for i in range(5):
for j in range(100):
assert f"thread{i}_chunk{j}" in chunk_texts

+ 0
- 347
api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py Ver arquivo

"""Test cases for ResponseStreamCoordinator."""

from unittest.mock import Mock

from core.variables import StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeState, NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine.output_registry import OutputRegistry
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment


class TestResponseStreamCoordinator:
"""Test cases for ResponseStreamCoordinator."""

def test_skip_variable_segment_from_skipped_node(self):
"""Test that VariableSegments from skipped nodes are properly skipped during try_flush."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes
skipped_node = Mock(spec=Node)
skipped_node.id = "skipped_node"
skipped_node.state = NodeState.SKIPPED
skipped_node.node_type = NodeType.LLM

active_node = Mock(spec=Node)
active_node.id = "active_node"
active_node.state = NodeState.TAKEN
active_node.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add some test data to registry for the active node
registry.set_scalar(("active_node", "output"), StringSegment(value="Active output"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with segments from both skipped and active nodes
template = Template(
segments=[
VariableSegment(selector=["skipped_node", "output"]),
TextSegment(text=" - "),
VariableSegment(selector=["active_node", "output"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Verify that:
# 1. The skipped node's variable segment was skipped (index advanced)
# 2. The text segment was processed
# 3. The active node's variable segment was processed
assert len(events) == 2 # TextSegment + VariableSegment from active_node

# Check that the first event is the text segment
assert events[0].chunk == " - "

# Check that the second event is from the active node
assert events[1].chunk == "Active output"
assert events[1].selector == ["active_node", "output"]

# Session should be complete
assert session.is_complete()

def test_process_variable_segment_from_non_skipped_node(self):
"""Test that VariableSegments from non-skipped nodes are processed normally."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes
active_node1 = Mock(spec=Node)
active_node1.id = "node1"
active_node1.state = NodeState.TAKEN
active_node1.node_type = NodeType.LLM

active_node2 = Mock(spec=Node)
active_node2.id = "node2"
active_node2.state = NodeState.TAKEN
active_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add test data to registry
registry.set_scalar(("node1", "output"), StringSegment(value="Output 1"))
registry.set_scalar(("node2", "output"), StringSegment(value="Output 2"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with segments from active nodes
template = Template(
segments=[
VariableSegment(selector=["node1", "output"]),
TextSegment(text=" | "),
VariableSegment(selector=["node2", "output"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Verify all segments were processed
assert len(events) == 3

# Check events in order
assert events[0].chunk == "Output 1"
assert events[0].selector == ["node1", "output"]

assert events[1].chunk == " | "

assert events[2].chunk == "Output 2"
assert events[2].selector == ["node2", "output"]

# Session should be complete
assert session.is_complete()

def test_mixed_skipped_and_active_nodes(self):
"""Test processing with a mix of skipped and active nodes."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes with various states
skipped_node1 = Mock(spec=Node)
skipped_node1.id = "skip1"
skipped_node1.state = NodeState.SKIPPED
skipped_node1.node_type = NodeType.LLM

active_node = Mock(spec=Node)
active_node.id = "active"
active_node.state = NodeState.TAKEN
active_node.node_type = NodeType.LLM

skipped_node2 = Mock(spec=Node)
skipped_node2.id = "skip2"
skipped_node2.state = NodeState.SKIPPED
skipped_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {
"skip1": skipped_node1,
"active": active_node,
"skip2": skipped_node2,
"response_node": response_node,
}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add data only for active node
registry.set_scalar(("active", "result"), StringSegment(value="Active Result"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with mixed segments
template = Template(
segments=[
TextSegment(text="Start: "),
VariableSegment(selector=["skip1", "output"]),
VariableSegment(selector=["active", "result"]),
VariableSegment(selector=["skip2", "output"]),
TextSegment(text=" :End"),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should have: "Start: ", "Active Result", " :End"
assert len(events) == 3

assert events[0].chunk == "Start: "
assert events[1].chunk == "Active Result"
assert events[1].selector == ["active", "result"]
assert events[2].chunk == " :End"

# Session should be complete
assert session.is_complete()

def test_all_variable_segments_skipped(self):
"""Test when all VariableSegments are from skipped nodes."""
# Create mock graph
graph = Mock(spec=Graph)

# Create all skipped nodes
skipped_node1 = Mock(spec=Node)
skipped_node1.id = "skip1"
skipped_node1.state = NodeState.SKIPPED
skipped_node1.node_type = NodeType.LLM

skipped_node2 = Mock(spec=Node)
skipped_node2.id = "skip2"
skipped_node2.state = NodeState.SKIPPED
skipped_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node}

# Create output registry (empty since nodes are skipped) with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with only skipped segments
template = Template(
segments=[
VariableSegment(selector=["skip1", "output"]),
VariableSegment(selector=["skip2", "output"]),
TextSegment(text="Final text"),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should only have the final text segment
assert len(events) == 1
assert events[0].chunk == "Final text"

# Session should be complete
assert session.is_complete()

def test_special_prefix_selectors(self):
"""Test that special prefix selectors (sys, env, conversation) are handled correctly."""
# Create mock graph
graph = Mock(spec=Graph)

# Create response node
response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary (no sys, env, conversation nodes)
graph.nodes = {"response_node": response_node}

# Create output registry with special selector data and variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)
registry.set_scalar(("sys", "user_id"), StringSegment(value="user123"))
registry.set_scalar(("env", "api_key"), StringSegment(value="key456"))
registry.set_scalar(("conversation", "id"), StringSegment(value="conv789"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with special selectors
template = Template(
segments=[
TextSegment(text="User: "),
VariableSegment(selector=["sys", "user_id"]),
TextSegment(text=", API: "),
VariableSegment(selector=["env", "api_key"]),
TextSegment(text=", Conv: "),
VariableSegment(selector=["conversation", "id"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should have all segments processed
assert len(events) == 6

# Check text segments
assert events[0].chunk == "User: "
assert events[0].node_id == "response_node"

# Check sys selector - should use response node's info
assert events[1].chunk == "user123"
assert events[1].selector == ["sys", "user_id"]
assert events[1].node_id == "response_node"
assert events[1].node_type == NodeType.ANSWER

assert events[2].chunk == ", API: "

# Check env selector - should use response node's info
assert events[3].chunk == "key456"
assert events[3].selector == ["env", "api_key"]
assert events[3].node_id == "response_node"
assert events[3].node_type == NodeType.ANSWER

assert events[4].chunk == ", Conv: "

# Check conversation selector - should use response node's info
assert events[5].chunk == "conv789"
assert events[5].selector == ["conversation", "id"]
assert events[5].node_id == "response_node"
assert events[5].node_type == NodeType.ANSWER

# Session should be complete
assert session.is_complete()

Carregando…
Cancelar
Salvar