Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

variable_factory.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayBooleanSegment,
  10. ArrayFileSegment,
  11. ArrayNumberSegment,
  12. ArrayObjectSegment,
  13. ArraySegment,
  14. ArrayStringSegment,
  15. BooleanSegment,
  16. FileSegment,
  17. FloatSegment,
  18. IntegerSegment,
  19. NoneSegment,
  20. ObjectSegment,
  21. Segment,
  22. StringSegment,
  23. )
  24. from core.variables.types import SegmentType
  25. from core.variables.variables import (
  26. ArrayAnyVariable,
  27. ArrayBooleanVariable,
  28. ArrayFileVariable,
  29. ArrayNumberVariable,
  30. ArrayObjectVariable,
  31. ArrayStringVariable,
  32. BooleanVariable,
  33. FileVariable,
  34. FloatVariable,
  35. IntegerVariable,
  36. NoneVariable,
  37. ObjectVariable,
  38. SecretVariable,
  39. StringVariable,
  40. Variable,
  41. )
  42. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
  43. class UnsupportedSegmentTypeError(Exception):
  44. pass
  45. class TypeMismatchError(Exception):
  46. pass
  47. # Define the constant
  48. SEGMENT_TO_VARIABLE_MAP = {
  49. ArrayAnySegment: ArrayAnyVariable,
  50. ArrayBooleanSegment: ArrayBooleanVariable,
  51. ArrayFileSegment: ArrayFileVariable,
  52. ArrayNumberSegment: ArrayNumberVariable,
  53. ArrayObjectSegment: ArrayObjectVariable,
  54. ArrayStringSegment: ArrayStringVariable,
  55. BooleanSegment: BooleanVariable,
  56. FileSegment: FileVariable,
  57. FloatSegment: FloatVariable,
  58. IntegerSegment: IntegerVariable,
  59. NoneSegment: NoneVariable,
  60. ObjectSegment: ObjectVariable,
  61. StringSegment: StringVariable,
  62. }
  63. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  64. if not mapping.get("name"):
  65. raise VariableError("missing name")
  66. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  67. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  68. if not mapping.get("name"):
  69. raise VariableError("missing name")
  70. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  71. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  72. """
  73. This factory function is used to create the environment variable or the conversation variable,
  74. not support the File type.
  75. """
  76. if (value_type := mapping.get("value_type")) is None:
  77. raise VariableError("missing value type")
  78. if (value := mapping.get("value")) is None:
  79. raise VariableError("missing value")
  80. result: Variable
  81. match value_type:
  82. case SegmentType.STRING:
  83. result = StringVariable.model_validate(mapping)
  84. case SegmentType.SECRET:
  85. result = SecretVariable.model_validate(mapping)
  86. case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
  87. mapping = dict(mapping)
  88. mapping["value_type"] = SegmentType.INTEGER
  89. result = IntegerVariable.model_validate(mapping)
  90. case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
  91. mapping = dict(mapping)
  92. mapping["value_type"] = SegmentType.FLOAT
  93. result = FloatVariable.model_validate(mapping)
  94. case SegmentType.BOOLEAN:
  95. result = BooleanVariable.model_validate(mapping)
  96. case SegmentType.NUMBER if not isinstance(value, float | int):
  97. raise VariableError(f"invalid number value {value}")
  98. case SegmentType.OBJECT if isinstance(value, dict):
  99. result = ObjectVariable.model_validate(mapping)
  100. case SegmentType.ARRAY_STRING if isinstance(value, list):
  101. result = ArrayStringVariable.model_validate(mapping)
  102. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  103. result = ArrayNumberVariable.model_validate(mapping)
  104. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  105. result = ArrayObjectVariable.model_validate(mapping)
  106. case SegmentType.ARRAY_BOOLEAN if isinstance(value, list):
  107. result = ArrayBooleanVariable.model_validate(mapping)
  108. case _:
  109. raise VariableError(f"not supported value type {value_type}")
  110. if result.size > dify_config.MAX_VARIABLE_SIZE:
  111. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  112. if not result.selector:
  113. result = result.model_copy(update={"selector": selector})
  114. return cast(Variable, result)
  115. def build_segment(value: Any, /) -> Segment:
  116. # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
  117. # below
  118. if value is None:
  119. return NoneSegment()
  120. if isinstance(value, str):
  121. return StringSegment(value=value)
  122. if isinstance(value, bool):
  123. return BooleanSegment(value=value)
  124. if isinstance(value, int):
  125. return IntegerSegment(value=value)
  126. if isinstance(value, float):
  127. return FloatSegment(value=value)
  128. if isinstance(value, dict):
  129. return ObjectSegment(value=value)
  130. if isinstance(value, File):
  131. return FileSegment(value=value)
  132. if isinstance(value, list):
  133. items = [build_segment(item) for item in value]
  134. types = {item.value_type for item in items}
  135. if all(isinstance(item, ArraySegment) for item in items):
  136. return ArrayAnySegment(value=value)
  137. elif len(types) != 1:
  138. if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
  139. return ArrayNumberSegment(value=value)
  140. return ArrayAnySegment(value=value)
  141. match types.pop():
  142. case SegmentType.STRING:
  143. return ArrayStringSegment(value=value)
  144. case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
  145. return ArrayNumberSegment(value=value)
  146. case SegmentType.BOOLEAN:
  147. return ArrayBooleanSegment(value=value)
  148. case SegmentType.OBJECT:
  149. return ArrayObjectSegment(value=value)
  150. case SegmentType.FILE:
  151. return ArrayFileSegment(value=value)
  152. case SegmentType.NONE:
  153. return ArrayAnySegment(value=value)
  154. case _:
  155. # This should be unreachable.
  156. raise ValueError(f"not supported value {value}")
  157. raise ValueError(f"not supported value {value}")
  158. _segment_factory: Mapping[SegmentType, type[Segment]] = {
  159. SegmentType.NONE: NoneSegment,
  160. SegmentType.STRING: StringSegment,
  161. SegmentType.INTEGER: IntegerSegment,
  162. SegmentType.FLOAT: FloatSegment,
  163. SegmentType.FILE: FileSegment,
  164. SegmentType.BOOLEAN: BooleanSegment,
  165. SegmentType.OBJECT: ObjectSegment,
  166. # Array types
  167. SegmentType.ARRAY_ANY: ArrayAnySegment,
  168. SegmentType.ARRAY_STRING: ArrayStringSegment,
  169. SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
  170. SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
  171. SegmentType.ARRAY_FILE: ArrayFileSegment,
  172. SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
  173. }
  174. def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
  175. """
  176. Build a segment with explicit type checking.
  177. This function creates a segment from a value while enforcing type compatibility
  178. with the specified segment_type. It provides stricter type validation compared
  179. to the standard build_segment function.
  180. Args:
  181. segment_type: The expected SegmentType for the resulting segment
  182. value: The value to be converted into a segment
  183. Returns:
  184. Segment: A segment instance of the appropriate type
  185. Raises:
  186. TypeMismatchError: If the value type doesn't match the expected segment_type
  187. Special Cases:
  188. - For empty list [] values, if segment_type is array[*], returns the corresponding array type
  189. - Type validation is performed before segment creation
  190. Examples:
  191. >>> build_segment_with_type(SegmentType.STRING, "hello")
  192. StringSegment(value="hello")
  193. >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
  194. ArrayStringSegment(value=[])
  195. >>> build_segment_with_type(SegmentType.STRING, 123)
  196. # Raises TypeMismatchError
  197. """
  198. # Handle None values
  199. if value is None:
  200. if segment_type == SegmentType.NONE:
  201. return NoneSegment()
  202. else:
  203. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
  204. # Handle empty list special case for array types
  205. if isinstance(value, list) and len(value) == 0:
  206. if segment_type == SegmentType.ARRAY_ANY:
  207. return ArrayAnySegment(value=value)
  208. elif segment_type == SegmentType.ARRAY_STRING:
  209. return ArrayStringSegment(value=value)
  210. elif segment_type == SegmentType.ARRAY_BOOLEAN:
  211. return ArrayBooleanSegment(value=value)
  212. elif segment_type == SegmentType.ARRAY_NUMBER:
  213. return ArrayNumberSegment(value=value)
  214. elif segment_type == SegmentType.ARRAY_OBJECT:
  215. return ArrayObjectSegment(value=value)
  216. elif segment_type == SegmentType.ARRAY_FILE:
  217. return ArrayFileSegment(value=value)
  218. else:
  219. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
  220. inferred_type = SegmentType.infer_segment_type(value)
  221. # Type compatibility checking
  222. if inferred_type is None:
  223. raise TypeMismatchError(
  224. f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
  225. )
  226. if inferred_type == segment_type:
  227. segment_class = _segment_factory[segment_type]
  228. return segment_class(value_type=segment_type, value=value)
  229. elif segment_type == SegmentType.NUMBER and inferred_type in (
  230. SegmentType.INTEGER,
  231. SegmentType.FLOAT,
  232. ):
  233. segment_class = _segment_factory[inferred_type]
  234. return segment_class(value_type=inferred_type, value=value)
  235. else:
  236. raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
  237. def segment_to_variable(
  238. *,
  239. segment: Segment,
  240. selector: Sequence[str],
  241. id: str | None = None,
  242. name: str | None = None,
  243. description: str = "",
  244. ) -> Variable:
  245. if isinstance(segment, Variable):
  246. return segment
  247. name = name or selector[-1]
  248. id = id or str(uuid4())
  249. segment_type = type(segment)
  250. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  251. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  252. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  253. return cast(
  254. Variable,
  255. variable_class(
  256. id=id,
  257. name=name,
  258. description=description,
  259. value=segment.value,
  260. selector=list(selector),
  261. ),
  262. )