| @@ -17,6 +17,7 @@ from .segments import ( | |||
| from .types import SegmentType | |||
| from .variables import ( | |||
| ArrayAnyVariable, | |||
| ArrayFileVariable, | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| @@ -58,4 +59,5 @@ __all__ = [ | |||
| "ArrayStringSegment", | |||
| "FileSegment", | |||
| "FileVariable", | |||
| "ArrayFileVariable", | |||
| ] | |||
| @@ -1,9 +1,13 @@ | |||
| from collections.abc import Sequence | |||
| from uuid import uuid4 | |||
| from pydantic import Field | |||
| from core.helper import encrypter | |||
| from .segments import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| ArrayNumberSegment, | |||
| ArrayObjectSegment, | |||
| ArrayStringSegment, | |||
| @@ -24,11 +28,12 @@ class Variable(Segment): | |||
| """ | |||
| id: str = Field( | |||
| default="", | |||
| description="Unique identity for variable. It's only used by environment variables now.", | |||
| default=lambda _: str(uuid4()), | |||
| description="Unique identity for variable.", | |||
| ) | |||
| name: str | |||
| description: str = Field(default="", description="Description of the variable.") | |||
| selector: Sequence[str] = Field(default_factory=list) | |||
| class StringVariable(StringSegment, Variable): | |||
| @@ -78,3 +83,7 @@ class NoneVariable(NoneSegment, Variable): | |||
| class FileVariable(FileSegment, Variable): | |||
| pass | |||
| class ArrayFileVariable(ArrayFileSegment, Variable): | |||
| pass | |||
| @@ -95,13 +95,16 @@ class VariablePool(BaseModel): | |||
| if len(selector) < 2: | |||
| raise ValueError("Invalid selector") | |||
| if isinstance(value, Variable): | |||
| variable = value | |||
| if isinstance(value, Segment): | |||
| v = value | |||
| variable = variable_factory.segment_to_variable(segment=value, selector=selector) | |||
| else: | |||
| v = variable_factory.build_segment(value) | |||
| segment = variable_factory.build_segment(value) | |||
| variable = variable_factory.segment_to_variable(segment=segment, selector=selector) | |||
| hash_key = hash(tuple(selector[1:])) | |||
| self.variable_dictionary[selector[0]][hash_key] = v | |||
| self.variable_dictionary[selector[0]][hash_key] = variable | |||
| def get(self, selector: Sequence[str], /) -> Segment | None: | |||
| """ | |||
| @@ -1,34 +1,65 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any | |||
| from uuid import uuid4 | |||
| from configs import dify_config | |||
| from core.file import File | |||
| from core.variables import ( | |||
| from core.variables.exc import VariableError | |||
| from core.variables.segments import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| ArrayNumberSegment, | |||
| ArrayNumberVariable, | |||
| ArrayObjectSegment, | |||
| ArrayObjectVariable, | |||
| ArraySegment, | |||
| ArrayStringSegment, | |||
| ArrayStringVariable, | |||
| FileSegment, | |||
| FloatSegment, | |||
| FloatVariable, | |||
| IntegerSegment, | |||
| IntegerVariable, | |||
| NoneSegment, | |||
| ObjectSegment, | |||
| ObjectVariable, | |||
| SecretVariable, | |||
| Segment, | |||
| SegmentType, | |||
| StringSegment, | |||
| ) | |||
| from core.variables.types import SegmentType | |||
| from core.variables.variables import ( | |||
| ArrayAnyVariable, | |||
| ArrayFileVariable, | |||
| ArrayNumberVariable, | |||
| ArrayObjectVariable, | |||
| ArrayStringVariable, | |||
| FileVariable, | |||
| FloatVariable, | |||
| IntegerVariable, | |||
| NoneVariable, | |||
| ObjectVariable, | |||
| SecretVariable, | |||
| StringVariable, | |||
| Variable, | |||
| ) | |||
| from core.variables.exc import VariableError | |||
| class InvalidSelectorError(ValueError): | |||
| pass | |||
| class UnsupportedSegmentTypeError(Exception): | |||
| pass | |||
| # Define the constant | |||
| SEGMENT_TO_VARIABLE_MAP = { | |||
| StringSegment: StringVariable, | |||
| IntegerSegment: IntegerVariable, | |||
| FloatSegment: FloatVariable, | |||
| ObjectSegment: ObjectVariable, | |||
| FileSegment: FileVariable, | |||
| ArrayStringSegment: ArrayStringVariable, | |||
| ArrayNumberSegment: ArrayNumberVariable, | |||
| ArrayObjectSegment: ArrayObjectVariable, | |||
| ArrayFileSegment: ArrayFileVariable, | |||
| ArrayAnySegment: ArrayAnyVariable, | |||
| NoneSegment: NoneVariable, | |||
| } | |||
| def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: | |||
| @@ -96,3 +127,30 @@ def build_segment(value: Any, /) -> Segment: | |||
| case _: | |||
| raise ValueError(f"not supported value {value}") | |||
| raise ValueError(f"not supported value {value}") | |||
| def segment_to_variable( | |||
| *, | |||
| segment: Segment, | |||
| selector: Sequence[str], | |||
| id: str | None = None, | |||
| name: str | None = None, | |||
| description: str = "", | |||
| ) -> Variable: | |||
| if isinstance(segment, Variable): | |||
| return segment | |||
| name = name or selector[-1] | |||
| id = id or str(uuid4()) | |||
| segment_type = type(segment) | |||
| if segment_type not in SEGMENT_TO_VARIABLE_MAP: | |||
| raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") | |||
| variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] | |||
| return variable_class( | |||
| id=id, | |||
| name=name, | |||
| description=description, | |||
| value=segment.value, | |||
| selector=selector, | |||
| ) | |||
| @@ -1,5 +1,5 @@ | |||
| from core.helper import encrypter | |||
| from core.variables import SecretVariable, StringSegment | |||
| from core.variables import SecretVariable, StringVariable | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| @@ -54,4 +54,5 @@ def test_convert_variable_to_segment_group(): | |||
| segments_group = variable_pool.convert_template(template) | |||
| assert segments_group.text == "fake-user-id" | |||
| assert segments_group.log == "fake-user-id" | |||
| assert segments_group.value == [StringSegment(value="fake-user-id")] | |||
| assert isinstance(segments_group.value[0], StringVariable) | |||
| assert segments_group.value[0].value == "fake-user-id" | |||