Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

variable_factory.py 8.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayFileSegment,
  10. ArrayNumberSegment,
  11. ArrayObjectSegment,
  12. ArraySegment,
  13. ArrayStringSegment,
  14. FileSegment,
  15. FloatSegment,
  16. IntegerSegment,
  17. NoneSegment,
  18. ObjectSegment,
  19. Segment,
  20. StringSegment,
  21. )
  22. from core.variables.types import SegmentType
  23. from core.variables.variables import (
  24. ArrayAnyVariable,
  25. ArrayFileVariable,
  26. ArrayNumberVariable,
  27. ArrayObjectVariable,
  28. ArrayStringVariable,
  29. FileVariable,
  30. FloatVariable,
  31. IntegerVariable,
  32. NoneVariable,
  33. ObjectVariable,
  34. SecretVariable,
  35. StringVariable,
  36. Variable,
  37. )
  38. from core.workflow.constants import (
  39. CONVERSATION_VARIABLE_NODE_ID,
  40. ENVIRONMENT_VARIABLE_NODE_ID,
  41. PIPELINE_VARIABLE_NODE_ID,
  42. )
  43. class InvalidSelectorError(ValueError):
  44. pass
  45. class UnsupportedSegmentTypeError(Exception):
  46. pass
  47. # Define the constant
  48. SEGMENT_TO_VARIABLE_MAP = {
  49. StringSegment: StringVariable,
  50. IntegerSegment: IntegerVariable,
  51. FloatSegment: FloatVariable,
  52. ObjectSegment: ObjectVariable,
  53. FileSegment: FileVariable,
  54. ArrayStringSegment: ArrayStringVariable,
  55. ArrayNumberSegment: ArrayNumberVariable,
  56. ArrayObjectSegment: ArrayObjectVariable,
  57. ArrayFileSegment: ArrayFileVariable,
  58. ArrayAnySegment: ArrayAnyVariable,
  59. NoneSegment: NoneVariable,
  60. }
  61. def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  62. if not mapping.get("name"):
  63. raise VariableError("missing name")
  64. return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
  65. def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  66. if not mapping.get("name"):
  67. raise VariableError("missing name")
  68. return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
  69. def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  70. if not mapping.get("variable"):
  71. raise VariableError("missing variable")
  72. return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]])
  73. def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  74. """
  75. This factory function is used to create the environment variable or the conversation variable,
  76. not support the File type.
  77. """
  78. if (value_type := mapping.get("value_type")) is None:
  79. raise VariableError("missing value type")
  80. if (value := mapping.get("value")) is None:
  81. raise VariableError("missing value")
  82. # FIXME: using Any here, fix it later
  83. result: Any
  84. match value_type:
  85. case SegmentType.STRING:
  86. result = StringVariable.model_validate(mapping)
  87. case SegmentType.SECRET:
  88. result = SecretVariable.model_validate(mapping)
  89. case SegmentType.NUMBER if isinstance(value, int):
  90. result = IntegerVariable.model_validate(mapping)
  91. case SegmentType.NUMBER if isinstance(value, float):
  92. result = FloatVariable.model_validate(mapping)
  93. case SegmentType.NUMBER if not isinstance(value, float | int):
  94. raise VariableError(f"invalid number value {value}")
  95. case SegmentType.OBJECT if isinstance(value, dict):
  96. result = ObjectVariable.model_validate(mapping)
  97. case SegmentType.ARRAY_STRING if isinstance(value, list):
  98. result = ArrayStringVariable.model_validate(mapping)
  99. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  100. result = ArrayNumberVariable.model_validate(mapping)
  101. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  102. result = ArrayObjectVariable.model_validate(mapping)
  103. case _:
  104. raise VariableError(f"not supported value type {value_type}")
  105. if result.size > dify_config.MAX_VARIABLE_SIZE:
  106. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  107. if not result.selector:
  108. result = result.model_copy(update={"selector": selector})
  109. return cast(Variable, result)
  110. def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
  111. """
  112. This factory function is used to create the rag pipeline variable,
  113. not support the File type.
  114. """
  115. if (type := mapping.get("type")) is None:
  116. raise VariableError("missing type")
  117. if (value := mapping.get("value")) is None:
  118. raise VariableError("missing value")
  119. # FIXME: using Any here, fix it later
  120. result: Any
  121. match type:
  122. case SegmentType.STRING:
  123. result = StringVariable.model_validate(mapping)
  124. case SegmentType.SECRET:
  125. result = SecretVariable.model_validate(mapping)
  126. case SegmentType.NUMBER if isinstance(value, int):
  127. result = IntegerVariable.model_validate(mapping)
  128. case SegmentType.NUMBER if isinstance(value, float):
  129. result = FloatVariable.model_validate(mapping)
  130. case SegmentType.NUMBER if not isinstance(value, float | int):
  131. raise VariableError(f"invalid number value {value}")
  132. case SegmentType.OBJECT if isinstance(value, dict):
  133. result = ObjectVariable.model_validate(mapping)
  134. case SegmentType.ARRAY_STRING if isinstance(value, list):
  135. result = ArrayStringVariable.model_validate(mapping)
  136. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  137. result = ArrayNumberVariable.model_validate(mapping)
  138. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  139. result = ArrayObjectVariable.model_validate(mapping)
  140. case _:
  141. raise VariableError(f"not supported type {type}")
  142. if result.size > dify_config.MAX_VARIABLE_SIZE:
  143. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  144. if not result.selector:
  145. result = result.model_copy(update={"selector": selector})
  146. return cast(Variable, result)
  147. def build_segment(value: Any, /) -> Segment:
  148. if value is None:
  149. return NoneSegment()
  150. if isinstance(value, str):
  151. return StringSegment(value=value)
  152. if isinstance(value, int):
  153. return IntegerSegment(value=value)
  154. if isinstance(value, float):
  155. return FloatSegment(value=value)
  156. if isinstance(value, dict):
  157. return ObjectSegment(value=value)
  158. if isinstance(value, File):
  159. return FileSegment(value=value)
  160. if isinstance(value, list):
  161. items = [build_segment(item) for item in value]
  162. types = {item.value_type for item in items}
  163. if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
  164. return ArrayAnySegment(value=value)
  165. match types.pop():
  166. case SegmentType.STRING:
  167. return ArrayStringSegment(value=value)
  168. case SegmentType.NUMBER:
  169. return ArrayNumberSegment(value=value)
  170. case SegmentType.OBJECT:
  171. return ArrayObjectSegment(value=value)
  172. case SegmentType.FILE:
  173. return ArrayFileSegment(value=value)
  174. case SegmentType.NONE:
  175. return ArrayAnySegment(value=value)
  176. case _:
  177. raise ValueError(f"not supported value {value}")
  178. raise ValueError(f"not supported value {value}")
  179. def segment_to_variable(
  180. *,
  181. segment: Segment,
  182. selector: Sequence[str],
  183. id: str | None = None,
  184. name: str | None = None,
  185. description: str = "",
  186. ) -> Variable:
  187. if isinstance(segment, Variable):
  188. return segment
  189. name = name or selector[-1]
  190. id = id or str(uuid4())
  191. segment_type = type(segment)
  192. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  193. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  194. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  195. return cast(
  196. Variable,
  197. variable_class(
  198. id=id,
  199. name=name,
  200. description=description,
  201. value=segment.value,
  202. selector=selector,
  203. ),
  204. )