You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_factory.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayBooleanSegment,
  10. ArrayFileSegment,
  11. ArrayNumberSegment,
  12. ArrayObjectSegment,
  13. ArraySegment,
  14. ArrayStringSegment,
  15. BooleanSegment,
  16. FileSegment,
  17. FloatSegment,
  18. IntegerSegment,
  19. NoneSegment,
  20. ObjectSegment,
  21. Segment,
  22. StringSegment,
  23. )
  24. from core.variables.types import SegmentType
  25. from core.variables.variables import (
  26. ArrayAnyVariable,
  27. ArrayBooleanVariable,
  28. ArrayFileVariable,
  29. ArrayNumberVariable,
  30. ArrayObjectVariable,
  31. ArrayStringVariable,
  32. BooleanVariable,
  33. FileVariable,
  34. FloatVariable,
  35. IntegerVariable,
  36. NoneVariable,
  37. ObjectVariable,
  38. SecretVariable,
  39. StringVariable,
  40. Variable,
  41. )
  42. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
  43. class UnsupportedSegmentTypeError(Exception):
  44. pass
  45. class TypeMismatchError(Exception):
  46. pass
  47. # Define the constant
  48. SEGMENT_TO_VARIABLE_MAP = {
  49. ArrayAnySegment: ArrayAnyVariable,
  50. ArrayBooleanSegment: ArrayBooleanVariable,
  51. ArrayFileSegment: ArrayFileVariable,
  52. ArrayNumberSegment: ArrayNumberVariable,
  53. ArrayObjectSegment: ArrayObjectVariable,
  54. ArrayStringSegment: ArrayStringVariable,
  55. BooleanSegment: BooleanVariable,
  56. FileSegment: FileVariable,
  57. FloatSegment: FloatVariable,
  58. IntegerSegment: IntegerVariable,
  59. NoneSegment: NoneVariable,
  60. ObjectSegment: ObjectVariable,
  61. StringSegment: StringVariable,
  62. }
  63. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  64. if not mapping.get("name"):
  65. raise VariableError("missing name")
  66. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  67. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  68. if not mapping.get("name"):
  69. raise VariableError("missing name")
  70. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  71. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  72. """
  73. This factory function is used to create the environment variable or the conversation variable,
  74. not support the File type.
  75. """
  76. if (value_type := mapping.get("value_type")) is None:
  77. raise VariableError("missing value type")
  78. if (value := mapping.get("value")) is None:
  79. raise VariableError("missing value")
  80. result: Variable
  81. match value_type:
  82. case SegmentType.STRING:
  83. result = StringVariable.model_validate(mapping)
  84. case SegmentType.SECRET:
  85. result = SecretVariable.model_validate(mapping)
  86. case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
  87. mapping = dict(mapping)
  88. mapping["value_type"] = SegmentType.INTEGER
  89. result = IntegerVariable.model_validate(mapping)
  90. case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
  91. mapping = dict(mapping)
  92. mapping["value_type"] = SegmentType.FLOAT
  93. result = FloatVariable.model_validate(mapping)
  94. case SegmentType.BOOLEAN:
  95. result = BooleanVariable.model_validate(mapping)
  96. case SegmentType.NUMBER if not isinstance(value, float | int):
  97. raise VariableError(f"invalid number value {value}")
  98. case SegmentType.OBJECT if isinstance(value, dict):
  99. result = ObjectVariable.model_validate(mapping)
  100. case SegmentType.ARRAY_STRING if isinstance(value, list):
  101. result = ArrayStringVariable.model_validate(mapping)
  102. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  103. result = ArrayNumberVariable.model_validate(mapping)
  104. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  105. result = ArrayObjectVariable.model_validate(mapping)
  106. case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
  107. result = ArrayBooleanVariable.model_validate(mapping)
  108. case _:
  109. raise VariableError(f"not supported value type {value_type}")
  110. if result.size > dify_config.MAX_VARIABLE_SIZE:
  111. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  112. if not result.selector:
  113. result = result.model_copy(update={"selector": selector})
  114. return cast(Variable, result)
  115. def infer_segment_type_from_value(value: Any, /) -> SegmentType:
  116. return build_segment(value).value_type
  117. def build_segment(value: Any, /) -> Segment:
  118. # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
  119. # below
  120. if value is None:
  121. return NoneSegment()
  122. if isinstance(value, str):
  123. return StringSegment(value=value)
  124. if isinstance(value, bool):
  125. return BooleanSegment(value=value)
  126. if isinstance(value, int):
  127. return IntegerSegment(value=value)
  128. if isinstance(value, float):
  129. return FloatSegment(value=value)
  130. if isinstance(value, dict):
  131. return ObjectSegment(value=value)
  132. if isinstance(value, File):
  133. return FileSegment(value=value)
  134. if isinstance(value, list):
  135. items = [build_segment(item) for item in value]
  136. types = {item.value_type for item in items}
  137. if all(isinstance(item, ArraySegment) for item in items):
  138. return ArrayAnySegment(value=value)
  139. elif len(types) != 1:
  140. if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
  141. return ArrayNumberSegment(value=value)
  142. return ArrayAnySegment(value=value)
  143. match types.pop():
  144. case SegmentType.STRING:
  145. return ArrayStringSegment(value=value)
  146. case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
  147. return ArrayNumberSegment(value=value)
  148. case SegmentType.BOOLEAN:
  149. return ArrayBooleanSegment(value=value)
  150. case SegmentType.OBJECT:
  151. return ArrayObjectSegment(value=value)
  152. case SegmentType.FILE:
  153. return ArrayFileSegment(value=value)
  154. case SegmentType.NONE:
  155. return ArrayAnySegment(value=value)
  156. case _:
  157. # This should be unreachable.
  158. raise ValueError(f"not supported value {value}")
  159. raise ValueError(f"not supported value {value}")
  160. _segment_factory: Mapping[SegmentType, type[Segment]] = {
  161. SegmentType.NONE: NoneSegment,
  162. SegmentType.STRING: StringSegment,
  163. SegmentType.INTEGER: IntegerSegment,
  164. SegmentType.FLOAT: FloatSegment,
  165. SegmentType.FILE: FileSegment,
  166. SegmentType.BOOLEAN: BooleanSegment,
  167. SegmentType.OBJECT: ObjectSegment,
  168. # Array types
  169. SegmentType.ARRAY_ANY: ArrayAnySegment,
  170. SegmentType.ARRAY_STRING: ArrayStringSegment,
  171. SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
  172. SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
  173. SegmentType.ARRAY_FILE: ArrayFileSegment,
  174. SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
  175. }
  176. def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
  177. """
  178. Build a segment with explicit type checking.
  179. This function creates a segment from a value while enforcing type compatibility
  180. with the specified segment_type. It provides stricter type validation compared
  181. to the standard build_segment function.
  182. Args:
  183. segment_type: The expected SegmentType for the resulting segment
  184. value: The value to be converted into a segment
  185. Returns:
  186. Segment: A segment instance of the appropriate type
  187. Raises:
  188. TypeMismatchError: If the value type doesn't match the expected segment_type
  189. Special Cases:
  190. - For empty list [] values, if segment_type is array[*], returns the corresponding array type
  191. - Type validation is performed before segment creation
  192. Examples:
  193. >>> build_segment_with_type(SegmentType.STRING, "hello")
  194. StringSegment(value="hello")
  195. >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
  196. ArrayStringSegment(value=[])
  197. >>> build_segment_with_type(SegmentType.STRING, 123)
  198. # Raises TypeMismatchError
  199. """
  200. # Handle None values
  201. if value is None:
  202. if segment_type == SegmentType.NONE:
  203. return NoneSegment()
  204. else:
  205. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
  206. # Handle empty list special case for array types
  207. if isinstance(value, list) and len(value) == 0:
  208. if segment_type == SegmentType.ARRAY_ANY:
  209. return ArrayAnySegment(value=value)
  210. elif segment_type == SegmentType.ARRAY_STRING:
  211. return ArrayStringSegment(value=value)
  212. elif segment_type == SegmentType.ARRAY_BOOLEAN:
  213. return ArrayBooleanSegment(value=value)
  214. elif segment_type == SegmentType.ARRAY_NUMBER:
  215. return ArrayNumberSegment(value=value)
  216. elif segment_type == SegmentType.ARRAY_OBJECT:
  217. return ArrayObjectSegment(value=value)
  218. elif segment_type == SegmentType.ARRAY_FILE:
  219. return ArrayFileSegment(value=value)
  220. else:
  221. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
  222. inferred_type = SegmentType.infer_segment_type(value)
  223. # Type compatibility checking
  224. if inferred_type is None:
  225. raise TypeMismatchError(
  226. f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
  227. )
  228. if inferred_type == segment_type:
  229. segment_class = _segment_factory[segment_type]
  230. return segment_class(value_type=segment_type, value=value)
  231. elif segment_type == SegmentType.NUMBER and inferred_type in (
  232. SegmentType.INTEGER,
  233. SegmentType.FLOAT,
  234. ):
  235. segment_class = _segment_factory[inferred_type]
  236. return segment_class(value_type=inferred_type, value=value)
  237. else:
  238. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
  239. def segment_to_variable(
  240. *,
  241. segment: Segment,
  242. selector: Sequence[str],
  243. id: str | None = None,
  244. name: str | None = None,
  245. description: str = "",
  246. ) -> Variable:
  247. if isinstance(segment, Variable):
  248. return segment
  249. name = name or selector[-1]
  250. id = id or str(uuid4())
  251. segment_type = type(segment)
  252. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  253. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  254. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  255. return cast(
  256. Variable,
  257. variable_class(
  258. id=id,
  259. name=name,
  260. description=description,
  261. value=segment.value,
  262. selector=list(selector),
  263. ),
  264. )