您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import json
  2. import sys
  3. from collections.abc import Mapping, Sequence
  4. from typing import Any
  5. from pydantic import BaseModel, ConfigDict, field_validator
  6. from core.file import File
  7. from .types import SegmentType
  8. class Segment(BaseModel):
  9. model_config = ConfigDict(frozen=True)
  10. value_type: SegmentType
  11. value: Any
  12. @field_validator("value_type")
  13. @classmethod
  14. def validate_value_type(cls, value):
  15. """
  16. This validator checks if the provided value is equal to the default value of the 'value_type' field.
  17. If the value is different, a ValueError is raised.
  18. """
  19. if value != cls.model_fields["value_type"].default:
  20. raise ValueError("Cannot modify 'value_type'")
  21. return value
  22. @property
  23. def text(self) -> str:
  24. return str(self.value)
  25. @property
  26. def log(self) -> str:
  27. return str(self.value)
  28. @property
  29. def markdown(self) -> str:
  30. return str(self.value)
  31. @property
  32. def size(self) -> int:
  33. """
  34. Return the size of the value in bytes.
  35. """
  36. return sys.getsizeof(self.value)
  37. def to_object(self) -> Any:
  38. return self.value
  39. class NoneSegment(Segment):
  40. value_type: SegmentType = SegmentType.NONE
  41. value: None = None
  42. @property
  43. def text(self) -> str:
  44. return ""
  45. @property
  46. def log(self) -> str:
  47. return ""
  48. @property
  49. def markdown(self) -> str:
  50. return ""
  51. class StringSegment(Segment):
  52. value_type: SegmentType = SegmentType.STRING
  53. value: str
  54. class FloatSegment(Segment):
  55. value_type: SegmentType = SegmentType.NUMBER
  56. value: float
  57. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
  58. # The following tests cannot pass.
  59. #
  60. # def test_float_segment_and_nan():
  61. # nan = float("nan")
  62. # assert nan != nan
  63. #
  64. # f1 = FloatSegment(value=float("nan"))
  65. # f2 = FloatSegment(value=float("nan"))
  66. # assert f1 != f2
  67. #
  68. # f3 = FloatSegment(value=nan)
  69. # f4 = FloatSegment(value=nan)
  70. # assert f3 != f4
  71. class IntegerSegment(Segment):
  72. value_type: SegmentType = SegmentType.NUMBER
  73. value: int
  74. class ObjectSegment(Segment):
  75. value_type: SegmentType = SegmentType.OBJECT
  76. value: Mapping[str, Any]
  77. @property
  78. def text(self) -> str:
  79. return json.dumps(self.model_dump()["value"], ensure_ascii=False)
  80. @property
  81. def log(self) -> str:
  82. return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
  83. @property
  84. def markdown(self) -> str:
  85. return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
  86. class ArraySegment(Segment):
  87. @property
  88. def markdown(self) -> str:
  89. items = []
  90. for item in self.value:
  91. items.append(str(item))
  92. return "\n".join(items)
  93. class FileSegment(Segment):
  94. value_type: SegmentType = SegmentType.FILE
  95. value: File
  96. @property
  97. def markdown(self) -> str:
  98. return self.value.markdown
  99. @property
  100. def log(self) -> str:
  101. return ""
  102. @property
  103. def text(self) -> str:
  104. return ""
  105. class ArrayAnySegment(ArraySegment):
  106. value_type: SegmentType = SegmentType.ARRAY_ANY
  107. value: Sequence[Any]
  108. class ArrayStringSegment(ArraySegment):
  109. value_type: SegmentType = SegmentType.ARRAY_STRING
  110. value: Sequence[str]
  111. @property
  112. def text(self) -> str:
  113. return json.dumps(self.value, ensure_ascii=False)
  114. class ArrayNumberSegment(ArraySegment):
  115. value_type: SegmentType = SegmentType.ARRAY_NUMBER
  116. value: Sequence[float | int]
  117. class ArrayObjectSegment(ArraySegment):
  118. value_type: SegmentType = SegmentType.ARRAY_OBJECT
  119. value: Sequence[Mapping[str, Any]]
  120. class ArrayFileSegment(ArraySegment):
  121. value_type: SegmentType = SegmentType.ARRAY_FILE
  122. value: Sequence[File]
  123. @property
  124. def markdown(self) -> str:
  125. items = []
  126. for item in self.value:
  127. items.append(item.markdown)
  128. return "\n".join(items)
  129. @property
  130. def log(self) -> str:
  131. return ""
  132. @property
  133. def text(self) -> str:
  134. return ""