- Replace direct field access with private attributes and property decorators - Implement deep copy protection for mutable objects (dict, LLMUsage) - Add helper methods: set_output(), get_output(), update_outputs() - Add increment_node_run_steps() and add_tokens() convenience methods - Update loop_node and event_handlers to use new accessor methods - Add comprehensive unit tests for immutability and validation - Ensure backward compatibility with existing property access patternstags/2.0.0-beta.1
| @@ -1,6 +1,7 @@ | |||
| from copy import deepcopy | |||
| from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import BaseModel, PrivateAttr | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| @@ -8,21 +9,132 @@ from .variable_pool import VariablePool | |||
| class GraphRuntimeState(BaseModel): | |||
| variable_pool: VariablePool = Field(..., description="variable pool") | |||
| """variable pool""" | |||
| start_at: float = Field(..., description="start time") | |||
| """start time""" | |||
| total_tokens: int = 0 | |||
| """total tokens""" | |||
| llm_usage: LLMUsage = LLMUsage.empty_usage() | |||
| """llm usage info""" | |||
| # The `outputs` field stores the final output values generated by executing workflows or chatflows. | |||
| # | |||
| # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent | |||
| # after a serialization and deserialization round trip. | |||
| outputs: dict[str, Any] = Field(default_factory=dict) | |||
| node_run_steps: int = 0 | |||
| """node run steps""" | |||
| # Private attributes to prevent direct modification | |||
| _variable_pool: VariablePool = PrivateAttr() | |||
| _start_at: float = PrivateAttr() | |||
| _total_tokens: int = PrivateAttr(default=0) | |||
| _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) | |||
| _outputs: dict[str, Any] = PrivateAttr(default_factory=dict) | |||
| _node_run_steps: int = PrivateAttr(default=0) | |||
| def __init__( | |||
| self, | |||
| variable_pool: VariablePool, | |||
| start_at: float, | |||
| total_tokens: int = 0, | |||
| llm_usage: LLMUsage | None = None, | |||
| outputs: dict[str, Any] | None = None, | |||
| node_run_steps: int = 0, | |||
| **kwargs, | |||
| ): | |||
| """Initialize the GraphRuntimeState with validation.""" | |||
| super().__init__(**kwargs) | |||
| # Initialize private attributes with validation | |||
| self._variable_pool = variable_pool | |||
| self._start_at = start_at | |||
| if total_tokens < 0: | |||
| raise ValueError("total_tokens must be non-negative") | |||
| self._total_tokens = total_tokens | |||
| if llm_usage is None: | |||
| llm_usage = LLMUsage.empty_usage() | |||
| self._llm_usage = llm_usage | |||
| if outputs is None: | |||
| outputs = {} | |||
| self._outputs = deepcopy(outputs) | |||
| if node_run_steps < 0: | |||
| raise ValueError("node_run_steps must be non-negative") | |||
| self._node_run_steps = node_run_steps | |||
| @property | |||
| def variable_pool(self) -> VariablePool: | |||
| """Get the variable pool.""" | |||
| return self._variable_pool | |||
| @variable_pool.setter | |||
| def variable_pool(self, value: VariablePool) -> None: | |||
| """Set the variable pool.""" | |||
| self._variable_pool = value | |||
| @property | |||
| def start_at(self) -> float: | |||
| """Get the start time.""" | |||
| return self._start_at | |||
| @start_at.setter | |||
| def start_at(self, value: float) -> None: | |||
| """Set the start time.""" | |||
| self._start_at = value | |||
| @property | |||
| def total_tokens(self) -> int: | |||
| """Get the total tokens count.""" | |||
| return self._total_tokens | |||
| @total_tokens.setter | |||
| def total_tokens(self, value: int): | |||
| """Set the total tokens count.""" | |||
| if value < 0: | |||
| raise ValueError("total_tokens must be non-negative") | |||
| self._total_tokens = value | |||
| @property | |||
| def llm_usage(self) -> LLMUsage: | |||
| """Get the LLM usage info.""" | |||
| # Return a copy to prevent external modification | |||
| return self._llm_usage.model_copy() | |||
| @llm_usage.setter | |||
| def llm_usage(self, value: LLMUsage): | |||
| """Set the LLM usage info.""" | |||
| self._llm_usage = value.model_copy() | |||
| @property | |||
| def outputs(self) -> dict[str, Any]: | |||
| """Get a copy of the outputs dictionary.""" | |||
| return deepcopy(self._outputs) | |||
| @outputs.setter | |||
| def outputs(self, value: dict[str, Any]) -> None: | |||
| """Set the outputs dictionary.""" | |||
| self._outputs = deepcopy(value) | |||
| def set_output(self, key: str, value: Any) -> None: | |||
| """Set a single output value.""" | |||
| self._outputs[key] = deepcopy(value) | |||
| def get_output(self, key: str, default: Any = None) -> Any: | |||
| """Get a single output value.""" | |||
| return deepcopy(self._outputs.get(key, default)) | |||
| def update_outputs(self, updates: dict[str, Any]) -> None: | |||
| """Update multiple output values.""" | |||
| for key, value in updates.items(): | |||
| self._outputs[key] = deepcopy(value) | |||
| @property | |||
| def node_run_steps(self) -> int: | |||
| """Get the node run steps count.""" | |||
| return self._node_run_steps | |||
| @node_run_steps.setter | |||
| def node_run_steps(self, value: int) -> None: | |||
| """Set the node run steps count.""" | |||
| if value < 0: | |||
| raise ValueError("node_run_steps must be non-negative") | |||
| self._node_run_steps = value | |||
| def increment_node_run_steps(self) -> None: | |||
| """Increment the node run steps by 1.""" | |||
| self._node_run_steps += 1 | |||
| def add_tokens(self, tokens: int) -> None: | |||
| """Add tokens to the total count.""" | |||
| if tokens < 0: | |||
| raise ValueError("tokens must be non-negative") | |||
| self._total_tokens += tokens | |||
| @@ -267,10 +267,10 @@ class EventHandler: | |||
| # in runtime state, rather than allowing nodes to directly access runtime state. | |||
| for key, value in event.node_run_result.outputs.items(): | |||
| if key == "answer": | |||
| existing = self._graph_runtime_state.outputs.get("answer", "") | |||
| existing = self._graph_runtime_state.get_output("answer", "") | |||
| if existing: | |||
| self._graph_runtime_state.outputs["answer"] = f"{existing}{value}" | |||
| self._graph_runtime_state.set_output("answer", f"{existing}{value}") | |||
| else: | |||
| self._graph_runtime_state.outputs["answer"] = value | |||
| self._graph_runtime_state.set_output("answer", value) | |||
| else: | |||
| self._graph_runtime_state.outputs[key] = value | |||
| self._graph_runtime_state.set_output(key, value) | |||
| @@ -147,14 +147,14 @@ class LoopNode(Node): | |||
| for key, value in graph_engine.graph_runtime_state.outputs.items(): | |||
| if key == "answer": | |||
| # Concatenate answer outputs with newline | |||
| existing_answer = self.graph_runtime_state.outputs.get("answer", "") | |||
| existing_answer = self.graph_runtime_state.get_output("answer", "") | |||
| if existing_answer: | |||
| self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}" | |||
| self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") | |||
| else: | |||
| self.graph_runtime_state.outputs["answer"] = value | |||
| self.graph_runtime_state.set_output("answer", value) | |||
| else: | |||
| # For other outputs, just update | |||
| self.graph_runtime_state.outputs[key] = value | |||
| self.graph_runtime_state.set_output(key, value) | |||
| # Update the total tokens from this iteration | |||
| cost_tokens += graph_engine.graph_runtime_state.total_tokens | |||
| @@ -0,0 +1,114 @@ | |||
| from time import time | |||
| import pytest | |||
| from core.workflow.entities.graph_runtime_state import GraphRuntimeState | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| class TestGraphRuntimeState: | |||
| def test_property_getters_and_setters(self): | |||
| # FIXME(-LAN-): Mock VariablePool if needed | |||
| variable_pool = VariablePool() | |||
| start_time = time() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) | |||
| # Test variable_pool property | |||
| assert state.variable_pool == variable_pool | |||
| new_pool = VariablePool() | |||
| state.variable_pool = new_pool | |||
| assert state.variable_pool == new_pool | |||
| # Test start_at property | |||
| assert state.start_at == start_time | |||
| new_time = time() + 100 | |||
| state.start_at = new_time | |||
| assert state.start_at == new_time | |||
| # Test total_tokens property | |||
| assert state.total_tokens == 0 | |||
| state.total_tokens = 100 | |||
| assert state.total_tokens == 100 | |||
| # Test node_run_steps property | |||
| assert state.node_run_steps == 0 | |||
| state.node_run_steps = 5 | |||
| assert state.node_run_steps == 5 | |||
| def test_outputs_immutability(self): | |||
| variable_pool = VariablePool() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) | |||
| # Test that getting outputs returns a copy | |||
| outputs1 = state.outputs | |||
| outputs2 = state.outputs | |||
| assert outputs1 == outputs2 | |||
| assert outputs1 is not outputs2 # Different objects | |||
| # Test that modifying retrieved outputs doesn't affect internal state | |||
| outputs = state.outputs | |||
| outputs["test"] = "value" | |||
| assert "test" not in state.outputs | |||
| # Test set_output method | |||
| state.set_output("key1", "value1") | |||
| assert state.get_output("key1") == "value1" | |||
| # Test update_outputs method | |||
| state.update_outputs({"key2": "value2", "key3": "value3"}) | |||
| assert state.get_output("key2") == "value2" | |||
| assert state.get_output("key3") == "value3" | |||
| def test_llm_usage_immutability(self): | |||
| variable_pool = VariablePool() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) | |||
| # Test that getting llm_usage returns a copy | |||
| usage1 = state.llm_usage | |||
| usage2 = state.llm_usage | |||
| assert usage1 is not usage2 # Different objects | |||
| def test_type_validation(self): | |||
| variable_pool = VariablePool() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) | |||
| # Test total_tokens validation | |||
| with pytest.raises(ValueError): | |||
| state.total_tokens = -1 | |||
| # Test node_run_steps validation | |||
| with pytest.raises(ValueError): | |||
| state.node_run_steps = -1 | |||
| def test_helper_methods(self): | |||
| variable_pool = VariablePool() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) | |||
| # Test increment_node_run_steps | |||
| initial_steps = state.node_run_steps | |||
| state.increment_node_run_steps() | |||
| assert state.node_run_steps == initial_steps + 1 | |||
| # Test add_tokens | |||
| initial_tokens = state.total_tokens | |||
| state.add_tokens(50) | |||
| assert state.total_tokens == initial_tokens + 50 | |||
| # Test add_tokens validation | |||
| with pytest.raises(ValueError): | |||
| state.add_tokens(-1) | |||
| def test_deep_copy_for_nested_objects(self): | |||
| variable_pool = VariablePool() | |||
| state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) | |||
| # Test deep copy for nested dict | |||
| nested_data = {"level1": {"level2": {"value": "test"}}} | |||
| state.set_output("nested", nested_data) | |||
| retrieved = state.get_output("nested") | |||
| retrieved["level1"]["level2"]["value"] = "modified" | |||
| # Original should remain unchanged | |||
| assert state.get_output("nested")["level1"]["level2"]["value"] == "test" | |||