選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

variable_truncator.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import dataclasses
  2. from collections.abc import Mapping
  3. from typing import Any, Generic, TypeAlias, TypeVar, overload
  4. from configs import dify_config
  5. from core.file.models import File
  6. from core.variables.segments import (
  7. ArrayFileSegment,
  8. ArraySegment,
  9. BooleanSegment,
  10. FileSegment,
  11. FloatSegment,
  12. IntegerSegment,
  13. NoneSegment,
  14. ObjectSegment,
  15. Segment,
  16. StringSegment,
  17. )
  18. from core.variables.utils import dumps_with_segments
  19. _MAX_DEPTH = 100
  20. class _QAKeys:
  21. """dict keys for _QAStructure"""
  22. QA_CHUNKS = "qa_chunks"
  23. QUESTION = "question"
  24. ANSWER = "answer"
  25. class _PCKeys:
  26. """dict keys for _ParentChildStructure"""
  27. PARENT_MODE = "parent_mode"
  28. PARENT_CHILD_CHUNKS = "parent_child_chunks"
  29. PARENT_CONTENT = "parent_content"
  30. CHILD_CONTENTS = "child_contents"
  31. _T = TypeVar("_T")
  32. @dataclasses.dataclass(frozen=True)
  33. class _PartResult(Generic[_T]):
  34. value: _T
  35. value_size: int
  36. truncated: bool
  37. class MaxDepthExceededError(Exception):
  38. pass
  39. class UnknownTypeError(Exception):
  40. pass
  41. JSONTypes: TypeAlias = int | float | str | list | dict | None | bool
  42. @dataclasses.dataclass(frozen=True)
  43. class TruncationResult:
  44. result: Segment
  45. truncated: bool
  46. class VariableTruncator:
  47. """
  48. Handles variable truncation with structure-preserving strategies.
  49. This class implements intelligent truncation that prioritizes maintaining data structure
  50. integrity while ensuring the final size doesn't exceed specified limits.
  51. Uses recursive size calculation to avoid repeated JSON serialization.
  52. """
  53. def __init__(
  54. self,
  55. string_length_limit=5000,
  56. array_element_limit: int = 20,
  57. max_size_bytes: int = 1024_000, # 100KB
  58. ):
  59. if string_length_limit <= 3:
  60. raise ValueError("string_length_limit should be greater than 3.")
  61. self._string_length_limit = string_length_limit
  62. if array_element_limit <= 0:
  63. raise ValueError("array_element_limit should be greater than 0.")
  64. self._array_element_limit = array_element_limit
  65. if max_size_bytes <= 0:
  66. raise ValueError("max_size_bytes should be greater than 0.")
  67. self._max_size_bytes = max_size_bytes
  68. @classmethod
  69. def default(cls) -> "VariableTruncator":
  70. return VariableTruncator(
  71. max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
  72. array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
  73. string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
  74. )
  75. def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
  76. """
  77. `truncate_variable_mapping` is responsible for truncating variable mappings
  78. generated during workflow execution, such as `inputs`, `process_data`, or `outputs`
  79. of a WorkflowNodeExecution record. This ensures the mappings remain within the
  80. specified size limits while preserving their structure.
  81. """
  82. budget = self._max_size_bytes
  83. is_truncated = False
  84. truncated_mapping: dict[str, Any] = {}
  85. length = len(v.items())
  86. used_size = 0
  87. for key, value in v.items():
  88. used_size += self.calculate_json_size(key)
  89. if used_size > budget:
  90. truncated_mapping[key] = "..."
  91. continue
  92. value_budget = (budget - used_size) // (length - len(truncated_mapping))
  93. if isinstance(value, Segment):
  94. part_result = self._truncate_segment(value, value_budget)
  95. else:
  96. part_result = self._truncate_json_primitives(value, value_budget)
  97. is_truncated = is_truncated or part_result.truncated
  98. truncated_mapping[key] = part_result.value
  99. used_size += part_result.value_size
  100. return truncated_mapping, is_truncated
  101. @staticmethod
  102. def _segment_need_truncation(segment: Segment) -> bool:
  103. if isinstance(
  104. segment,
  105. (NoneSegment, FloatSegment, IntegerSegment, FileSegment, BooleanSegment, ArrayFileSegment),
  106. ):
  107. return False
  108. return True
  109. @staticmethod
  110. def _json_value_needs_truncation(value: Any) -> bool:
  111. if value is None:
  112. return False
  113. if isinstance(value, (bool, int, float)):
  114. return False
  115. return True
  116. def truncate(self, segment: Segment) -> TruncationResult:
  117. if isinstance(segment, StringSegment):
  118. result = self._truncate_segment(segment, self._string_length_limit)
  119. else:
  120. result = self._truncate_segment(segment, self._max_size_bytes)
  121. if result.value_size > self._max_size_bytes:
  122. if isinstance(result.value, str):
  123. result = self._truncate_string(result.value, self._max_size_bytes)
  124. return TruncationResult(StringSegment(value=result.value), True)
  125. # Apply final fallback - convert to JSON string and truncate
  126. json_str = dumps_with_segments(result.value, ensure_ascii=False)
  127. if len(json_str) > self._max_size_bytes:
  128. json_str = json_str[: self._max_size_bytes] + "..."
  129. return TruncationResult(result=StringSegment(value=json_str), truncated=True)
  130. return TruncationResult(
  131. result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
  132. )
  133. def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
  134. """
  135. Apply smart truncation to a variable value.
  136. Args:
  137. value: The value to truncate (can be Segment or raw value)
  138. Returns:
  139. TruncationResult with truncated data and truncation status
  140. """
  141. if not VariableTruncator._segment_need_truncation(segment):
  142. return _PartResult(segment, self.calculate_json_size(segment.value), False)
  143. result: _PartResult[Any]
  144. # Apply type-specific truncation with target size
  145. if isinstance(segment, ArraySegment):
  146. result = self._truncate_array(segment.value, target_size)
  147. elif isinstance(segment, StringSegment):
  148. result = self._truncate_string(segment.value, target_size)
  149. elif isinstance(segment, ObjectSegment):
  150. result = self._truncate_object(segment.value, target_size)
  151. else:
  152. raise AssertionError("this should be unreachable.")
  153. return _PartResult(
  154. value=segment.model_copy(update={"value": result.value}),
  155. value_size=result.value_size,
  156. truncated=result.truncated,
  157. )
  158. @staticmethod
  159. def calculate_json_size(value: Any, depth=0) -> int:
  160. """Recursively calculate JSON size without serialization."""
  161. if isinstance(value, Segment):
  162. return VariableTruncator.calculate_json_size(value.value)
  163. if depth > _MAX_DEPTH:
  164. raise MaxDepthExceededError()
  165. if isinstance(value, str):
  166. # Ideally, the size of strings should be calculated based on their utf-8 encoded length.
  167. # However, this adds complexity as we would need to compute encoded sizes consistently
  168. # throughout the code. Therefore, we approximate the size using the string's length.
  169. # Rough estimate: number of characters, plus 2 for quotes
  170. return len(value) + 2
  171. elif isinstance(value, (int, float)):
  172. return len(str(value))
  173. elif isinstance(value, bool):
  174. return 4 if value else 5 # "true" or "false"
  175. elif value is None:
  176. return 4 # "null"
  177. elif isinstance(value, list):
  178. # Size = sum of elements + separators + brackets
  179. total = 2 # "[]"
  180. for i, item in enumerate(value):
  181. if i > 0:
  182. total += 1 # ","
  183. total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
  184. return total
  185. elif isinstance(value, dict):
  186. # Size = sum of keys + values + separators + brackets
  187. total = 2 # "{}"
  188. for index, key in enumerate(value.keys()):
  189. if index > 0:
  190. total += 1 # ","
  191. total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
  192. total += 1 # ":"
  193. total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
  194. return total
  195. elif isinstance(value, File):
  196. return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
  197. else:
  198. raise UnknownTypeError(f"got unknown type {type(value)}")
  199. def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
  200. if (size := self.calculate_json_size(value)) < target_size:
  201. return _PartResult(value, size, False)
  202. if target_size < 5:
  203. return _PartResult("...", 5, True)
  204. truncated_size = min(self._string_length_limit, target_size - 5)
  205. truncated_value = value[:truncated_size] + "..."
  206. return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
  207. def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
  208. """
  209. Truncate array with correct strategy:
  210. 1. First limit to 20 items
  211. 2. If still too large, truncate individual items
  212. """
  213. truncated_value: list[Any] = []
  214. truncated = False
  215. used_size = self.calculate_json_size([])
  216. target_length = self._array_element_limit
  217. for i, item in enumerate(value):
  218. # Dirty fix:
  219. # The output of `Start` node may contain list of `File` elements,
  220. # causing `AssertionError` while invoking `_truncate_json_primitives`.
  221. #
  222. # This check ensures that `list[File]` are handled separately
  223. if isinstance(item, File):
  224. truncated_value.append(item)
  225. continue
  226. if i >= target_length:
  227. return _PartResult(truncated_value, used_size, True)
  228. if i > 0:
  229. used_size += 1 # Account for comma
  230. if used_size > target_size:
  231. break
  232. part_result = self._truncate_json_primitives(item, target_size - used_size)
  233. truncated_value.append(part_result.value)
  234. used_size += part_result.value_size
  235. truncated = part_result.truncated
  236. return _PartResult(truncated_value, used_size, truncated)
  237. @classmethod
  238. def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
  239. qa_chunks = m.get(_QAKeys.QA_CHUNKS)
  240. if qa_chunks is None:
  241. return False
  242. if not isinstance(qa_chunks, list):
  243. return False
  244. return True
  245. @classmethod
  246. def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
  247. parent_mode = m.get(_PCKeys.PARENT_MODE)
  248. if parent_mode is None:
  249. return False
  250. if not isinstance(parent_mode, str):
  251. return False
  252. parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
  253. if parent_child_chunks is None:
  254. return False
  255. if not isinstance(parent_child_chunks, list):
  256. return False
  257. return True
  258. def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
  259. """
  260. Truncate object with key preservation priority.
  261. Strategy:
  262. 1. Keep all keys, truncate values to fit within budget
  263. 2. If still too large, drop keys starting from the end
  264. """
  265. if not mapping:
  266. return _PartResult(mapping, self.calculate_json_size(mapping), False)
  267. truncated_obj = {}
  268. truncated = False
  269. used_size = self.calculate_json_size({})
  270. # Sort keys to ensure deterministic behavior
  271. sorted_keys = sorted(mapping.keys())
  272. for i, key in enumerate(sorted_keys):
  273. if used_size > target_size:
  274. # No more room for additional key-value pairs
  275. truncated = True
  276. break
  277. pair_size = 0
  278. if i > 0:
  279. pair_size += 1 # Account for comma
  280. # Calculate budget for this key-value pair
  281. # do not try to truncate keys, as we want to keep the structure of
  282. # object.
  283. key_size = self.calculate_json_size(key) + 1 # +1 for ":"
  284. pair_size += key_size
  285. remaining_pairs = len(sorted_keys) - i
  286. value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
  287. if value_budget <= 0:
  288. truncated = True
  289. break
  290. # Truncate the value to fit within budget
  291. value = mapping[key]
  292. if isinstance(value, Segment):
  293. value_result = self._truncate_segment(value, value_budget)
  294. else:
  295. value_result = self._truncate_json_primitives(mapping[key], value_budget)
  296. truncated_obj[key] = value_result.value
  297. pair_size += value_result.value_size
  298. used_size += pair_size
  299. if value_result.truncated:
  300. truncated = True
  301. return _PartResult(truncated_obj, used_size, truncated)
  302. @overload
  303. def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
  304. @overload
  305. def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
  306. @overload
  307. def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
  308. @overload
  309. def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
  310. @overload
  311. def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
  312. @overload
  313. def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
  314. @overload
  315. def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
  316. def _truncate_json_primitives(
  317. self, val: str | list | dict | bool | int | float | None, target_size: int
  318. ) -> _PartResult[Any]:
  319. """Truncate a value within an object to fit within budget."""
  320. if isinstance(val, str):
  321. return self._truncate_string(val, target_size)
  322. elif isinstance(val, list):
  323. return self._truncate_array(val, target_size)
  324. elif isinstance(val, dict):
  325. return self._truncate_object(val, target_size)
  326. elif val is None or isinstance(val, (bool, int, float)):
  327. return _PartResult(val, self.calculate_json_size(val), False)
  328. else:
  329. raise AssertionError("this statement should be unreachable.")