ソースを参照

refactor(graph_engine): inline output_registry into response_coordinator

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/2.0.0-beta.1
-LAN- 2ヶ月前
コミット
a5cb9d2b73
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 0
- 1
api/.importlinter ファイルの表示

@@ -37,7 +37,6 @@ type = layers
layers =
graph_engine
response_coordinator
output_registry
containers =
core.workflow.graph_engine


+ 3
- 3
api/core/workflow/graph_engine/graph_engine.py ファイルの表示

@@ -35,7 +35,6 @@ from .event_management import EventCollector, EventEmitter, EventHandlerRegistry
from .graph_traversal import BranchHandler, EdgeProcessor, NodeReadinessChecker, SkipPropagator
from .layers.base import Layer
from .orchestration import Dispatcher, ExecutionCoordinator
from .output_registry import OutputRegistry
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import UnifiedStateManager
@@ -122,8 +121,9 @@ class GraphEngine:
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)

# Response coordination
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
self.response_coordinator = ResponseStreamCoordinator(registry=self.output_registry, graph=self.graph)
self.response_coordinator = ResponseStreamCoordinator(
variable_pool=self.graph_runtime_state.variable_pool, graph=self.graph
)

# Event management
self.event_collector = EventCollector()

+ 0
- 10
api/core/workflow/graph_engine/output_registry/__init__.py ファイルの表示

@@ -1,10 +0,0 @@
"""
OutputRegistry - Thread-safe storage for node outputs (streams and scalars)

This component provides thread-safe storage and retrieval of node outputs,
supporting both scalar values and streaming chunks with proper state management.
"""

from .registry import OutputRegistry

__all__ = ["OutputRegistry"]

+ 0
- 148
api/core/workflow/graph_engine/output_registry/registry.py ファイルの表示

@@ -1,148 +0,0 @@
"""
Main OutputRegistry implementation.

This module contains the public OutputRegistry class that provides
thread-safe storage for node outputs.
"""

from collections.abc import Sequence
from threading import RLock
from typing import TYPE_CHECKING, Any, Union, final

from core.variables import Segment
from core.workflow.entities.variable_pool import VariablePool

from .stream import Stream

if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class OutputRegistry:
"""
Thread-safe registry for storing and retrieving node outputs.

Supports both scalar values and streaming chunks with proper state management.
All operations are thread-safe using internal locking.
"""

def __init__(self, variable_pool: VariablePool) -> None:
"""Initialize empty registry with thread-safe storage."""
self._lock = RLock()
self._scalars = variable_pool
self._streams: dict[tuple[str, ...], Stream] = {}

def _selector_to_key(self, selector: Sequence[str]) -> tuple[str, ...]:
"""Convert selector list to tuple key for internal storage."""
return tuple(selector)

def set_scalar(
self, selector: Sequence[str], value: Union[str, int, float, bool, dict[str, Any], list[Any]]
) -> None:
"""
Set a scalar value for the given selector.

Args:
selector: List of strings identifying the output location
value: The scalar value to store
"""
with self._lock:
self._scalars.add(selector, value)

def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
"""
Get a scalar value for the given selector.

Args:
selector: List of strings identifying the output location

Returns:
The stored Variable object, or None if not found
"""
with self._lock:
return self._scalars.get(selector)

