You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

segments.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import json
  2. import sys
  3. from collections.abc import Mapping, Sequence
  4. from typing import Annotated, Any, TypeAlias
  5. from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
  6. from core.file import File
  7. from .types import SegmentType
  8. class Segment(BaseModel):
  9. """Segment is runtime type used during the execution of workflow.
  10. Note: this class is abstract, you should use subclasses of this class instead.
  11. """
  12. model_config = ConfigDict(frozen=True)
  13. value_type: SegmentType
  14. value: Any = None
  15. @field_validator("value_type")
  16. @classmethod
  17. def validate_value_type(cls, value):
  18. """
  19. This validator checks if the provided value is equal to the default value of the 'value_type' field.
  20. If the value is different, a ValueError is raised.
  21. """
  22. if value != cls.model_fields["value_type"].default:
  23. raise ValueError("Cannot modify 'value_type'")
  24. return value
  25. @property
  26. def text(self) -> str:
  27. return str(self.value)
  28. @property
  29. def log(self) -> str:
  30. return str(self.value)
  31. @property
  32. def markdown(self) -> str:
  33. return str(self.value)
  34. @property
  35. def size(self) -> int:
  36. """
  37. Return the size of the value in bytes.
  38. """
  39. return sys.getsizeof(self.value)
  40. def to_object(self):
  41. return self.value
  42. class NoneSegment(Segment):
  43. value_type: SegmentType = SegmentType.NONE
  44. value: None = None
  45. @property
  46. def text(self) -> str:
  47. return ""
  48. @property
  49. def log(self) -> str:
  50. return ""
  51. @property
  52. def markdown(self) -> str:
  53. return ""
  54. class StringSegment(Segment):
  55. value_type: SegmentType = SegmentType.STRING
  56. value: str
  57. class FloatSegment(Segment):
  58. value_type: SegmentType = SegmentType.FLOAT
  59. value: float
  60. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
  61. # The following tests cannot pass.
  62. #
  63. # def test_float_segment_and_nan():
  64. # nan = float("nan")
  65. # assert nan != nan
  66. #
  67. # f1 = FloatSegment(value=float("nan"))
  68. # f2 = FloatSegment(value=float("nan"))
  69. # assert f1 != f2
  70. #
  71. # f3 = FloatSegment(value=nan)
  72. # f4 = FloatSegment(value=nan)
  73. # assert f3 != f4
  74. class IntegerSegment(Segment):
  75. value_type: SegmentType = SegmentType.INTEGER
  76. value: int
  77. class ObjectSegment(Segment):
  78. value_type: SegmentType = SegmentType.OBJECT
  79. value: Mapping[str, Any]
  80. @property
  81. def text(self) -> str:
  82. return json.dumps(self.model_dump()["value"], ensure_ascii=False)
  83. @property
  84. def log(self) -> str:
  85. return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
  86. @property
  87. def markdown(self) -> str:
  88. return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
  89. class ArraySegment(Segment):
  90. @property
  91. def text(self) -> str:
  92. # Return empty string for empty arrays instead of "[]"
  93. if not self.value:
  94. return ""
  95. return super().text
  96. @property
  97. def markdown(self) -> str:
  98. items = []
  99. for item in self.value:
  100. items.append(str(item))
  101. return "\n".join(items)
  102. class FileSegment(Segment):
  103. value_type: SegmentType = SegmentType.FILE
  104. value: File
  105. @property
  106. def markdown(self) -> str:
  107. return self.value.markdown
  108. @property
  109. def log(self) -> str:
  110. return ""
  111. @property
  112. def text(self) -> str:
  113. return ""
  114. class BooleanSegment(Segment):
  115. value_type: SegmentType = SegmentType.BOOLEAN
  116. value: bool
  117. class ArrayAnySegment(ArraySegment):
  118. value_type: SegmentType = SegmentType.ARRAY_ANY
  119. value: Sequence[Any]
  120. class ArrayStringSegment(ArraySegment):
  121. value_type: SegmentType = SegmentType.ARRAY_STRING
  122. value: Sequence[str]
  123. @property
  124. def text(self) -> str:
  125. # Return empty string for empty arrays instead of "[]"
  126. if not self.value:
  127. return ""
  128. return json.dumps(self.value, ensure_ascii=False)
  129. class ArrayNumberSegment(ArraySegment):
  130. value_type: SegmentType = SegmentType.ARRAY_NUMBER
  131. value: Sequence[float | int]
  132. class ArrayObjectSegment(ArraySegment):
  133. value_type: SegmentType = SegmentType.ARRAY_OBJECT
  134. value: Sequence[Mapping[str, Any]]
  135. class ArrayFileSegment(ArraySegment):
  136. value_type: SegmentType = SegmentType.ARRAY_FILE
  137. value: Sequence[File]
  138. @property
  139. def markdown(self) -> str:
  140. items = []
  141. for item in self.value:
  142. items.append(item.markdown)
  143. return "\n".join(items)
  144. @property
  145. def log(self) -> str:
  146. return ""
  147. @property
  148. def text(self) -> str:
  149. return ""
  150. class ArrayBooleanSegment(ArraySegment):
  151. value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
  152. value: Sequence[bool]
  153. def get_segment_discriminator(v: Any) -> SegmentType | None:
  154. if isinstance(v, Segment):
  155. return v.value_type
  156. elif isinstance(v, dict):
  157. value_type = v.get("value_type")
  158. if value_type is None:
  159. return None
  160. try:
  161. seg_type = SegmentType(value_type)
  162. except ValueError:
  163. return None
  164. return seg_type
  165. else:
  166. # return None if the discriminator value isn't found
  167. return None
  168. # The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
  169. # Use `Segment` for type hinting when serialization is not required.
  170. #
  171. # Note:
  172. # - All variants in `SegmentUnion` must inherit from the `Segment` class.
  173. # - The union must include all non-abstract subclasses of `Segment`, except:
  174. # - `SegmentGroup`, which is not added to the variable pool.
  175. # - `Variable` and its subclasses, which are handled by `VariableUnion`.
  176. SegmentUnion: TypeAlias = Annotated[
  177. (
  178. Annotated[NoneSegment, Tag(SegmentType.NONE)]
  179. | Annotated[StringSegment, Tag(SegmentType.STRING)]
  180. | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
  181. | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
  182. | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
  183. | Annotated[FileSegment, Tag(SegmentType.FILE)]
  184. | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
  185. | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
  186. | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
  187. | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
  188. | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
  189. | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
  190. | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
  191. ),
  192. Discriminator(get_segment_discriminator),
  193. ]