You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

variables.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from collections.abc import Sequence
  2. from typing import Any, cast
  3. from uuid import uuid4
  4. from pydantic import BaseModel, Field
  5. from typing import Annotated, TypeAlias, cast
  6. from uuid import uuid4
  7. from pydantic import Discriminator, Field, Tag
  8. from core.helper import encrypter
  9. from .segments import (
  10. ArrayAnySegment,
  11. ArrayFileSegment,
  12. ArrayNumberSegment,
  13. ArrayObjectSegment,
  14. ArraySegment,
  15. ArrayStringSegment,
  16. FileSegment,
  17. FloatSegment,
  18. IntegerSegment,
  19. NoneSegment,
  20. ObjectSegment,
  21. Segment,
  22. StringSegment,
  23. get_segment_discriminator,
  24. )
  25. from .types import SegmentType
  26. class Variable(Segment):
  27. """
  28. A variable is a segment that has a name.
  29. It is mainly used to store segments and their selector in VariablePool.
  30. Note: this class is abstract, you should use subclasses of this class instead.
  31. """
  32. id: str = Field(
  33. default_factory=lambda: str(uuid4()),
  34. description="Unique identity for variable.",
  35. )
  36. name: str
  37. description: str = Field(default="", description="Description of the variable.")
  38. selector: Sequence[str] = Field(default_factory=list)
  39. class StringVariable(StringSegment, Variable):
  40. pass
  41. class FloatVariable(FloatSegment, Variable):
  42. pass
  43. class IntegerVariable(IntegerSegment, Variable):
  44. pass
  45. class ObjectVariable(ObjectSegment, Variable):
  46. pass
  47. class ArrayVariable(ArraySegment, Variable):
  48. pass
  49. class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
  50. pass
  51. class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
  52. pass
  53. class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
  54. pass
  55. class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
  56. pass
  57. class SecretVariable(StringVariable):
  58. value_type: SegmentType = SegmentType.SECRET
  59. @property
  60. def log(self) -> str:
  61. return cast(str, encrypter.obfuscated_token(self.value))
  62. class NoneVariable(NoneSegment, Variable):
  63. value_type: SegmentType = SegmentType.NONE
  64. value: None = None
  65. class FileVariable(FileSegment, Variable):
  66. pass
  67. class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
  68. pass
  69. class RAGPipelineVariable(BaseModel):
  70. belong_to_node_id: str = Field(description="belong to which node id, shared means public")
  71. type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
  72. label: str = Field(description="label")
  73. description: str | None = Field(description="description", default="")
  74. variable: str = Field(description="variable key", default="")
  75. max_length: int | None = Field(
  76. description="max length, applicable to text-input, paragraph, and file-list", default=0
  77. )
  78. default_value: Any = Field(description="default value", default="")
  79. placeholder: str | None = Field(description="placeholder", default="")
  80. unit: str | None = Field(description="unit, applicable to Number", default="")
  81. tooltips: str | None = Field(description="helpful text", default="")
  82. allowed_file_types: list[str] | None = Field(
  83. description="image, document, audio, video, custom.", default_factory=list
  84. )
  85. allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list)
  86. allowed_file_upload_methods: list[str] | None = Field(
  87. description="remote_url, local_file, tool_file.", default_factory=list
  88. )
  89. required: bool = Field(description="optional, default false", default=False)
  90. options: list[str] | None = Field(default_factory=list)
  91. class RAGPipelineVariableInput(BaseModel):
  92. variable: RAGPipelineVariable
  93. value: Any
  94. # The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
  95. # Use `Variable` for type hinting when serialization is not required.
  96. #
  97. # Note:
  98. # - All variants in `VariableUnion` must inherit from the `Variable` class.
  99. # - The union must include all non-abstract subclasses of `Segment`, except:
  100. VariableUnion: TypeAlias = Annotated[
  101. (
  102. Annotated[NoneVariable, Tag(SegmentType.NONE)]
  103. | Annotated[StringVariable, Tag(SegmentType.STRING)]
  104. | Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
  105. | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
  106. | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
  107. | Annotated[FileVariable, Tag(SegmentType.FILE)]
  108. | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
  109. | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
  110. | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
  111. | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
  112. | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
  113. | Annotated[SecretVariable, Tag(SegmentType.SECRET)]
  114. ),
  115. Discriminator(get_segment_discriminator),
  116. ]