Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.8.0
| # | # | ||||
| # If the selector length is more than 2, the remaining parts are the keys / indexes paths used | # If the selector length is more than 2, the remaining parts are the keys / indexes paths used | ||||
| # to extract part of the variable value. | # to extract part of the variable value. | ||||
| MIN_SELECTORS_LENGTH = 2 | |||||
| SELECTORS_LENGTH = 2 |
| from core.file import File, FileAttribute, file_manager | from core.file import File, FileAttribute, file_manager | ||||
| from core.variables import Segment, SegmentGroup, Variable | from core.variables import Segment, SegmentGroup, Variable | ||||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||||
| from core.variables.segments import FileSegment, NoneSegment | |||||
| from core.variables.consts import SELECTORS_LENGTH | |||||
| from core.variables.segments import FileSegment, ObjectSegment | |||||
| from core.variables.variables import VariableUnion | from core.variables.variables import VariableUnion | ||||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | ||||
| from core.workflow.system_variable import SystemVariable | from core.workflow.system_variable import SystemVariable | ||||
| # The first element of the selector is the node id, it's the first-level key in the dictionary. | # The first element of the selector is the node id, it's the first-level key in the dictionary. | ||||
| # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the | # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the | ||||
| # elements of the selector except the first one. | # elements of the selector except the first one. | ||||
| variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( | |||||
| variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( | |||||
| description="Variables mapping", | description="Variables mapping", | ||||
| default=defaultdict(dict), | default=defaultdict(dict), | ||||
| ) | ) | ||||
| ) | ) | ||||
| system_variables: SystemVariable = Field( | system_variables: SystemVariable = Field( | ||||
| description="System variables", | description="System variables", | ||||
| default_factory=SystemVariable.empty, | |||||
| ) | ) | ||||
| environment_variables: Sequence[VariableUnion] = Field( | environment_variables: Sequence[VariableUnion] = Field( | ||||
| description="Environment variables.", | description="Environment variables.", | ||||
| def add(self, selector: Sequence[str], value: Any, /) -> None: | def add(self, selector: Sequence[str], value: Any, /) -> None: | ||||
| """ | """ | ||||
| Adds a variable to the variable pool. | |||||
| Add a variable to the variable pool. | |||||
| NOTE: You should not add a non-Segment value to the variable pool | |||||
| even if it is allowed now. | |||||
| This method accepts a selector path and a value, converting the value | |||||
| to a Variable object if necessary before storing it in the pool. | |||||
| Args: | Args: | ||||
| selector (Sequence[str]): The selector for the variable. | |||||
| value (VariableValue): The value of the variable. | |||||
| selector: A two-element sequence containing [node_id, variable_name]. | |||||
| The selector must have exactly 2 elements to be valid. | |||||
| value: The value to store. Can be a Variable, Segment, or any value | |||||
| that can be converted to a Segment (str, int, float, dict, list, File). | |||||
| Raises: | Raises: | ||||
| ValueError: If the selector is invalid. | |||||
| ValueError: If selector length is not exactly 2 elements. | |||||
| Returns: | |||||
| None | |||||
| Note: | |||||
| While non-Segment values are currently accepted and automatically | |||||
| converted, it's recommended to pass Segment or Variable objects directly. | |||||
| """ | """ | ||||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||||
| raise ValueError("Invalid selector") | |||||
| if len(selector) != SELECTORS_LENGTH: | |||||
| raise ValueError( | |||||
| f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " | |||||
| f"got {len(selector)} elements" | |||||
| ) | |||||
| if isinstance(value, Variable): | if isinstance(value, Variable): | ||||
| variable = value | variable = value | ||||
| segment = variable_factory.build_segment(value) | segment = variable_factory.build_segment(value) | ||||
| variable = variable_factory.segment_to_variable(segment=segment, selector=selector) | variable = variable_factory.segment_to_variable(segment=segment, selector=selector) | ||||
| key, hash_key = self._selector_to_keys(selector) | |||||
| node_id, name = self._selector_to_keys(selector) | |||||
| # Based on the definition of `VariableUnion`, | # Based on the definition of `VariableUnion`, | ||||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | ||||
| self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) | |||||
| self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) | |||||
| @classmethod | @classmethod | ||||
| def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: | |||||
| return selector[0], hash(tuple(selector[1:])) | |||||
| def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: | |||||
| return selector[0], selector[1] | |||||
| def _has(self, selector: Sequence[str]) -> bool: | def _has(self, selector: Sequence[str]) -> bool: | ||||
| key, hash_key = self._selector_to_keys(selector) | |||||
| if key not in self.variable_dictionary: | |||||
| node_id, name = self._selector_to_keys(selector) | |||||
| if node_id not in self.variable_dictionary: | |||||
| return False | return False | ||||
| if hash_key not in self.variable_dictionary[key]: | |||||
| if name not in self.variable_dictionary[node_id]: | |||||
| return False | return False | ||||
| return True | return True | ||||
| def get(self, selector: Sequence[str], /) -> Segment | None: | def get(self, selector: Sequence[str], /) -> Segment | None: | ||||
| """ | """ | ||||
| Retrieves the value from the variable pool based on the given selector. | |||||
| Retrieve a variable's value from the pool as a Segment. | |||||
| This method supports both simple selectors [node_id, variable_name] and | |||||
| extended selectors that include attribute access for FileSegment and | |||||
| ObjectSegment types. | |||||
| Args: | Args: | ||||
| selector (Sequence[str]): The selector used to identify the variable. | |||||
| selector: A sequence with at least 2 elements: | |||||
| - [node_id, variable_name]: Returns the full segment | |||||
| - [node_id, variable_name, attr, ...]: Returns a nested value | |||||
| from FileSegment (e.g., 'url', 'name') or ObjectSegment | |||||
| Returns: | Returns: | ||||
| Any: The value associated with the given selector. | |||||
| The Segment associated with the selector, or None if not found. | |||||
| Returns None if selector has fewer than 2 elements. | |||||
| Raises: | Raises: | ||||
| ValueError: If the selector is invalid. | |||||
| ValueError: If attempting to access an invalid FileAttribute. | |||||
| """ | """ | ||||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||||
| if len(selector) < SELECTORS_LENGTH: | |||||
| return None | return None | ||||
| key, hash_key = self._selector_to_keys(selector) | |||||
| value: Segment | None = self.variable_dictionary[key].get(hash_key) | |||||
| node_id, name = self._selector_to_keys(selector) | |||||
| segment: Segment | None = self.variable_dictionary[node_id].get(name) | |||||
| if segment is None: | |||||
| return None | |||||
| if len(selector) == 2: | |||||
| return segment | |||||
| if value is None: | |||||
| selector, attr = selector[:-1], selector[-1] | |||||
| if isinstance(segment, FileSegment): | |||||
| attr = selector[2] | |||||
| # Python support `attr in FileAttribute` after 3.12 | # Python support `attr in FileAttribute` after 3.12 | ||||
| if attr not in {item.value for item in FileAttribute}: | if attr not in {item.value for item in FileAttribute}: | ||||
| return None | return None | ||||
| value = self.get(selector) | |||||
| if not isinstance(value, FileSegment | NoneSegment): | |||||
| attr = FileAttribute(attr) | |||||
| attr_value = file_manager.get_attr(file=segment.value, attr=attr) | |||||
| return variable_factory.build_segment(attr_value) | |||||
| # Navigate through nested attributes | |||||
| result: Any = segment | |||||
| for attr in selector[2:]: | |||||
| result = self._extract_value(result) | |||||
| result = self._get_nested_attribute(result, attr) | |||||
| if result is None: | |||||
| return None | return None | ||||
| if isinstance(value, FileSegment): | |||||
| attr = FileAttribute(attr) | |||||
| attr_value = file_manager.get_attr(file=value.value, attr=attr) | |||||
| return variable_factory.build_segment(attr_value) | |||||
| return value | |||||
| return value | |||||
| # Return result as Segment | |||||
| return result if isinstance(result, Segment) else variable_factory.build_segment(result) | |||||
| def _extract_value(self, obj: Any) -> Any: | |||||
| """Extract the actual value from an ObjectSegment.""" | |||||
| return obj.value if isinstance(obj, ObjectSegment) else obj | |||||
| def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any: | |||||
| """Get a nested attribute from a dictionary-like object.""" | |||||
| if not isinstance(obj, dict): | |||||
| return None | |||||
| return obj.get(attr) | |||||
| def remove(self, selector: Sequence[str], /): | def remove(self, selector: Sequence[str], /): | ||||
| """ | """ |
| from core.app.apps.exc import GenerateTaskStoppedError | from core.app.apps.exc import GenerateTaskStoppedError | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult | from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult | ||||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | ||||
| from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager | from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager | ||||
| from core.workflow.graph_engine.entities.event import ( | from core.workflow.graph_engine.entities.event import ( | ||||
| from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor | ||||
| from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle | ||||
| from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent | ||||
| from core.workflow.utils import variable_utils | |||||
| from libs.flask_utils import preserve_flask_contexts | from libs.flask_utils import preserve_flask_contexts | ||||
| from models.enums import UserFrom | from models.enums import UserFrom | ||||
| from models.workflow import WorkflowType | from models.workflow import WorkflowType | ||||
| route_node_state.status = RouteNodeState.Status.EXCEPTION | route_node_state.status = RouteNodeState.Status.EXCEPTION | ||||
| if run_result.outputs: | if run_result.outputs: | ||||
| for variable_key, variable_value in run_result.outputs.items(): | for variable_key, variable_value in run_result.outputs.items(): | ||||
| # append variables to variable pool recursively | |||||
| self._append_variables_recursively( | |||||
| node_id=node.node_id, | |||||
| variable_key_list=[variable_key], | |||||
| variable_value=variable_value, | |||||
| # Add variables to variable pool | |||||
| self.graph_runtime_state.variable_pool.add( | |||||
| [node.node_id, variable_key], variable_value | |||||
| ) | ) | ||||
| yield NodeRunExceptionEvent( | yield NodeRunExceptionEvent( | ||||
| error=run_result.error or "System Error", | error=run_result.error or "System Error", | ||||
| # append node output variables to variable pool | # append node output variables to variable pool | ||||
| if run_result.outputs: | if run_result.outputs: | ||||
| for variable_key, variable_value in run_result.outputs.items(): | for variable_key, variable_value in run_result.outputs.items(): | ||||
| # append variables to variable pool recursively | |||||
| self._append_variables_recursively( | |||||
| node_id=node.node_id, | |||||
| variable_key_list=[variable_key], | |||||
| variable_value=variable_value, | |||||
| # Add variables to variable pool | |||||
| self.graph_runtime_state.variable_pool.add( | |||||
| [node.node_id, variable_key], variable_value | |||||
| ) | ) | ||||
| # When setting metadata, convert to dict first | # When setting metadata, convert to dict first | ||||
| logger.exception("Node %s run failed", node.title) | logger.exception("Node %s run failed", node.title) | ||||
| raise e | raise e | ||||
| def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||||
| """ | |||||
| Append variables recursively | |||||
| :param node_id: node id | |||||
| :param variable_key_list: variable key list | |||||
| :param variable_value: variable value | |||||
| :return: | |||||
| """ | |||||
| variable_utils.append_variables_recursively( | |||||
| self.graph_runtime_state.variable_pool, | |||||
| node_id, | |||||
| variable_key_list, | |||||
| variable_value, | |||||
| ) | |||||
| def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: | def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: | ||||
| """ | """ | ||||
| Check timeout | Check timeout |
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.variables import Segment | from core.variables import Segment | ||||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||||
| from core.variables.consts import SELECTORS_LENGTH | |||||
| from core.variables.types import SegmentType | from core.variables.types import SegmentType | ||||
| # Use double underscore (`__`) prefix for internal variables | # Use double underscore (`__`) prefix for internal variables | ||||
| def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: | def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: | ||||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||||
| if len(selector) < SELECTORS_LENGTH: | |||||
| raise Exception("selector too short") | raise Exception("selector too short") | ||||
| node_id, var_name = selector[:2] | node_id, var_name = selector[:2] | ||||
| return UpdatedVariable( | return UpdatedVariable( |
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.variables import SegmentType, Variable | from core.variables import SegmentType, Variable | ||||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||||
| from core.variables.consts import SELECTORS_LENGTH | |||||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID | from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID | ||||
| from core.workflow.conversation_variable_updater import ConversationVariableUpdater | from core.workflow.conversation_variable_updater import ConversationVariableUpdater | ||||
| from core.workflow.entities.node_entities import NodeRunResult | from core.workflow.entities.node_entities import NodeRunResult | ||||
| selector = item.value | selector = item.value | ||||
| if not isinstance(selector, list): | if not isinstance(selector, list): | ||||
| raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") | raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") | ||||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||||
| if len(selector) < SELECTORS_LENGTH: | |||||
| raise InvalidDataError(f"selector too short, {node_id=}, {item=}") | raise InvalidDataError(f"selector too short, {node_id=}, {item=}") | ||||
| selector_str = ".".join(selector) | selector_str = ".".join(selector) | ||||
| key = f"{node_id}.#{selector_str}#" | key = f"{node_id}.#{selector_str}#" |
| from core.variables.segments import ObjectSegment, Segment | |||||
| from core.workflow.entities.variable_pool import VariablePool, VariableValue | |||||
| def append_variables_recursively( | |||||
| pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment | |||||
| ): | |||||
| """ | |||||
| Append variables recursively | |||||
| :param pool: variable pool to append variables to | |||||
| :param node_id: node id | |||||
| :param variable_key_list: variable key list | |||||
| :param variable_value: variable value | |||||
| :return: | |||||
| """ | |||||
| pool.add([node_id] + variable_key_list, variable_value) | |||||
| # if variable_value is a dict, then recursively append variables | |||||
| if isinstance(variable_value, ObjectSegment): | |||||
| variable_dict = variable_value.value | |||||
| elif isinstance(variable_value, dict): | |||||
| variable_dict = variable_value | |||||
| else: | |||||
| return | |||||
| for key, value in variable_dict.items(): | |||||
| # construct new key list | |||||
| new_key_list = variable_key_list + [key] | |||||
| append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) |
| from typing import Any, Protocol | from typing import Any, Protocol | ||||
| from core.variables import Variable | from core.variables import Variable | ||||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||||
| from core.variables.consts import SELECTORS_LENGTH | |||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.utils import variable_utils | |||||
| class VariableLoader(Protocol): | class VariableLoader(Protocol): | ||||
| variables_to_load.append(list(selector)) | variables_to_load.append(list(selector)) | ||||
| loaded = variable_loader.load_variables(variables_to_load) | loaded = variable_loader.load_variables(variables_to_load) | ||||
| for var in loaded: | for var in loaded: | ||||
| assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" | |||||
| variable_utils.append_variables_recursively( | |||||
| variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var | |||||
| ) | |||||
| assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" | |||||
| # Add variable directly to the pool | |||||
| # The variable pool expects 2-element selectors [node_id, variable_name] | |||||
| variable_pool.add([var.selector[0], var.selector[1]], var) |
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.file.models import File | from core.file.models import File | ||||
| from core.variables import Segment, StringSegment, Variable | from core.variables import Segment, StringSegment, Variable | ||||
| from core.variables.consts import MIN_SELECTORS_LENGTH | |||||
| from core.variables.consts import SELECTORS_LENGTH | |||||
| from core.variables.segments import ArrayFileSegment, FileSegment | from core.variables.segments import ArrayFileSegment, FileSegment | ||||
| from core.variables.types import SegmentType | from core.variables.types import SegmentType | ||||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | ||||
| ) -> list[WorkflowDraftVariable]: | ) -> list[WorkflowDraftVariable]: | ||||
| ors = [] | ors = [] | ||||
| for selector in selectors: | for selector in selectors: | ||||
| assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" | |||||
| assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}" | |||||
| node_id, name = selector[:2] | node_id, name = selector[:2] | ||||
| ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) | ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) | ||||
| for item in updated_variables: | for item in updated_variables: | ||||
| selector = item.selector | selector = item.selector | ||||
| if len(selector) < MIN_SELECTORS_LENGTH: | |||||
| if len(selector) < SELECTORS_LENGTH: | |||||
| raise Exception("selector too short") | raise Exception("selector too short") | ||||
| # NOTE(QuantumGhost): only the following two kinds of variable could be updated by | # NOTE(QuantumGhost): only the following two kinds of variable could be updated by | ||||
| # VariableAssigner: ConversationVariable and iteration variable. | # VariableAssigner: ConversationVariable and iteration variable. |
| def test_use_long_selector(pool): | def test_use_long_selector(pool): | ||||
| pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) | |||||
| # The add method now only accepts 2-element selectors (node_id, variable_name) | |||||
| # Store nested data as an ObjectSegment instead | |||||
| nested_data = {"part_2": "test_value"} | |||||
| pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data)) | |||||
| # The get method supports longer selectors for nested access | |||||
| result = pool.get(("node_1", "part_1", "part_2")) | result = pool.get(("node_1", "part_1", "part_2")) | ||||
| assert result is not None | assert result is not None | ||||
| assert result.value == "test_value" | assert result.value == "test_value" | ||||
| pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) | pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) | ||||
| pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) | pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) | ||||
| # Add nested variables | |||||
| pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) | |||||
| # Add nested variables as ObjectSegment | |||||
| # The add method only accepts 2-element selectors | |||||
| nested_obj = {"deep": {"var": "deep_value"}} | |||||
| pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj)) | |||||
| def test_system_variables(self): | def test_system_variables(self): | ||||
| sys_vars = SystemVariable( | sys_vars = SystemVariable( |
| from typing import Any | |||||
| from core.variables.segments import ObjectSegment, StringSegment | |||||
| from core.workflow.entities.variable_pool import VariablePool | |||||
| from core.workflow.utils.variable_utils import append_variables_recursively | |||||
| class TestAppendVariablesRecursively: | |||||
| """Test cases for append_variables_recursively function""" | |||||
| def test_append_simple_dict_value(self): | |||||
| """Test appending a simple dictionary value""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["output"] | |||||
| variable_value = {"name": "John", "age": 30} | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check that the main variable is added | |||||
| main_var = pool.get([node_id] + variable_key_list) | |||||
| assert main_var is not None | |||||
| assert main_var.value == variable_value | |||||
| # Check that nested variables are added recursively | |||||
| name_var = pool.get([node_id] + variable_key_list + ["name"]) | |||||
| assert name_var is not None | |||||
| assert name_var.value == "John" | |||||
| age_var = pool.get([node_id] + variable_key_list + ["age"]) | |||||
| assert age_var is not None | |||||
| assert age_var.value == 30 | |||||
| def test_append_object_segment_value(self): | |||||
| """Test appending an ObjectSegment value""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["result"] | |||||
| # Create an ObjectSegment | |||||
| obj_data = {"status": "success", "code": 200} | |||||
| variable_value = ObjectSegment(value=obj_data) | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check that the main variable is added | |||||
| main_var = pool.get([node_id] + variable_key_list) | |||||
| assert main_var is not None | |||||
| assert isinstance(main_var, ObjectSegment) | |||||
| assert main_var.value == obj_data | |||||
| # Check that nested variables are added recursively | |||||
| status_var = pool.get([node_id] + variable_key_list + ["status"]) | |||||
| assert status_var is not None | |||||
| assert status_var.value == "success" | |||||
| code_var = pool.get([node_id] + variable_key_list + ["code"]) | |||||
| assert code_var is not None | |||||
| assert code_var.value == 200 | |||||
| def test_append_nested_dict_value(self): | |||||
| """Test appending a nested dictionary value""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["data"] | |||||
| variable_value = { | |||||
| "user": { | |||||
| "profile": {"name": "Alice", "email": "alice@example.com"}, | |||||
| "settings": {"theme": "dark", "notifications": True}, | |||||
| }, | |||||
| "metadata": {"version": "1.0", "timestamp": 1234567890}, | |||||
| } | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check deeply nested variables | |||||
| name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) | |||||
| assert name_var is not None | |||||
| assert name_var.value == "Alice" | |||||
| email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) | |||||
| assert email_var is not None | |||||
| assert email_var.value == "alice@example.com" | |||||
| theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) | |||||
| assert theme_var is not None | |||||
| assert theme_var.value == "dark" | |||||
| notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) | |||||
| assert notifications_var is not None | |||||
| assert notifications_var.value == 1 # Boolean True is converted to integer 1 | |||||
| version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) | |||||
| assert version_var is not None | |||||
| assert version_var.value == "1.0" | |||||
| def test_append_non_dict_value(self): | |||||
| """Test appending a non-dictionary value (should not recurse)""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["simple"] | |||||
| variable_value = "simple_string" | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check that only the main variable is added | |||||
| main_var = pool.get([node_id] + variable_key_list) | |||||
| assert main_var is not None | |||||
| assert main_var.value == variable_value | |||||
| # Ensure no additional variables are created | |||||
| assert len(pool.variable_dictionary[node_id]) == 1 | |||||
| def test_append_segment_non_object_value(self): | |||||
| """Test appending a Segment that is not ObjectSegment (should not recurse)""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["text"] | |||||
| variable_value = StringSegment(value="Hello World") | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check that only the main variable is added | |||||
| main_var = pool.get([node_id] + variable_key_list) | |||||
| assert main_var is not None | |||||
| assert isinstance(main_var, StringSegment) | |||||
| assert main_var.value == "Hello World" | |||||
| # Ensure no additional variables are created | |||||
| assert len(pool.variable_dictionary[node_id]) == 1 | |||||
| def test_append_empty_dict_value(self): | |||||
| """Test appending an empty dictionary value""" | |||||
| pool = VariablePool.empty() | |||||
| node_id = "test_node" | |||||
| variable_key_list = ["empty"] | |||||
| variable_value: dict[str, Any] = {} | |||||
| append_variables_recursively(pool, node_id, variable_key_list, variable_value) | |||||
| # Check that the main variable is added | |||||
| main_var = pool.get([node_id] + variable_key_list) | |||||
| assert main_var is not None | |||||
| assert main_var.value == {} | |||||
| # Ensure only the main variable is created (no recursion for empty dict) | |||||
| assert len(pool.variable_dictionary[node_id]) == 1 |