Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

variable_factory.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayFileSegment,
  10. ArrayNumberSegment,
  11. ArrayObjectSegment,
  12. ArraySegment,
  13. ArrayStringSegment,
  14. FileSegment,
  15. FloatSegment,
  16. IntegerSegment,
  17. NoneSegment,
  18. ObjectSegment,
  19. Segment,
  20. StringSegment,
  21. )
  22. from core.variables.types import SegmentType
  23. from core.variables.variables import (
  24. ArrayAnyVariable,
  25. ArrayFileVariable,
  26. ArrayNumberVariable,
  27. ArrayObjectVariable,
  28. ArrayStringVariable,
  29. FileVariable,
  30. FloatVariable,
  31. IntegerVariable,
  32. NoneVariable,
  33. ObjectVariable,
  34. SecretVariable,
  35. StringVariable,
  36. Variable,
  37. )
  38. from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
  39. class UnsupportedSegmentTypeError(Exception):
  40. pass
  41. class TypeMismatchError(Exception):
  42. pass
  43. # Define the constant
  44. SEGMENT_TO_VARIABLE_MAP = {
  45. StringSegment: StringVariable,
  46. IntegerSegment: IntegerVariable,
  47. FloatSegment: FloatVariable,
  48. ObjectSegment: ObjectVariable,
  49. FileSegment: FileVariable,
  50. ArrayStringSegment: ArrayStringVariable,
  51. ArrayNumberSegment: ArrayNumberVariable,
  52. ArrayObjectSegment: ArrayObjectVariable,
  53. ArrayFileSegment: ArrayFileVariable,
  54. ArrayAnySegment: ArrayAnyVariable,
  55. NoneSegment: NoneVariable,
  56. }
  57. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  58. if not mapping.get("name"):
  59. raise VariableError("missing name")
  60. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  61. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  62. if not mapping.get("name"):
  63. raise VariableError("missing name")
  64. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  65. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  66. """
  67. This factory function is used to create the environment variable or the conversation variable,
  68. not support the File type.
  69. """
  70. if (value_type := mapping.get("value_type")) is None:
  71. raise VariableError("missing value type")
  72. if (value := mapping.get("value")) is None:
  73. raise VariableError("missing value")
  74. result: Variable
  75. match value_type:
  76. case SegmentType.STRING:
  77. result = StringVariable.model_validate(mapping)
  78. case SegmentType.SECRET:
  79. result = SecretVariable.model_validate(mapping)
  80. case SegmentType.NUMBER if isinstance(value, int):
  81. result = IntegerVariable.model_validate(mapping)
  82. case SegmentType.NUMBER if isinstance(value, float):
  83. result = FloatVariable.model_validate(mapping)
  84. case SegmentType.NUMBER if not isinstance(value, float | int):
  85. raise VariableError(f"invalid number value {value}")
  86. case SegmentType.OBJECT if isinstance(value, dict):
  87. result = ObjectVariable.model_validate(mapping)
  88. case SegmentType.ARRAY_STRING if isinstance(value, list):
  89. result = ArrayStringVariable.model_validate(mapping)
  90. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  91. result = ArrayNumberVariable.model_validate(mapping)
  92. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  93. result = ArrayObjectVariable.model_validate(mapping)
  94. case _:
  95. raise VariableError(f"not supported value type {value_type}")
  96. if result.size > dify_config.MAX_VARIABLE_SIZE:
  97. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  98. if not result.selector:
  99. result = result.model_copy(update={"selector": selector})
  100. return cast(Variable, result)
  101. def infer_segment_type_from_value(value: Any, /) -> SegmentType:
  102. return build_segment(value).value_type
  103. def build_segment(value: Any, /) -> Segment:
  104. if value is None:
  105. return NoneSegment()
  106. if isinstance(value, str):
  107. return StringSegment(value=value)
  108. if isinstance(value, int):
  109. return IntegerSegment(value=value)
  110. if isinstance(value, float):
  111. return FloatSegment(value=value)
  112. if isinstance(value, dict):
  113. return ObjectSegment(value=value)
  114. if isinstance(value, File):
  115. return FileSegment(value=value)
  116. if isinstance(value, list):
  117. items = [build_segment(item) for item in value]
  118. types = {item.value_type for item in items}
  119. if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
  120. return ArrayAnySegment(value=value)
  121. match types.pop():
  122. case SegmentType.STRING:
  123. return ArrayStringSegment(value=value)
  124. case SegmentType.NUMBER:
  125. return ArrayNumberSegment(value=value)
  126. case SegmentType.OBJECT:
  127. return ArrayObjectSegment(value=value)
  128. case SegmentType.FILE:
  129. return ArrayFileSegment(value=value)
  130. case SegmentType.NONE:
  131. return ArrayAnySegment(value=value)
  132. case _:
  133. # This should be unreachable.
  134. raise ValueError(f"not supported value {value}")
  135. raise ValueError(f"not supported value {value}")
  136. def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
  137. """
  138. Build a segment with explicit type checking.
  139. This function creates a segment from a value while enforcing type compatibility
  140. with the specified segment_type. It provides stricter type validation compared
  141. to the standard build_segment function.
  142. Args:
  143. segment_type: The expected SegmentType for the resulting segment
  144. value: The value to be converted into a segment
  145. Returns:
  146. Segment: A segment instance of the appropriate type
  147. Raises:
  148. TypeMismatchError: If the value type doesn't match the expected segment_type
  149. Special Cases:
  150. - For empty list [] values, if segment_type is array[*], returns the corresponding array type
  151. - Type validation is performed before segment creation
  152. Examples:
  153. >>> build_segment_with_type(SegmentType.STRING, "hello")
  154. StringSegment(value="hello")
  155. >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
  156. ArrayStringSegment(value=[])
  157. >>> build_segment_with_type(SegmentType.STRING, 123)
  158. # Raises TypeMismatchError
  159. """
  160. # Handle None values
  161. if value is None:
  162. if segment_type == SegmentType.NONE:
  163. return NoneSegment()
  164. else:
  165. raise TypeMismatchError(f"Expected {segment_type}, but got None")
  166. # Handle empty list special case for array types
  167. if isinstance(value, list) and len(value) == 0:
  168. if segment_type == SegmentType.ARRAY_ANY:
  169. return ArrayAnySegment(value=value)
  170. elif segment_type == SegmentType.ARRAY_STRING:
  171. return ArrayStringSegment(value=value)
  172. elif segment_type == SegmentType.ARRAY_NUMBER:
  173. return ArrayNumberSegment(value=value)
  174. elif segment_type == SegmentType.ARRAY_OBJECT:
  175. return ArrayObjectSegment(value=value)
  176. elif segment_type == SegmentType.ARRAY_FILE:
  177. return ArrayFileSegment(value=value)
  178. else:
  179. raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
  180. # Build segment using existing logic to infer actual type
  181. inferred_segment = build_segment(value)
  182. inferred_type = inferred_segment.value_type
  183. # Type compatibility checking
  184. if inferred_type == segment_type:
  185. return inferred_segment
  186. # Type mismatch - raise error with descriptive message
  187. raise TypeMismatchError(
  188. f"Type mismatch: expected {segment_type}, but value '{value}' "
  189. f"(type: {type(value).__name__}) corresponds to {inferred_type}"
  190. )
  191. def segment_to_variable(
  192. *,
  193. segment: Segment,
  194. selector: Sequence[str],
  195. id: str | None = None,
  196. name: str | None = None,
  197. description: str = "",
  198. ) -> Variable:
  199. if isinstance(segment, Variable):
  200. return segment
  201. name = name or selector[-1]
  202. id = id or str(uuid4())
  203. segment_type = type(segment)
  204. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  205. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  206. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  207. return cast(
  208. Variable,
  209. variable_class(
  210. id=id,
  211. name=name,
  212. description=description,
  213. value=segment.value,
  214. selector=selector,
  215. ),
  216. )