Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

variable_truncator.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import dataclasses
  2. from collections.abc import Mapping
  3. from enum import StrEnum
  4. from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload
  5. from configs import dify_config
  6. from core.file.models import File
  7. from core.variables.segments import (
  8. ArrayFileSegment,
  9. ArraySegment,
  10. BooleanSegment,
  11. FileSegment,
  12. FloatSegment,
  13. IntegerSegment,
  14. NoneSegment,
  15. ObjectSegment,
  16. Segment,
  17. StringSegment,
  18. )
  19. from core.variables.utils import dumps_with_segments
  20. _MAX_DEPTH = 100
  21. class _QAKeys:
  22. """dict keys for _QAStructure"""
  23. QA_CHUNKS = "qa_chunks"
  24. QUESTION = "question"
  25. ANSWER = "answer"
  26. class _PCKeys:
  27. """dict keys for _ParentChildStructure"""
  28. PARENT_MODE = "parent_mode"
  29. PARENT_CHILD_CHUNKS = "parent_child_chunks"
  30. PARENT_CONTENT = "parent_content"
  31. CHILD_CONTENTS = "child_contents"
  32. class _QAStructureItem(TypedDict):
  33. question: str
  34. answer: str
  35. class _QAStructure(TypedDict):
  36. qa_chunks: list[_QAStructureItem]
  37. class _ParentChildChunkItem(TypedDict):
  38. parent_content: str
  39. child_contents: list[str]
  40. class _ParentChildStructure(TypedDict):
  41. parent_mode: str
  42. parent_child_chunks: list[_ParentChildChunkItem]
  43. class _SpecialChunkType(StrEnum):
  44. parent_child = "parent_child"
  45. qa = "qa"
  46. _T = TypeVar("_T")
  47. @dataclasses.dataclass(frozen=True)
  48. class _PartResult(Generic[_T]):
  49. value: _T
  50. value_size: int
  51. truncated: bool
  52. class MaxDepthExceededError(Exception):
  53. pass
  54. class UnknownTypeError(Exception):
  55. pass
  56. JSONTypes: TypeAlias = int | float | str | list | dict | None | bool
  57. @dataclasses.dataclass(frozen=True)
  58. class TruncationResult:
  59. result: Segment
  60. truncated: bool
  61. class VariableTruncator:
  62. """
  63. Handles variable truncation with structure-preserving strategies.
  64. This class implements intelligent truncation that prioritizes maintaining data structure
  65. integrity while ensuring the final size doesn't exceed specified limits.
  66. Uses recursive size calculation to avoid repeated JSON serialization.
  67. """
  68. def __init__(
  69. self,
  70. string_length_limit=5000,
  71. array_element_limit: int = 20,
  72. max_size_bytes: int = 1024_000, # 100KB
  73. ):
  74. if string_length_limit <= 3:
  75. raise ValueError("string_length_limit should be greater than 3.")
  76. self._string_length_limit = string_length_limit
  77. if array_element_limit <= 0:
  78. raise ValueError("array_element_limit should be greater than 0.")
  79. self._array_element_limit = array_element_limit
  80. if max_size_bytes <= 0:
  81. raise ValueError("max_size_bytes should be greater than 0.")
  82. self._max_size_bytes = max_size_bytes
  83. @classmethod
  84. def default(cls) -> "VariableTruncator":
  85. return VariableTruncator(
  86. max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
  87. array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
  88. string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
  89. )
  90. def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
  91. """
  92. `truncate_variable_mapping` is responsible for truncating variable mappings
  93. generated during workflow execution, such as `inputs`, `process_data`, or `outputs`
  94. of a WorkflowNodeExecution record. This ensures the mappings remain within the
  95. specified size limits while preserving their structure.
  96. """
  97. budget = self._max_size_bytes
  98. is_truncated = False
  99. truncated_mapping: dict[str, Any] = {}
  100. length = len(v.items())
  101. used_size = 0
  102. for key, value in v.items():
  103. used_size += self.calculate_json_size(key)
  104. if used_size > budget:
  105. truncated_mapping[key] = "..."
  106. continue
  107. value_budget = (budget - used_size) // (length - len(truncated_mapping))
  108. if isinstance(value, Segment):
  109. part_result = self._truncate_segment(value, value_budget)
  110. else:
  111. part_result = self._truncate_json_primitives(value, value_budget)
  112. is_truncated = is_truncated or part_result.truncated
  113. truncated_mapping[key] = part_result.value
  114. used_size += part_result.value_size
  115. return truncated_mapping, is_truncated
  116. @staticmethod
  117. def _segment_need_truncation(segment: Segment) -> bool:
  118. if isinstance(
  119. segment,
  120. (NoneSegment, FloatSegment, IntegerSegment, FileSegment, BooleanSegment, ArrayFileSegment),
  121. ):
  122. return False
  123. return True
  124. @staticmethod
  125. def _json_value_needs_truncation(value: Any) -> bool:
  126. if value is None:
  127. return False
  128. if isinstance(value, (bool, int, float)):
  129. return False
  130. return True
  131. def truncate(self, segment: Segment) -> TruncationResult:
  132. if isinstance(segment, StringSegment):
  133. result = self._truncate_segment(segment, self._string_length_limit)
  134. else:
  135. result = self._truncate_segment(segment, self._max_size_bytes)
  136. if result.value_size > self._max_size_bytes:
  137. if isinstance(result.value, str):
  138. result = self._truncate_string(result.value, self._max_size_bytes)
  139. return TruncationResult(StringSegment(value=result.value), True)
  140. # Apply final fallback - convert to JSON string and truncate
  141. json_str = dumps_with_segments(result.value, ensure_ascii=False)
  142. if len(json_str) > self._max_size_bytes:
  143. json_str = json_str[: self._max_size_bytes] + "..."
  144. return TruncationResult(result=StringSegment(value=json_str), truncated=True)
  145. return TruncationResult(
  146. result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
  147. )
  148. def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
  149. """
  150. Apply smart truncation to a variable value.
  151. Args:
  152. value: The value to truncate (can be Segment or raw value)
  153. Returns:
  154. TruncationResult with truncated data and truncation status
  155. """
  156. if not VariableTruncator._segment_need_truncation(segment):
  157. return _PartResult(segment, self.calculate_json_size(segment.value), False)
  158. result: _PartResult[Any]
  159. # Apply type-specific truncation with target size
  160. if isinstance(segment, ArraySegment):
  161. result = self._truncate_array(segment.value, target_size)
  162. elif isinstance(segment, StringSegment):
  163. result = self._truncate_string(segment.value, target_size)
  164. elif isinstance(segment, ObjectSegment):
  165. result = self._truncate_object(segment.value, target_size)
  166. else:
  167. raise AssertionError("this should be unreachable.")
  168. return _PartResult(
  169. value=segment.model_copy(update={"value": result.value}),
  170. value_size=result.value_size,
  171. truncated=result.truncated,
  172. )
  173. @staticmethod
  174. def calculate_json_size(value: Any, depth=0) -> int:
  175. """Recursively calculate JSON size without serialization."""
  176. if isinstance(value, Segment):
  177. return VariableTruncator.calculate_json_size(value.value)
  178. if depth > _MAX_DEPTH:
  179. raise MaxDepthExceededError()
  180. if isinstance(value, str):
  181. # Ideally, the size of strings should be calculated based on their utf-8 encoded length.
  182. # However, this adds complexity as we would need to compute encoded sizes consistently
  183. # throughout the code. Therefore, we approximate the size using the string's length.
  184. # Rough estimate: number of characters, plus 2 for quotes
  185. return len(value) + 2
  186. elif isinstance(value, (int, float)):
  187. return len(str(value))
  188. elif isinstance(value, bool):
  189. return 4 if value else 5 # "true" or "false"
  190. elif value is None:
  191. return 4 # "null"
  192. elif isinstance(value, list):
  193. # Size = sum of elements + separators + brackets
  194. total = 2 # "[]"
  195. for i, item in enumerate(value):
  196. if i > 0:
  197. total += 1 # ","
  198. total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
  199. return total
  200. elif isinstance(value, dict):
  201. # Size = sum of keys + values + separators + brackets
  202. total = 2 # "{}"
  203. for index, key in enumerate(value.keys()):
  204. if index > 0:
  205. total += 1 # ","
  206. total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
  207. total += 1 # ":"
  208. total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
  209. return total
  210. elif isinstance(value, File):
  211. return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
  212. else:
  213. raise UnknownTypeError(f"got unknown type {type(value)}")
  214. def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
  215. if (size := self.calculate_json_size(value)) < target_size:
  216. return _PartResult(value, size, False)
  217. if target_size < 5:
  218. return _PartResult("...", 5, True)
  219. truncated_size = min(self._string_length_limit, target_size - 5)
  220. truncated_value = value[:truncated_size] + "..."
  221. return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
  222. def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
  223. """
  224. Truncate array with correct strategy:
  225. 1. First limit to 20 items
  226. 2. If still too large, truncate individual items
  227. """
  228. truncated_value: list[Any] = []
  229. truncated = False
  230. used_size = self.calculate_json_size([])
  231. target_length = self._array_element_limit
  232. for i, item in enumerate(value):
  233. if i >= target_length:
  234. return _PartResult(truncated_value, used_size, True)
  235. if i > 0:
  236. used_size += 1 # Account for comma
  237. if used_size > target_size:
  238. break
  239. part_result = self._truncate_json_primitives(item, target_size - used_size)
  240. truncated_value.append(part_result.value)
  241. used_size += part_result.value_size
  242. truncated = part_result.truncated
  243. return _PartResult(truncated_value, used_size, truncated)
  244. @classmethod
  245. def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
  246. qa_chunks = m.get(_QAKeys.QA_CHUNKS)
  247. if qa_chunks is None:
  248. return False
  249. if not isinstance(qa_chunks, list):
  250. return False
  251. return True
  252. @classmethod
  253. def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
  254. parent_mode = m.get(_PCKeys.PARENT_MODE)
  255. if parent_mode is None:
  256. return False
  257. if not isinstance(parent_mode, str):
  258. return False
  259. parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
  260. if parent_child_chunks is None:
  261. return False
  262. if not isinstance(parent_child_chunks, list):
  263. return False
  264. return True
  265. def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
  266. """
  267. Truncate object with key preservation priority.
  268. Strategy:
  269. 1. Keep all keys, truncate values to fit within budget
  270. 2. If still too large, drop keys starting from the end
  271. """
  272. if not mapping:
  273. return _PartResult(mapping, self.calculate_json_size(mapping), False)
  274. truncated_obj = {}
  275. truncated = False
  276. used_size = self.calculate_json_size({})
  277. # Sort keys to ensure deterministic behavior
  278. sorted_keys = sorted(mapping.keys())
  279. for i, key in enumerate(sorted_keys):
  280. if used_size > target_size:
  281. # No more room for additional key-value pairs
  282. truncated = True
  283. break
  284. pair_size = 0
  285. if i > 0:
  286. pair_size += 1 # Account for comma
  287. # Calculate budget for this key-value pair
  288. # do not try to truncate keys, as we want to keep the structure of
  289. # object.
  290. key_size = self.calculate_json_size(key) + 1 # +1 for ":"
  291. pair_size += key_size
  292. remaining_pairs = len(sorted_keys) - i
  293. value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
  294. if value_budget <= 0:
  295. truncated = True
  296. break
  297. # Truncate the value to fit within budget
  298. value = mapping[key]
  299. if isinstance(value, Segment):
  300. value_result = self._truncate_segment(value, value_budget)
  301. else:
  302. value_result = self._truncate_json_primitives(mapping[key], value_budget)
  303. truncated_obj[key] = value_result.value
  304. pair_size += value_result.value_size
  305. used_size += pair_size
  306. if value_result.truncated:
  307. truncated = True
  308. return _PartResult(truncated_obj, used_size, truncated)
  309. @overload
  310. def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
  311. @overload
  312. def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
  313. @overload
  314. def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
  315. @overload
  316. def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ...
  317. @overload
  318. def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
  319. @overload
  320. def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
  321. @overload
  322. def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
  323. def _truncate_json_primitives(
  324. self, val: str | list | dict | bool | int | float | None, target_size: int
  325. ) -> _PartResult[Any]:
  326. """Truncate a value within an object to fit within budget."""
  327. if isinstance(val, str):
  328. return self._truncate_string(val, target_size)
  329. elif isinstance(val, list):
  330. return self._truncate_array(val, target_size)
  331. elif isinstance(val, dict):
  332. return self._truncate_object(val, target_size)
  333. elif val is None or isinstance(val, (bool, int, float)):
  334. return _PartResult(val, self.calculate_json_size(val), False)
  335. else:
  336. raise AssertionError("this statement should be unreachable.")