Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

variable_factory.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayFileSegment,
  10. ArrayNumberSegment,
  11. ArrayObjectSegment,
  12. ArraySegment,
  13. ArrayStringSegment,
  14. FileSegment,
  15. FloatSegment,
  16. IntegerSegment,
  17. NoneSegment,
  18. ObjectSegment,
  19. Segment,
  20. StringSegment,
  21. )
  22. from core.variables.types import SegmentType
  23. from core.variables.variables import (
  24. ArrayAnyVariable,
  25. ArrayFileVariable,
  26. ArrayNumberVariable,
  27. ArrayObjectVariable,
  28. ArrayStringVariable,
  29. FileVariable,
  30. FloatVariable,
  31. IntegerVariable,
  32. NoneVariable,
  33. ObjectVariable,
  34. SecretVariable,
  35. StringVariable,
  36. Variable,
  37. )
  38. from core.workflow.constants import (
  39. CONVERSATION_VARIABLE_NODE_ID,
  40. ENVIRONMENT_VARIABLE_NODE_ID,
  41. )
  42. class UnsupportedSegmentTypeError(Exception):
  43. pass
  44. class TypeMismatchError(Exception):
  45. pass
  46. # Define the constant
  47. SEGMENT_TO_VARIABLE_MAP = {
  48. StringSegment: StringVariable,
  49. IntegerSegment: IntegerVariable,
  50. FloatSegment: FloatVariable,
  51. ObjectSegment: ObjectVariable,
  52. FileSegment: FileVariable,
  53. ArrayStringSegment: ArrayStringVariable,
  54. ArrayNumberSegment: ArrayNumberVariable,
  55. ArrayObjectSegment: ArrayObjectVariable,
  56. ArrayFileSegment: ArrayFileVariable,
  57. ArrayAnySegment: ArrayAnyVariable,
  58. NoneSegment: NoneVariable,
  59. }
  60. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  61. if not mapping.get("name"):
  62. raise VariableError("missing name")
  63. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  64. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  65. if not mapping.get("name"):
  66. raise VariableError("missing name")
  67. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  68. def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  69. if not mapping.get("variable"):
  70. raise VariableError("missing variable")
  71. return mapping["variable"]
  72. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  73. """
  74. This factory function is used to create the environment variable or the conversation variable,
  75. not support the File type.
  76. """
  77. if (value_type := mapping.get("value_type")) is None:
  78. raise VariableError("missing value type")
  79. if (value := mapping.get("value")) is None:
  80. raise VariableError("missing value")
  81. result: Variable
  82. match value_type:
  83. case SegmentType.STRING:
  84. result = StringVariable.model_validate(mapping)
  85. case SegmentType.SECRET:
  86. result = SecretVariable.model_validate(mapping)
  87. case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
  88. mapping = dict(mapping)
  89. mapping["value_type"] = SegmentType.INTEGER
  90. result = IntegerVariable.model_validate(mapping)
  91. case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
  92. mapping = dict(mapping)
  93. mapping["value_type"] = SegmentType.FLOAT
  94. result = FloatVariable.model_validate(mapping)
  95. case SegmentType.NUMBER if not isinstance(value, float | int):
  96. raise VariableError(f"invalid number value {value}")
  97. case SegmentType.OBJECT if isinstance(value, dict):
  98. result = ObjectVariable.model_validate(mapping)
  99. case SegmentType.ARRAY_STRING if isinstance(value, list):
  100. result = ArrayStringVariable.model_validate(mapping)
  101. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  102. result = ArrayNumberVariable.model_validate(mapping)
  103. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  104. result = ArrayObjectVariable.model_validate(mapping)
  105. case _:
  106. raise VariableError(f"not supported value type {value_type}")
  107. if result.size > dify_config.MAX_VARIABLE_SIZE:
  108. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  109. if not result.selector:
  110. result = result.model_copy(update={"selector": selector})
  111. return cast(Variable, result)
  112. def infer_segment_type_from_value(value: Any, /) -> SegmentType:
  113. return build_segment(value).value_type
  114. def build_segment(value: Any, /) -> Segment:
  115. # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
  116. # below
  117. if value is None:
  118. return NoneSegment()
  119. if isinstance(value, str):
  120. return StringSegment(value=value)
  121. if isinstance(value, int):
  122. return IntegerSegment(value=value)
  123. if isinstance(value, float):
  124. return FloatSegment(value=value)
  125. if isinstance(value, dict):
  126. return ObjectSegment(value=value)
  127. if isinstance(value, File):
  128. return FileSegment(value=value)
  129. if isinstance(value, list):
  130. items = [build_segment(item) for item in value]
  131. types = {item.value_type for item in items}
  132. if all(isinstance(item, ArraySegment) for item in items):
  133. return ArrayAnySegment(value=value)
  134. elif len(types) != 1:
  135. if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
  136. return ArrayNumberSegment(value=value)
  137. return ArrayAnySegment(value=value)
  138. match types.pop():
  139. case SegmentType.STRING:
  140. return ArrayStringSegment(value=value)
  141. case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
  142. return ArrayNumberSegment(value=value)
  143. case SegmentType.OBJECT:
  144. return ArrayObjectSegment(value=value)
  145. case SegmentType.FILE:
  146. return ArrayFileSegment(value=value)
  147. case SegmentType.NONE:
  148. return ArrayAnySegment(value=value)
  149. case _:
  150. # This should be unreachable.
  151. raise ValueError(f"not supported value {value}")
  152. raise ValueError(f"not supported value {value}")
  153. _segment_factory: Mapping[SegmentType, type[Segment]] = {
  154. SegmentType.NONE: NoneSegment,
  155. SegmentType.STRING: StringSegment,
  156. SegmentType.INTEGER: IntegerSegment,
  157. SegmentType.FLOAT: FloatSegment,
  158. SegmentType.FILE: FileSegment,
  159. SegmentType.OBJECT: ObjectSegment,
  160. # Array types
  161. SegmentType.ARRAY_ANY: ArrayAnySegment,
  162. SegmentType.ARRAY_STRING: ArrayStringSegment,
  163. SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
  164. SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
  165. SegmentType.ARRAY_FILE: ArrayFileSegment,
  166. }
  167. def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
  168. """
  169. Build a segment with explicit type checking.
  170. This function creates a segment from a value while enforcing type compatibility
  171. with the specified segment_type. It provides stricter type validation compared
  172. to the standard build_segment function.
  173. Args:
  174. segment_type: The expected SegmentType for the resulting segment
  175. value: The value to be converted into a segment
  176. Returns:
  177. Segment: A segment instance of the appropriate type
  178. Raises:
  179. TypeMismatchError: If the value type doesn't match the expected segment_type
  180. Special Cases:
  181. - For empty list [] values, if segment_type is array[*], returns the corresponding array type
  182. - Type validation is performed before segment creation
  183. Examples:
  184. >>> build_segment_with_type(SegmentType.STRING, "hello")
  185. StringSegment(value="hello")
  186. >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
  187. ArrayStringSegment(value=[])
  188. >>> build_segment_with_type(SegmentType.STRING, 123)
  189. # Raises TypeMismatchError
  190. """
  191. # Handle None values
  192. if value is None:
  193. if segment_type == SegmentType.NONE:
  194. return NoneSegment()
  195. else:
  196. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
  197. # Handle empty list special case for array types
  198. if isinstance(value, list) and len(value) == 0:
  199. if segment_type == SegmentType.ARRAY_ANY:
  200. return ArrayAnySegment(value=value)
  201. elif segment_type == SegmentType.ARRAY_STRING:
  202. return ArrayStringSegment(value=value)
  203. elif segment_type == SegmentType.ARRAY_NUMBER:
  204. return ArrayNumberSegment(value=value)
  205. elif segment_type == SegmentType.ARRAY_OBJECT:
  206. return ArrayObjectSegment(value=value)
  207. elif segment_type == SegmentType.ARRAY_FILE:
  208. return ArrayFileSegment(value=value)
  209. else:
  210. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
  211. inferred_type = SegmentType.infer_segment_type(value)
  212. # Type compatibility checking
  213. if inferred_type is None:
  214. raise TypeMismatchError(
  215. f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
  216. )
  217. if inferred_type == segment_type:
  218. segment_class = _segment_factory[segment_type]
  219. return segment_class(value_type=segment_type, value=value)
  220. elif segment_type == SegmentType.NUMBER and inferred_type in (
  221. SegmentType.INTEGER,
  222. SegmentType.FLOAT,
  223. ):
  224. segment_class = _segment_factory[inferred_type]
  225. return segment_class(value_type=inferred_type, value=value)
  226. else:
  227. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
  228. def segment_to_variable(
  229. *,
  230. segment: Segment,
  231. selector: Sequence[str],
  232. id: str | None = None,
  233. name: str | None = None,
  234. description: str = "",
  235. ) -> Variable:
  236. if isinstance(segment, Variable):
  237. return segment
  238. name = name or selector[-1]
  239. id = id or str(uuid4())
  240. segment_type = type(segment)
  241. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  242. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  243. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  244. return cast(
  245. Variable,
  246. variable_class(
  247. id=id,
  248. name=name,
  249. description=description,
  250. value=segment.value,
  251. selector=list(selector),
  252. ),
  253. )