def append_chunk(self, selector: Sequence[str], event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream for the given selector.

Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()

try:
self._streams[key].append(event)
except ValueError:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")

def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

Args:
selector: List of strings identifying the stream location

Returns:
The next event, or None if no unread events available
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return None

return self._streams[key].pop_next()

def has_unread(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.

Args:
selector: List of strings identifying the stream location

Returns:
True if there are unread events, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False

return self._streams[key].has_unread()

def close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).

Args:
selector: List of strings identifying the stream location
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
self._streams[key] = Stream()
self._streams[key].close()

def stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.

Args:
selector: List of strings identifying the stream location

Returns:
True if the stream is closed, False otherwise
"""
key = self._selector_to_key(selector)
with self._lock:
if key not in self._streams:
return False
return self._streams[key].is_closed

+ 0
- 70
api/core/workflow/graph_engine/output_registry/stream.py ファイルの表示

@@ -1,70 +0,0 @@
"""
Internal stream implementation for OutputRegistry.

This module contains the private Stream class used internally by OutputRegistry
to manage streaming data chunks.
"""

from typing import TYPE_CHECKING, final

if TYPE_CHECKING:
from core.workflow.graph_events import NodeRunStreamChunkEvent


@final
class Stream:
"""
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.

This class encapsulates stream-specific data and operations,
including event storage, read position tracking, and closed state.

Note: This is an internal class not exposed in the public API.
"""

def __init__(self) -> None:
"""Initialize an empty stream."""
self.events: list[NodeRunStreamChunkEvent] = []
self.read_position: int = 0
self.is_closed: bool = False

def append(self, event: "NodeRunStreamChunkEvent") -> None:
"""
Append a NodeRunStreamChunkEvent to the stream.

Args:
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
if self.is_closed:
raise ValueError("Cannot append to a closed stream")
self.events.append(event)

def pop_next(self) -> "NodeRunStreamChunkEvent | None":
"""
Pop the next unread NodeRunStreamChunkEvent from the stream.

Returns:
The next event, or None if no unread events available
"""
if self.read_position >= len(self.events):
return None

event = self.events[self.read_position]
self.read_position += 1
return event

def has_unread(self) -> bool:
"""
Check if the stream has unread events.

Returns:
True if there are unread events, False otherwise
"""
return self.read_position < len(self.events)

def close(self) -> None:
"""Mark the stream as closed (no more chunks can be appended)."""
self.is_closed = True

+ 107
- 12
api/core/workflow/graph_engine/response_coordinator/coordinator.py ファイルの表示

@@ -12,12 +12,12 @@ from threading import RLock
from typing import TypeAlias, final
from uuid import uuid4

from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment

from ..output_registry import OutputRegistry
from .path import Path
from .session import ResponseSession

@@ -36,20 +36,25 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""

def __init__(self, registry: OutputRegistry, graph: "Graph") -> None:
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
"""
Initialize coordinator with output registry.
Initialize coordinator with variable pool.

Args:
registry: OutputRegistry instance for accessing node outputs
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self.registry = registry
self.variable_pool = variable_pool
self.graph = graph
self.active_session: ResponseSession | None = None
self.waiting_sessions: deque[ResponseSession] = deque()
self.lock = RLock()

# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()

# Track response nodes
self._response_nodes: set[NodeID] = set()

@@ -256,15 +261,15 @@ class ResponseStreamCoordinator:
) -> Sequence[NodeRunStreamChunkEvent]:
with self.lock:
if isinstance(event, NodeRunStreamChunkEvent):
self.registry.append_chunk(event.selector, event)
self._append_stream_chunk(event.selector, event)
if event.is_final:
self.registry.close_stream(event.selector)
self._close_stream(event.selector)
return self.try_flush()
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self.registry.set_scalar((event.node_id, variable_name), variable_value)
# self.variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
return []

@@ -327,8 +332,8 @@ class ResponseStreamCoordinator:
execution_id = self._get_or_create_execution_id(output_node_id)

# Stream all available chunks
while self.registry.has_unread(segment.selector):
if event := self.registry.pop_chunk(segment.selector):
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self.active_session and source_selector_prefix not in self.graph.nodes:
@@ -349,12 +354,12 @@ class ResponseStreamCoordinator:
events.append(event)

# Check if this is the last chunk by looking ahead
stream_closed = self.registry.stream_closed(segment.selector)
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True

elif value := self.registry.get_scalar(segment.selector):
elif value := self.variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self.active_session and self.active_session.index == len(self.active_session.template.segments) - 1
@@ -464,3 +469,93 @@ class ResponseStreamCoordinator:
events = self.try_flush()

return events

# ============= Internal Stream Management Methods =============

def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append

Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)

if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")

if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0

self._stream_buffers[key].append(event)

def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.

Args:
selector: List of strings identifying the stream location

Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)

if key not in self._stream_buffers:
return None

position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]

if position >= len(buffer):
return None

event = buffer[position]
self._stream_positions[key] = position + 1
return event

def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.

Args:
selector: List of strings identifying the stream location

Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)

if key not in self._stream_buffers:
return False

position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])

def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).

Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)

def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.

Args:
selector: List of strings identifying the stream location

Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams

+ 0
- 135
api/tests/unit_tests/core/workflow/graph_engine/test_output_registry.py ファイルの表示

@@ -1,135 +0,0 @@
from uuid import uuid4

import pytest

from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeType
from core.workflow.graph_engine.output_registry import OutputRegistry
from core.workflow.graph_events import NodeRunStreamChunkEvent


class TestOutputRegistry:
def test_scalar_operations(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Test setting and getting scalar
registry.set_scalar(["node1", "output"], "test_value")

segment = registry.get_scalar(["node1", "output"])
assert segment
assert segment.text == "test_value"

# Test getting non-existent scalar
assert registry.get_scalar(["non_existent"]) is None

def test_stream_operations(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Create test events
event1 = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk1",
is_final=False,
)
event2 = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk2",
is_final=True,
)

# Test appending events
registry.append_chunk(["node1", "stream"], event1)
registry.append_chunk(["node1", "stream"], event2)

# Test has_unread
assert registry.has_unread(["node1", "stream"]) is True

# Test popping events
popped_event1 = registry.pop_chunk(["node1", "stream"])
assert popped_event1 == event1
assert popped_event1.chunk == "chunk1"

popped_event2 = registry.pop_chunk(["node1", "stream"])
assert popped_event2 == event2
assert popped_event2.chunk == "chunk2"

assert registry.pop_chunk(["node1", "stream"]) is None

# Test has_unread after popping all
assert registry.has_unread(["node1", "stream"]) is False

def test_stream_closing(self):
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Test stream is not closed initially
assert registry.stream_closed(["node1", "stream"]) is False

# Test closing stream
registry.close_stream(["node1", "stream"])
assert registry.stream_closed(["node1", "stream"]) is True

# Test appending to closed stream raises error
event = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="node1",
node_type=NodeType.LLM,
selector=["node1", "stream"],
chunk="chunk",
is_final=False,
)
with pytest.raises(ValueError, match="Stream node1.stream is already closed"):
registry.append_chunk(["node1", "stream"], event)

def test_thread_safety(self):
import threading

variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)
results = []

def append_chunks(thread_id: int):
for i in range(100):
event = NodeRunStreamChunkEvent(
id=str(uuid4()),
node_id="test_node",
node_type=NodeType.LLM,
selector=["stream"],
chunk=f"thread{thread_id}_chunk{i}",
is_final=False,
)
registry.append_chunk(["stream"], event)

# Start multiple threads
threads = []
for i in range(5):
thread = threading.Thread(target=append_chunks, args=(i,))
threads.append(thread)
thread.start()

# Wait for threads
for thread in threads:
thread.join()

# Verify all events are present
events = []
while True:
event = registry.pop_chunk(["stream"])
if event is None:
break
events.append(event)

assert len(events) == 500 # 5 threads * 100 events each
# Verify the events have the expected chunk content format
chunk_texts = [e.chunk for e in events]
for i in range(5):
for j in range(100):
assert f"thread{i}_chunk{j}" in chunk_texts

+ 0
- 347
api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py ファイルの表示

@@ -1,347 +0,0 @@
"""Test cases for ResponseStreamCoordinator."""

from unittest.mock import Mock

from core.variables import StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeState, NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine.output_registry import OutputRegistry
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment


class TestResponseStreamCoordinator:
"""Test cases for ResponseStreamCoordinator."""

def test_skip_variable_segment_from_skipped_node(self):
"""Test that VariableSegments from skipped nodes are properly skipped during try_flush."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes
skipped_node = Mock(spec=Node)
skipped_node.id = "skipped_node"
skipped_node.state = NodeState.SKIPPED
skipped_node.node_type = NodeType.LLM

active_node = Mock(spec=Node)
active_node.id = "active_node"
active_node.state = NodeState.TAKEN
active_node.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"skipped_node": skipped_node, "active_node": active_node, "response_node": response_node}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add some test data to registry for the active node
registry.set_scalar(("active_node", "output"), StringSegment(value="Active output"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with segments from both skipped and active nodes
template = Template(
segments=[
VariableSegment(selector=["skipped_node", "output"]),
TextSegment(text=" - "),
VariableSegment(selector=["active_node", "output"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Verify that:
# 1. The skipped node's variable segment was skipped (index advanced)
# 2. The text segment was processed
# 3. The active node's variable segment was processed
assert len(events) == 2 # TextSegment + VariableSegment from active_node

# Check that the first event is the text segment
assert events[0].chunk == " - "

# Check that the second event is from the active node
assert events[1].chunk == "Active output"
assert events[1].selector == ["active_node", "output"]

# Session should be complete
assert session.is_complete()

def test_process_variable_segment_from_non_skipped_node(self):
"""Test that VariableSegments from non-skipped nodes are processed normally."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes
active_node1 = Mock(spec=Node)
active_node1.id = "node1"
active_node1.state = NodeState.TAKEN
active_node1.node_type = NodeType.LLM

active_node2 = Mock(spec=Node)
active_node2.id = "node2"
active_node2.state = NodeState.TAKEN
active_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"node1": active_node1, "node2": active_node2, "response_node": response_node}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add test data to registry
registry.set_scalar(("node1", "output"), StringSegment(value="Output 1"))
registry.set_scalar(("node2", "output"), StringSegment(value="Output 2"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with segments from active nodes
template = Template(
segments=[
VariableSegment(selector=["node1", "output"]),
TextSegment(text=" | "),
VariableSegment(selector=["node2", "output"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Verify all segments were processed
assert len(events) == 3

# Check events in order
assert events[0].chunk == "Output 1"
assert events[0].selector == ["node1", "output"]

assert events[1].chunk == " | "

assert events[2].chunk == "Output 2"
assert events[2].selector == ["node2", "output"]

# Session should be complete
assert session.is_complete()

def test_mixed_skipped_and_active_nodes(self):
"""Test processing with a mix of skipped and active nodes."""
# Create mock graph
graph = Mock(spec=Graph)

# Create mock nodes with various states
skipped_node1 = Mock(spec=Node)
skipped_node1.id = "skip1"
skipped_node1.state = NodeState.SKIPPED
skipped_node1.node_type = NodeType.LLM

active_node = Mock(spec=Node)
active_node.id = "active"
active_node.state = NodeState.TAKEN
active_node.node_type = NodeType.LLM

skipped_node2 = Mock(spec=Node)
skipped_node2.id = "skip2"
skipped_node2.state = NodeState.SKIPPED
skipped_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {
"skip1": skipped_node1,
"active": active_node,
"skip2": skipped_node2,
"response_node": response_node,
}

# Create output registry with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Add data only for active node
registry.set_scalar(("active", "result"), StringSegment(value="Active Result"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with mixed segments
template = Template(
segments=[
TextSegment(text="Start: "),
VariableSegment(selector=["skip1", "output"]),
VariableSegment(selector=["active", "result"]),
VariableSegment(selector=["skip2", "output"]),
TextSegment(text=" :End"),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should have: "Start: ", "Active Result", " :End"
assert len(events) == 3

assert events[0].chunk == "Start: "
assert events[1].chunk == "Active Result"
assert events[1].selector == ["active", "result"]
assert events[2].chunk == " :End"

# Session should be complete
assert session.is_complete()

def test_all_variable_segments_skipped(self):
"""Test when all VariableSegments are from skipped nodes."""
# Create mock graph
graph = Mock(spec=Graph)

# Create all skipped nodes
skipped_node1 = Mock(spec=Node)
skipped_node1.id = "skip1"
skipped_node1.state = NodeState.SKIPPED
skipped_node1.node_type = NodeType.LLM

skipped_node2 = Mock(spec=Node)
skipped_node2.id = "skip2"
skipped_node2.state = NodeState.SKIPPED
skipped_node2.node_type = NodeType.LLM

response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary
graph.nodes = {"skip1": skipped_node1, "skip2": skipped_node2, "response_node": response_node}

# Create output registry (empty since nodes are skipped) with variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with only skipped segments
template = Template(
segments=[
VariableSegment(selector=["skip1", "output"]),
VariableSegment(selector=["skip2", "output"]),
TextSegment(text="Final text"),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should only have the final text segment
assert len(events) == 1
assert events[0].chunk == "Final text"

# Session should be complete
assert session.is_complete()

def test_special_prefix_selectors(self):
"""Test that special prefix selectors (sys, env, conversation) are handled correctly."""
# Create mock graph
graph = Mock(spec=Graph)

# Create response node
response_node = Mock(spec=AnswerNode)
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER

# Set up graph nodes dictionary (no sys, env, conversation nodes)
graph.nodes = {"response_node": response_node}

# Create output registry with special selector data and variable pool
variable_pool = VariablePool()
registry = OutputRegistry(variable_pool)
registry.set_scalar(("sys", "user_id"), StringSegment(value="user123"))
registry.set_scalar(("env", "api_key"), StringSegment(value="key456"))
registry.set_scalar(("conversation", "id"), StringSegment(value="conv789"))

# Create RSC instance
rsc = ResponseStreamCoordinator(registry=registry, graph=graph)

# Create template with special selectors
template = Template(
segments=[
TextSegment(text="User: "),
VariableSegment(selector=["sys", "user_id"]),
TextSegment(text=", API: "),
VariableSegment(selector=["env", "api_key"]),
TextSegment(text=", Conv: "),
VariableSegment(selector=["conversation", "id"]),
]
)

# Create and set active session
session = ResponseSession(node_id="response_node", template=template, index=0)
rsc.active_session = session

# Execute try_flush
events = rsc.try_flush()

# Should have all segments processed
assert len(events) == 6

# Check text segments
assert events[0].chunk == "User: "
assert events[0].node_id == "response_node"

# Check sys selector - should use response node's info
assert events[1].chunk == "user123"
assert events[1].selector == ["sys", "user_id"]
assert events[1].node_id == "response_node"
assert events[1].node_type == NodeType.ANSWER

assert events[2].chunk == ", API: "

# Check env selector - should use response node's info
assert events[3].chunk == "key456"
assert events[3].selector == ["env", "api_key"]
assert events[3].node_id == "response_node"
assert events[3].node_type == NodeType.ANSWER

assert events[4].chunk == ", Conv: "

# Check conversation selector - should use response node's info
assert events[5].chunk == "conv789"
assert events[5].selector == ["conversation", "id"]
assert events[5].node_id == "response_node"
assert events[5].node_type == NodeType.ANSWER

# Session should be complete
assert session.is_complete()

読み込み中…
キャンセル
保存