- Introduce `WorkflowDraftVariable` model and the corresponding migration. - Implement `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`.tags/1.4.1
| @@ -4,13 +4,14 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. | |||
| import json | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from typing import Optional, Union | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from sqlalchemy import UnaryExpression, asc, delete, desc, select | |||
| from sqlalchemy.engine import Engine | |||
| from sqlalchemy.orm import sessionmaker | |||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||
| from core.workflow.entities.node_execution_entities import ( | |||
| NodeExecution, | |||
| NodeExecutionStatus, | |||
| @@ -122,7 +123,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) | |||
| status=status, | |||
| error=db_model.error, | |||
| elapsed_time=db_model.elapsed_time, | |||
| metadata=metadata, | |||
| # FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11. | |||
| # However, this problem is not occurred in Python 3.12. | |||
| # | |||
| # A case of this error is: | |||
| # https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24 | |||
| metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata), | |||
| created_at=db_model.created_at, | |||
| finished_at=db_model.finished_at, | |||
| ) | |||
| @@ -0,0 +1,7 @@ | |||
| # The minimal selector length for valid variables. | |||
| # | |||
| # The first element of the selector is the node id, and the second element is the variable name. | |||
| # | |||
| # If the selector length is more than 2, the remaining parts are the keys / indexes paths used | |||
| # to extract part of the variable value. | |||
| MIN_SELECTORS_LENGTH = 2 | |||
| @@ -0,0 +1,8 @@ | |||
| from collections.abc import Iterable, Sequence | |||
| def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: | |||
| selectors = [node_id, name] | |||
| if paths: | |||
| selectors.extend(paths) | |||
| return selectors | |||
| @@ -0,0 +1,51 @@ | |||
| """add WorkflowDraftVariable model | |||
| Revision ID: 2adcbe1f5dfb | |||
| Revises: d28f2004b072 | |||
| Create Date: 2025-05-15 15:31:03.128680 | |||
| """ | |||
| import sqlalchemy as sa | |||
| from alembic import op | |||
| import models as models | |||
| # revision identifiers, used by Alembic. | |||
| revision = "2adcbe1f5dfb" | |||
| down_revision = "d28f2004b072" | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table( | |||
| "workflow_draft_variables", | |||
| sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), | |||
| sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), | |||
| sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), | |||
| sa.Column("app_id", models.types.StringUUID(), nullable=False), | |||
| sa.Column("last_edited_at", sa.DateTime(), nullable=True), | |||
| sa.Column("node_id", sa.String(length=255), nullable=False), | |||
| sa.Column("name", sa.String(length=255), nullable=False), | |||
| sa.Column("description", sa.String(length=255), nullable=False), | |||
| sa.Column("selector", sa.String(length=255), nullable=False), | |||
| sa.Column("value_type", sa.String(length=20), nullable=False), | |||
| sa.Column("value", sa.Text(), nullable=False), | |||
| sa.Column("visible", sa.Boolean(), nullable=False), | |||
| sa.Column("editable", sa.Boolean(), nullable=False), | |||
| sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")), | |||
| sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")), | |||
| ) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| # Dropping `workflow_draft_variables` also drops any index associated with it. | |||
| op.drop_table("workflow_draft_variables") | |||
| # ### end Alembic commands ### | |||
| @@ -14,3 +14,10 @@ class UserFrom(StrEnum): | |||
| class WorkflowRunTriggeredFrom(StrEnum): | |||
| DEBUGGING = "debugging" | |||
| APP_RUN = "app-run" | |||
| class DraftVariableType(StrEnum): | |||
| # node means that the correspond variable | |||
| NODE = "node" | |||
| SYS = "sys" | |||
| CONVERSATION = "conversation" | |||
| @@ -1,4 +1,7 @@ | |||
| from sqlalchemy import CHAR, TypeDecorator | |||
| import enum | |||
| from typing import Generic, TypeVar | |||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | |||
| from sqlalchemy.dialects.postgresql import UUID | |||
| @@ -24,3 +27,51 @@ class StringUUID(TypeDecorator): | |||
| if value is None: | |||
| return value | |||
| return str(value) | |||
| _E = TypeVar("_E", bound=enum.StrEnum) | |||
| class EnumText(TypeDecorator, Generic[_E]): | |||
| impl = VARCHAR | |||
| cache_ok = True | |||
| _length: int | |||
| _enum_class: type[_E] | |||
| def __init__(self, enum_class: type[_E], length: int | None = None): | |||
| self._enum_class = enum_class | |||
| max_enum_value_len = max(len(e.value) for e in enum_class) | |||
| if length is not None: | |||
| if length < max_enum_value_len: | |||
| raise ValueError("length should be greater than enum value length.") | |||
| self._length = length | |||
| else: | |||
| # leave some rooms for future longer enum values. | |||
| self._length = max(max_enum_value_len, 20) | |||
| def process_bind_param(self, value: _E | str | None, dialect): | |||
| if value is None: | |||
| return value | |||
| if isinstance(value, self._enum_class): | |||
| return value.value | |||
| elif isinstance(value, str): | |||
| self._enum_class(value) | |||
| return value | |||
| else: | |||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||
| def load_dialect_impl(self, dialect): | |||
| return dialect.type_descriptor(VARCHAR(self._length)) | |||
| def process_result_value(self, value, dialect) -> _E | None: | |||
| if value is None: | |||
| return value | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"expected str, got {type(value)}") | |||
| return self._enum_class(value) | |||
| def compare_values(self, x, y): | |||
| if x is None or y is None: | |||
| return x is y | |||
| return x == y | |||
| @@ -1,29 +1,36 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||
| from uuid import uuid4 | |||
| from core.variables import utils as variable_utils | |||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||
| from factories.variable_factory import build_segment | |||
| if TYPE_CHECKING: | |||
| from models.model import AppMode | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import func | |||
| from sqlalchemy import UniqueConstraint, func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| import contexts | |||
| from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE | |||
| from core.helper import encrypter | |||
| from core.variables import SecretVariable, Variable | |||
| from core.variables import SecretVariable, Segment, SegmentType, Variable | |||
| from factories import variable_factory | |||
| from libs import helper | |||
| from .account import Account | |||
| from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatorUserRole | |||
| from .types import StringUUID | |||
| from .enums import CreatorUserRole, DraftVariableType | |||
| from .types import EnumText, StringUUID | |||
| _logger = logging.getLogger(__name__) | |||
| if TYPE_CHECKING: | |||
| from models.model import AppMode | |||
| @@ -651,7 +658,7 @@ class WorkflowNodeExecution(Base): | |||
| return json.loads(self.inputs) if self.inputs else None | |||
| @property | |||
| def outputs_dict(self): | |||
| def outputs_dict(self) -> dict[str, Any] | None: | |||
| return json.loads(self.outputs) if self.outputs else None | |||
| @property | |||
| @@ -659,7 +666,7 @@ class WorkflowNodeExecution(Base): | |||
| return json.loads(self.process_data) if self.process_data else None | |||
| @property | |||
| def execution_metadata_dict(self): | |||
| def execution_metadata_dict(self) -> dict[str, Any] | None: | |||
| return json.loads(self.execution_metadata) if self.execution_metadata else None | |||
| @property | |||
| @@ -797,3 +804,202 @@ class ConversationVariable(Base): | |||
| def to_variable(self) -> Variable: | |||
| mapping = json.loads(self.data) | |||
| return variable_factory.build_conversation_variable_from_mapping(mapping) | |||
| # Only `sys.query` and `sys.files` could be modified. | |||
| _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) | |||
| def _naive_utc_datetime(): | |||
| return datetime.now(UTC).replace(tzinfo=None) | |||
| class WorkflowDraftVariable(Base): | |||
| @staticmethod | |||
| def unique_columns() -> list[str]: | |||
| return [ | |||
| "app_id", | |||
| "node_id", | |||
| "name", | |||
| ] | |||
| __tablename__ = "workflow_draft_variables" | |||
| __table_args__ = (UniqueConstraint(*unique_columns()),) | |||
| # id is the unique identifier of a draft variable. | |||
| id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| created_at = mapped_column( | |||
| db.DateTime, | |||
| nullable=False, | |||
| default=_naive_utc_datetime, | |||
| server_default=func.current_timestamp(), | |||
| ) | |||
| updated_at = mapped_column( | |||
| db.DateTime, | |||
| nullable=False, | |||
| default=_naive_utc_datetime, | |||
| server_default=func.current_timestamp(), | |||
| onupdate=func.current_timestamp(), | |||
| ) | |||
| # "`app_id` maps to the `id` field in the `model.App` model." | |||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| # `last_edited_at` records when the value of a given draft variable | |||
| # is edited. | |||
| # | |||
| # If it's not edited after creation, its value is `None`. | |||
| last_edited_at: Mapped[datetime | None] = mapped_column( | |||
| db.DateTime, | |||
| nullable=True, | |||
| default=None, | |||
| ) | |||
| # The `node_id` field is special. | |||
| # | |||
| # If the variable is a conversation variable or a system variable, then the value of `node_id` | |||
| # is `conversation` or `sys`, respective. | |||
| # | |||
| # Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is | |||
| # the identity of correspond node in graph definition. An example of node id is `"1745769620734"`. | |||
| # | |||
| # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other | |||
| # "Answer" node conform the rules above.) | |||
| node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") | |||
| # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than | |||
| # 80 chars. | |||
| # | |||
| # ref: api/core/workflow/entities/variable_pool.py:18 | |||
| name: Mapped[str] = mapped_column(sa.String(255), nullable=False) | |||
| description: Mapped[str] = mapped_column( | |||
| sa.String(255), | |||
| default="", | |||
| nullable=False, | |||
| ) | |||
| selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") | |||
| value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) | |||
| # JSON string | |||
| value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") | |||
| # visible | |||
| visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) | |||
| editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) | |||
| def get_selector(self) -> list[str]: | |||
| selector = json.loads(self.selector) | |||
| if not isinstance(selector, list): | |||
| _logger.error( | |||
| "invalid selector loaded from database, type=%s, value=%s", | |||
| type(selector), | |||
| self.selector, | |||
| ) | |||
| raise ValueError("invalid selector.") | |||
| return selector | |||
| def _set_selector(self, value: list[str]): | |||
| self.selector = json.dumps(value) | |||
| def get_value(self) -> Segment | None: | |||
| return build_segment(json.loads(self.value)) | |||
| def set_name(self, name: str): | |||
| self.name = name | |||
| self._set_selector([self.node_id, name]) | |||
| def set_value(self, value: Segment): | |||
| self.value = json.dumps(value.value) | |||
| self.value_type = value.value_type | |||
| def get_node_id(self) -> str | None: | |||
| if self.get_variable_type() == DraftVariableType.NODE: | |||
| return self.node_id | |||
| else: | |||
| return None | |||
| def get_variable_type(self) -> DraftVariableType: | |||
| match self.node_id: | |||
| case DraftVariableType.CONVERSATION: | |||
| return DraftVariableType.CONVERSATION | |||
| case DraftVariableType.SYS: | |||
| return DraftVariableType.SYS | |||
| case _: | |||
| return DraftVariableType.NODE | |||
| @classmethod | |||
| def _new( | |||
| cls, | |||
| *, | |||
| app_id: str, | |||
| node_id: str, | |||
| name: str, | |||
| value: Segment, | |||
| description: str = "", | |||
| ) -> "WorkflowDraftVariable": | |||
| variable = WorkflowDraftVariable() | |||
| variable.created_at = _naive_utc_datetime() | |||
| variable.updated_at = _naive_utc_datetime() | |||
| variable.description = description | |||
| variable.app_id = app_id | |||
| variable.node_id = node_id | |||
| variable.name = name | |||
| variable.app_id = app_id | |||
| variable.set_value(value) | |||
| variable._set_selector(list(variable_utils.to_selector(node_id, name))) | |||
| return variable | |||
| @classmethod | |||
| def new_conversation_variable( | |||
| cls, | |||
| *, | |||
| app_id: str, | |||
| name: str, | |||
| value: Segment, | |||
| ) -> "WorkflowDraftVariable": | |||
| variable = cls._new( | |||
| app_id=app_id, | |||
| node_id=CONVERSATION_VARIABLE_NODE_ID, | |||
| name=name, | |||
| value=value, | |||
| ) | |||
| return variable | |||
| @classmethod | |||
| def new_sys_variable( | |||
| cls, | |||
| *, | |||
| app_id: str, | |||
| name: str, | |||
| value: Segment, | |||
| editable: bool = False, | |||
| ) -> "WorkflowDraftVariable": | |||
| variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) | |||
| variable.editable = editable | |||
| return variable | |||
| @classmethod | |||
| def new_node_variable( | |||
| cls, | |||
| *, | |||
| app_id: str, | |||
| node_id: str, | |||
| name: str, | |||
| value: Segment, | |||
| visible: bool = True, | |||
| ) -> "WorkflowDraftVariable": | |||
| variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) | |||
| variable.visible = visible | |||
| variable.editable = True | |||
| return variable | |||
| @property | |||
| def edited(self): | |||
| return self.last_edited_at is not None | |||
| def is_system_variable_editable(name: str) -> bool: | |||
| return name in _EDITABLE_SYSTEM_VARIABLE | |||
| @@ -0,0 +1,187 @@ | |||
| from collections.abc import Callable, Iterable | |||
| from enum import StrEnum | |||
| from typing import Any, NamedTuple, TypeVar | |||
| import pytest | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import exc as sa_exc | |||
| from sqlalchemy import insert | |||
| from sqlalchemy.orm import DeclarativeBase, Mapped, Session | |||
| from sqlalchemy.sql.sqltypes import VARCHAR | |||
| from models.types import EnumText | |||
| _user_type_admin = "admin" | |||
| _user_type_normal = "normal" | |||
| class _Base(DeclarativeBase): | |||
| pass | |||
| class _UserType(StrEnum): | |||
| admin = _user_type_admin | |||
| normal = _user_type_normal | |||
| class _EnumWithLongValue(StrEnum): | |||
| unknown = "unknown" | |||
| a_really_long_enum_values = "a_really_long_enum_values" | |||
| class _User(_Base): | |||
| __tablename__ = "users" | |||
| id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) | |||
| name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False) | |||
| user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) | |||
| user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True) | |||
| class _ColumnTest(_Base): | |||
| __tablename__ = "column_test" | |||
| id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) | |||
| user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) | |||
| explicit_length: Mapped[_UserType | None] = sa.Column( | |||
| EnumText(_UserType, length=50), nullable=True, default=_UserType.normal | |||
| ) | |||
| long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False) | |||
| _T = TypeVar("_T") | |||
| def _first(it: Iterable[_T]) -> _T: | |||
| ls = list(it) | |||
| if not ls: | |||
| raise ValueError("List is empty") | |||
| return ls[0] | |||
| class TestEnumText: | |||
| def test_column_impl(self): | |||
| engine = sa.create_engine("sqlite://", echo=False) | |||
| _Base.metadata.create_all(engine) | |||
| inspector = sa.inspect(engine) | |||
| columns = inspector.get_columns(_ColumnTest.__tablename__) | |||
| user_type_column = _first(c for c in columns if c["name"] == "user_type") | |||
| sql_type = user_type_column["type"] | |||
| assert isinstance(user_type_column["type"], VARCHAR) | |||
| assert sql_type.length == 20 | |||
| assert user_type_column["nullable"] is False | |||
| explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length") | |||
| sql_type = explicit_length_column["type"] | |||
| assert isinstance(sql_type, VARCHAR) | |||
| assert sql_type.length == 50 | |||
| assert explicit_length_column["nullable"] is True | |||
| long_value_column = _first(c for c in columns if c["name"] == "long_value") | |||
| sql_type = long_value_column["type"] | |||
| assert isinstance(sql_type, VARCHAR) | |||
| assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values) | |||
| def test_insert_and_select(self): | |||
| engine = sa.create_engine("sqlite://", echo=False) | |||
| _Base.metadata.create_all(engine) | |||
| with Session(engine) as session: | |||
| admin_user = _User( | |||
| name="admin", | |||
| user_type=_UserType.admin, | |||
| user_type_nullable=None, | |||
| ) | |||
| session.add(admin_user) | |||
| session.flush() | |||
| admin_user_id = admin_user.id | |||
| normal_user = _User( | |||
| name="normal", | |||
| user_type=_UserType.normal.value, | |||
| user_type_nullable=_UserType.normal.value, | |||
| ) | |||
| session.add(normal_user) | |||
| session.flush() | |||
| normal_user_id = normal_user.id | |||
| session.commit() | |||
| with Session(engine) as session: | |||
| user = session.query(_User).filter(_User.id == admin_user_id).first() | |||
| assert user.user_type == _UserType.admin | |||
| assert user.user_type_nullable is None | |||
| with Session(engine) as session: | |||
| user = session.query(_User).filter(_User.id == normal_user_id).first() | |||
| assert user.user_type == _UserType.normal | |||
| assert user.user_type_nullable == _UserType.normal | |||
| def test_insert_invalid_values(self): | |||
| def _session_insert_with_value(sess: Session, user_type: Any): | |||
| user = _User(name="test_user", user_type=user_type) | |||
| sess.add(user) | |||
| sess.flush() | |||
| def _insert_with_user(sess: Session, user_type: Any): | |||
| stmt = insert(_User).values( | |||
| { | |||
| "name": "test_user", | |||
| "user_type": user_type, | |||
| } | |||
| ) | |||
| sess.execute(stmt) | |||
| class TestCase(NamedTuple): | |||
| name: str | |||
| action: Callable[[Session], None] | |||
| exc_type: type[Exception] | |||
| engine = sa.create_engine("sqlite://", echo=False) | |||
| _Base.metadata.create_all(engine) | |||
| cases = [ | |||
| TestCase( | |||
| name="session insert with invalid value", | |||
| action=lambda s: _session_insert_with_value(s, "invalid"), | |||
| exc_type=ValueError, | |||
| ), | |||
| TestCase( | |||
| name="session insert with invalid type", | |||
| action=lambda s: _session_insert_with_value(s, 1), | |||
| exc_type=TypeError, | |||
| ), | |||
| TestCase( | |||
| name="insert with invalid value", | |||
| action=lambda s: _insert_with_user(s, "invalid"), | |||
| exc_type=ValueError, | |||
| ), | |||
| TestCase( | |||
| name="insert with invalid type", | |||
| action=lambda s: _insert_with_user(s, 1), | |||
| exc_type=TypeError, | |||
| ), | |||
| ] | |||
| for idx, c in enumerate(cases, 1): | |||
| with pytest.raises(sa_exc.StatementError) as exc: | |||
| with Session(engine) as session: | |||
| c.action(session) | |||
| assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}" | |||
| def test_select_invalid_values(self): | |||
| engine = sa.create_engine("sqlite://", echo=False) | |||
| _Base.metadata.create_all(engine) | |||
| insertion_sql = """ | |||
| INSERT INTO users (id, name, user_type) VALUES | |||
| (1, 'invalid_value', 'invalid'); | |||
| """ | |||
| with Session(engine) as session: | |||
| session.execute(sa.text(insertion_sql)) | |||
| session.commit() | |||
| with pytest.raises(ValueError) as exc: | |||
| with Session(engine) as session: | |||
| _user = session.query(_User).filter(_User.id == 1).first() | |||