| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- import dataclasses
- from collections.abc import Mapping
- from typing import Any, Generic, TypeAlias, TypeVar, overload
-
- from configs import dify_config
- from core.file.models import File
- from core.variables.segments import (
- ArrayFileSegment,
- ArraySegment,
- BooleanSegment,
- FileSegment,
- FloatSegment,
- IntegerSegment,
- NoneSegment,
- ObjectSegment,
- Segment,
- StringSegment,
- )
- from core.variables.utils import dumps_with_segments
-
- _MAX_DEPTH = 100
-
-
- class _QAKeys:
- """dict keys for _QAStructure"""
-
- QA_CHUNKS = "qa_chunks"
- QUESTION = "question"
- ANSWER = "answer"
-
-
- class _PCKeys:
- """dict keys for _ParentChildStructure"""
-
- PARENT_MODE = "parent_mode"
- PARENT_CHILD_CHUNKS = "parent_child_chunks"
- PARENT_CONTENT = "parent_content"
- CHILD_CONTENTS = "child_contents"
-
-
- _T = TypeVar("_T")
-
-
- @dataclasses.dataclass(frozen=True)
- class _PartResult(Generic[_T]):
- value: _T
- value_size: int
- truncated: bool
-
-
- class MaxDepthExceededError(Exception):
- pass
-
-
- class UnknownTypeError(Exception):
- pass
-
-
- JSONTypes: TypeAlias = int | float | str | list | dict | None | bool
-
-
- @dataclasses.dataclass(frozen=True)
- class TruncationResult:
- result: Segment
- truncated: bool
-
-
- class VariableTruncator:
- """
- Handles variable truncation with structure-preserving strategies.
-
- This class implements intelligent truncation that prioritizes maintaining data structure
- integrity while ensuring the final size doesn't exceed specified limits.
-
- Uses recursive size calculation to avoid repeated JSON serialization.
- """
-
- def __init__(
- self,
- string_length_limit=5000,
- array_element_limit: int = 20,
- max_size_bytes: int = 1024_000, # 100KB
- ):
- if string_length_limit <= 3:
- raise ValueError("string_length_limit should be greater than 3.")
- self._string_length_limit = string_length_limit
-
- if array_element_limit <= 0:
- raise ValueError("array_element_limit should be greater than 0.")
- self._array_element_limit = array_element_limit
-
- if max_size_bytes <= 0:
- raise ValueError("max_size_bytes should be greater than 0.")
- self._max_size_bytes = max_size_bytes
-
- @classmethod
- def default(cls) -> "VariableTruncator":
- return VariableTruncator(
- max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
- array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
- string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
- )
-
- def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
- """
- `truncate_variable_mapping` is responsible for truncating variable mappings
- generated during workflow execution, such as `inputs`, `process_data`, or `outputs`
- of a WorkflowNodeExecution record. This ensures the mappings remain within the
- specified size limits while preserving their structure.
- """
- budget = self._max_size_bytes
- is_truncated = False
- truncated_mapping: dict[str, Any] = {}
- length = len(v.items())
- used_size = 0
- for key, value in v.items():
- used_size += self.calculate_json_size(key)
- if used_size > budget:
- truncated_mapping[key] = "..."
- continue
- value_budget = (budget - used_size) // (length - len(truncated_mapping))
- if isinstance(value, Segment):
- part_result = self._truncate_segment(value, value_budget)
- else:
- part_result = self._truncate_json_primitives(value, value_budget)
- is_truncated = is_truncated or part_result.truncated
- truncated_mapping[key] = part_result.value
- used_size += part_result.value_size
- return truncated_mapping, is_truncated
-
- @staticmethod
- def _segment_need_truncation(segment: Segment) -> bool:
- if isinstance(
- segment,
- (NoneSegment, FloatSegment, IntegerSegment, FileSegment, BooleanSegment, ArrayFileSegment),
- ):
- return False
- return True
-
- @staticmethod
- def _json_value_needs_truncation(value: Any) -> bool:
- if value is None:
- return False
- if isinstance(value, (bool, int, float)):
- return False
- return True
-
- def truncate(self, segment: Segment) -> TruncationResult:
- if isinstance(segment, StringSegment):
- result = self._truncate_segment(segment, self._string_length_limit)
- else:
- result = self._truncate_segment(segment, self._max_size_bytes)
-
- if result.value_size > self._max_size_bytes:
- if isinstance(result.value, str):
- result = self._truncate_string(result.value, self._max_size_bytes)
- return TruncationResult(StringSegment(value=result.value), True)
-
- # Apply final fallback - convert to JSON string and truncate
- json_str = dumps_with_segments(result.value, ensure_ascii=False)
- if len(json_str) > self._max_size_bytes:
- json_str = json_str[: self._max_size_bytes] + "..."
- return TruncationResult(result=StringSegment(value=json_str), truncated=True)
-
- return TruncationResult(
- result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
- )
-
- def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
- """
- Apply smart truncation to a variable value.
-
- Args:
- value: The value to truncate (can be Segment or raw value)
-
- Returns:
- TruncationResult with truncated data and truncation status
- """
-
- if not VariableTruncator._segment_need_truncation(segment):
- return _PartResult(segment, self.calculate_json_size(segment.value), False)
-
- result: _PartResult[Any]
- # Apply type-specific truncation with target size
- if isinstance(segment, ArraySegment):
- result = self._truncate_array(segment.value, target_size)
- elif isinstance(segment, StringSegment):
- result = self._truncate_string(segment.value, target_size)
- elif isinstance(segment, ObjectSegment):
- result = self._truncate_object(segment.value, target_size)
- else:
- raise AssertionError("this should be unreachable.")
-
- return _PartResult(
- value=segment.model_copy(update={"value": result.value}),
- value_size=result.value_size,
- truncated=result.truncated,
- )
-
- @staticmethod
- def calculate_json_size(value: Any, depth=0) -> int:
- """Recursively calculate JSON size without serialization."""
- if isinstance(value, Segment):
- return VariableTruncator.calculate_json_size(value.value)
- if depth > _MAX_DEPTH:
- raise MaxDepthExceededError()
- if isinstance(value, str):
- # Ideally, the size of strings should be calculated based on their utf-8 encoded length.
- # However, this adds complexity as we would need to compute encoded sizes consistently
- # throughout the code. Therefore, we approximate the size using the string's length.
- # Rough estimate: number of characters, plus 2 for quotes
- return len(value) + 2
- elif isinstance(value, (int, float)):
- return len(str(value))
- elif isinstance(value, bool):
- return 4 if value else 5 # "true" or "false"
- elif value is None:
- return 4 # "null"
- elif isinstance(value, list):
- # Size = sum of elements + separators + brackets
- total = 2 # "[]"
- for i, item in enumerate(value):
- if i > 0:
- total += 1 # ","
- total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
- return total
- elif isinstance(value, dict):
- # Size = sum of keys + values + separators + brackets
- total = 2 # "{}"
- for index, key in enumerate(value.keys()):
- if index > 0:
- total += 1 # ","
- total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
- total += 1 # ":"
- total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
- return total
- elif isinstance(value, File):
- return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
- else:
- raise UnknownTypeError(f"got unknown type {type(value)}")
-
- def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
- if (size := self.calculate_json_size(value)) < target_size:
- return _PartResult(value, size, False)
- if target_size < 5:
- return _PartResult("...", 5, True)
- truncated_size = min(self._string_length_limit, target_size - 5)
- truncated_value = value[:truncated_size] + "..."
- return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
-
- def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
- """
- Truncate array with correct strategy:
- 1. First limit to 20 items
- 2. If still too large, truncate individual items
- """
-
- truncated_value: list[Any] = []
- truncated = False
- used_size = self.calculate_json_size([])
-
- target_length = self._array_element_limit
-
- for i, item in enumerate(value):
- if i >= target_length:
- return _PartResult(truncated_value, used_size, True)
- if i > 0:
- used_size += 1 # Account for comma
-
- if used_size > target_size:
- break
-
- part_result = self._truncate_json_primitives(item, target_size - used_size)
- truncated_value.append(part_result.value)
- used_size += part_result.value_size
- truncated = part_result.truncated
- return _PartResult(truncated_value, used_size, truncated)
-
- @classmethod
- def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
- qa_chunks = m.get(_QAKeys.QA_CHUNKS)
- if qa_chunks is None:
- return False
- if not isinstance(qa_chunks, list):
- return False
- return True
-
- @classmethod
- def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
- parent_mode = m.get(_PCKeys.PARENT_MODE)
- if parent_mode is None:
- return False
- if not isinstance(parent_mode, str):
- return False
- parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
- if parent_child_chunks is None:
- return False
- if not isinstance(parent_child_chunks, list):
- return False
-
- return True
-
- def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
- """
- Truncate object with key preservation priority.
-
- Strategy:
- 1. Keep all keys, truncate values to fit within budget
- 2. If still too large, drop keys starting from the end
- """
- if not mapping:
- return _PartResult(mapping, self.calculate_json_size(mapping), False)
-
- truncated_obj = {}
- truncated = False
- used_size = self.calculate_json_size({})
-
- # Sort keys to ensure deterministic behavior
- sorted_keys = sorted(mapping.keys())
-
- for i, key in enumerate(sorted_keys):
- if used_size > target_size:
- # No more room for additional key-value pairs
- truncated = True
- break
-
- pair_size = 0
-
- if i > 0:
- pair_size += 1 # Account for comma
-
- # Calculate budget for this key-value pair
- # do not try to truncate keys, as we want to keep the structure of
- # object.
- key_size = self.calculate_json_size(key) + 1 # +1 for ":"
- pair_size += key_size
- remaining_pairs = len(sorted_keys) - i
- value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
-
- if value_budget <= 0:
- truncated = True
- break
-
- # Truncate the value to fit within budget
- value = mapping[key]
- if isinstance(value, Segment):
- value_result = self._truncate_segment(value, value_budget)
- else:
- value_result = self._truncate_json_primitives(mapping[key], value_budget)
-
- truncated_obj[key] = value_result.value
- pair_size += value_result.value_size
- used_size += pair_size
-
- if value_result.truncated:
- truncated = True
-
- return _PartResult(truncated_obj, used_size, truncated)
-
- @overload
- def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
-
- @overload
- def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
-
- @overload
- def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
-
- @overload
- def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
-
- @overload
- def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
-
- @overload
- def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
-
- @overload
- def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
-
- def _truncate_json_primitives(
- self, val: str | list | dict | bool | int | float | None, target_size: int
- ) -> _PartResult[Any]:
- """Truncate a value within an object to fit within budget."""
- if isinstance(val, str):
- return self._truncate_string(val, target_size)
- elif isinstance(val, list):
- return self._truncate_array(val, target_size)
- elif isinstance(val, dict):
- return self._truncate_object(val, target_size)
- elif val is None or isinstance(val, (bool, int, float)):
- return _PartResult(val, self.calculate_json_size(val), False)
- else:
- raise AssertionError("this statement should be unreachable.")
|