Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

variable_factory.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayBooleanSegment,
  10. ArrayFileSegment,
  11. ArrayNumberSegment,
  12. ArrayObjectSegment,
  13. ArraySegment,
  14. ArrayStringSegment,
  15. BooleanSegment,
  16. FileSegment,
  17. FloatSegment,
  18. IntegerSegment,
  19. NoneSegment,
  20. ObjectSegment,
  21. Segment,
  22. StringSegment,
  23. )
  24. from core.variables.types import SegmentType
  25. from core.variables.variables import (
  26. ArrayAnyVariable,
  27. ArrayBooleanVariable,
  28. ArrayFileVariable,
  29. ArrayNumberVariable,
  30. ArrayObjectVariable,
  31. ArrayStringVariable,
  32. BooleanVariable,
  33. FileVariable,
  34. FloatVariable,
  35. IntegerVariable,
  36. NoneVariable,
  37. ObjectVariable,
  38. SecretVariable,
  39. StringVariable,
  40. Variable,
  41. )
  42. from core.workflow.constants import (
  43. CONVERSATION_VARIABLE_NODE_ID,
  44. ENVIRONMENT_VARIABLE_NODE_ID,
  45. )
  46. class UnsupportedSegmentTypeError(Exception):
  47. pass
  48. class TypeMismatchError(Exception):
  49. pass
  50. # Define the constant
  51. SEGMENT_TO_VARIABLE_MAP = {
  52. ArrayAnySegment: ArrayAnyVariable,
  53. ArrayBooleanSegment: ArrayBooleanVariable,
  54. ArrayFileSegment: ArrayFileVariable,
  55. ArrayNumberSegment: ArrayNumberVariable,
  56. ArrayObjectSegment: ArrayObjectVariable,
  57. ArrayStringSegment: ArrayStringVariable,
  58. BooleanSegment: BooleanVariable,
  59. FileSegment: FileVariable,
  60. FloatSegment: FloatVariable,
  61. IntegerSegment: IntegerVariable,
  62. NoneSegment: NoneVariable,
  63. ObjectSegment: ObjectVariable,
  64. StringSegment: StringVariable,
  65. }
  66. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  67. if not mapping.get("name"):
  68. raise VariableError("missing name")
  69. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  70. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  71. if not mapping.get("name"):
  72. raise VariableError("missing name")
  73. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  74. def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  75. if not mapping.get("variable"):
  76. raise VariableError("missing variable")
  77. return mapping["variable"]
  78. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  79. """
  80. This factory function is used to create the environment variable or the conversation variable,
  81. not support the File type.
  82. """
  83. if (value_type := mapping.get("value_type")) is None:
  84. raise VariableError("missing value type")
  85. if (value := mapping.get("value")) is None:
  86. raise VariableError("missing value")
  87. result: Variable
  88. match value_type:
  89. case SegmentType.STRING:
  90. result = StringVariable.model_validate(mapping)
  91. case SegmentType.SECRET:
  92. result = SecretVariable.model_validate(mapping)
  93. case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
  94. mapping = dict(mapping)
  95. mapping["value_type"] = SegmentType.INTEGER
  96. result = IntegerVariable.model_validate(mapping)
  97. case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
  98. mapping = dict(mapping)
  99. mapping["value_type"] = SegmentType.FLOAT
  100. result = FloatVariable.model_validate(mapping)
  101. case SegmentType.BOOLEAN:
  102. result = BooleanVariable.model_validate(mapping)
  103. case SegmentType.NUMBER if not isinstance(value, float | int):
  104. raise VariableError(f"invalid number value {value}")
  105. case SegmentType.OBJECT if isinstance(value, dict):
  106. result = ObjectVariable.model_validate(mapping)
  107. case SegmentType.ARRAY_STRING if isinstance(value, list):
  108. result = ArrayStringVariable.model_validate(mapping)
  109. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  110. result = ArrayNumberVariable.model_validate(mapping)
  111. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  112. result = ArrayObjectVariable.model_validate(mapping)
  113. case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
  114. result = ArrayBooleanVariable.model_validate(mapping)
  115. case _:
  116. raise VariableError(f"not supported value type {value_type}")
  117. if result.size > dify_config.MAX_VARIABLE_SIZE:
  118. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  119. if not result.selector:
  120. result = result.model_copy(update={"selector": selector})
  121. return cast(Variable, result)
  122. def build_segment(value: Any, /) -> Segment:
  123. # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
  124. # below
  125. if value is None:
  126. return NoneSegment()
  127. if isinstance(value, Segment):
  128. return value
  129. if isinstance(value, str):
  130. return StringSegment(value=value)
  131. if isinstance(value, bool):
  132. return BooleanSegment(value=value)
  133. if isinstance(value, int):
  134. return IntegerSegment(value=value)
  135. if isinstance(value, float):
  136. return FloatSegment(value=value)
  137. if isinstance(value, dict):
  138. return ObjectSegment(value=value)
  139. if isinstance(value, File):
  140. return FileSegment(value=value)
  141. if isinstance(value, list):
  142. items = [build_segment(item) for item in value]
  143. types = {item.value_type for item in items}
  144. if all(isinstance(item, ArraySegment) for item in items):
  145. return ArrayAnySegment(value=value)
  146. elif len(types) != 1:
  147. if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
  148. return ArrayNumberSegment(value=value)
  149. return ArrayAnySegment(value=value)
  150. match types.pop():
  151. case SegmentType.STRING:
  152. return ArrayStringSegment(value=value)
  153. case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
  154. return ArrayNumberSegment(value=value)
  155. case SegmentType.BOOLEAN:
  156. return ArrayBooleanSegment(value=value)
  157. case SegmentType.OBJECT:
  158. return ArrayObjectSegment(value=value)
  159. case SegmentType.FILE:
  160. return ArrayFileSegment(value=value)
  161. case SegmentType.NONE:
  162. return ArrayAnySegment(value=value)
  163. case _:
  164. # This should be unreachable.
  165. raise ValueError(f"not supported value {value}")
  166. raise ValueError(f"not supported value {value}")
  167. _segment_factory: Mapping[SegmentType, type[Segment]] = {
  168. SegmentType.NONE: NoneSegment,
  169. SegmentType.STRING: StringSegment,
  170. SegmentType.INTEGER: IntegerSegment,
  171. SegmentType.FLOAT: FloatSegment,
  172. SegmentType.FILE: FileSegment,
  173. SegmentType.BOOLEAN: BooleanSegment,
  174. SegmentType.OBJECT: ObjectSegment,
  175. # Array types
  176. SegmentType.ARRAY_ANY: ArrayAnySegment,
  177. SegmentType.ARRAY_STRING: ArrayStringSegment,
  178. SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
  179. SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
  180. SegmentType.ARRAY_FILE: ArrayFileSegment,
  181. SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
  182. }
  183. def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
  184. """
  185. Build a segment with explicit type checking.
  186. This function creates a segment from a value while enforcing type compatibility
  187. with the specified segment_type. It provides stricter type validation compared
  188. to the standard build_segment function.
  189. Args:
  190. segment_type: The expected SegmentType for the resulting segment
  191. value: The value to be converted into a segment
  192. Returns:
  193. Segment: A segment instance of the appropriate type
  194. Raises:
  195. TypeMismatchError: If the value type doesn't match the expected segment_type
  196. Special Cases:
  197. - For empty list [] values, if segment_type is array[*], returns the corresponding array type
  198. - Type validation is performed before segment creation
  199. Examples:
  200. >>> build_segment_with_type(SegmentType.STRING, "hello")
  201. StringSegment(value="hello")
  202. >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
  203. ArrayStringSegment(value=[])
  204. >>> build_segment_with_type(SegmentType.STRING, 123)
  205. # Raises TypeMismatchError
  206. """
  207. # Handle None values
  208. if value is None:
  209. if segment_type == SegmentType.NONE:
  210. return NoneSegment()
  211. else:
  212. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
  213. # Handle empty list special case for array types
  214. if isinstance(value, list) and len(value) == 0:
  215. if segment_type == SegmentType.ARRAY_ANY:
  216. return ArrayAnySegment(value=value)
  217. elif segment_type == SegmentType.ARRAY_STRING:
  218. return ArrayStringSegment(value=value)
  219. elif segment_type == SegmentType.ARRAY_BOOLEAN:
  220. return ArrayBooleanSegment(value=value)
  221. elif segment_type == SegmentType.ARRAY_NUMBER:
  222. return ArrayNumberSegment(value=value)
  223. elif segment_type == SegmentType.ARRAY_OBJECT:
  224. return ArrayObjectSegment(value=value)
  225. elif segment_type == SegmentType.ARRAY_FILE:
  226. return ArrayFileSegment(value=value)
  227. else:
  228. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
  229. inferred_type = SegmentType.infer_segment_type(value)
  230. # Type compatibility checking
  231. if inferred_type is None:
  232. raise TypeMismatchError(
  233. f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
  234. )
  235. if inferred_type == segment_type:
  236. segment_class = _segment_factory[segment_type]
  237. return segment_class(value_type=segment_type, value=value)
  238. elif segment_type == SegmentType.NUMBER and inferred_type in (
  239. SegmentType.INTEGER,
  240. SegmentType.FLOAT,
  241. ):
  242. segment_class = _segment_factory[inferred_type]
  243. return segment_class(value_type=inferred_type, value=value)
  244. else:
  245. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
  246. def segment_to_variable(
  247. *,
  248. segment: Segment,
  249. selector: Sequence[str],
  250. id: str | None = None,
  251. name: str | None = None,
  252. description: str = "",
  253. ) -> Variable:
  254. if isinstance(segment, Variable):
  255. return segment
  256. name = name or selector[-1]
  257. id = id or str(uuid4())
  258. segment_type = type(segment)
  259. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  260. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  261. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  262. return cast(
  263. Variable,
  264. variable_class(
  265. id=id,
  266. name=name,
  267. description=description,
  268. value=segment.value,
  269. selector=list(selector),
  270. ),
  271. )