- Introduce `WorkflowDraftVariable` model and the corresponding migration. - Implement `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`.tags/1.4.1
| import json | import json | ||||
| import logging | import logging | ||||
| from collections.abc import Sequence | |||||
| from typing import Optional, Union | |||||
| from collections.abc import Mapping, Sequence | |||||
| from typing import Any, Optional, Union, cast | |||||
| from sqlalchemy import UnaryExpression, asc, delete, desc, select | from sqlalchemy import UnaryExpression, asc, delete, desc, select | ||||
| from sqlalchemy.engine import Engine | from sqlalchemy.engine import Engine | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey | |||||
| from core.workflow.entities.node_execution_entities import ( | from core.workflow.entities.node_execution_entities import ( | ||||
| NodeExecution, | NodeExecution, | ||||
| NodeExecutionStatus, | NodeExecutionStatus, | ||||
| status=status, | status=status, | ||||
| error=db_model.error, | error=db_model.error, | ||||
| elapsed_time=db_model.elapsed_time, | elapsed_time=db_model.elapsed_time, | ||||
| metadata=metadata, | |||||
| # FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11. | |||||
| # However, this problem is not occurred in Python 3.12. | |||||
| # | |||||
| # A case of this error is: | |||||
| # https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24 | |||||
| metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata), | |||||
| created_at=db_model.created_at, | created_at=db_model.created_at, | ||||
| finished_at=db_model.finished_at, | finished_at=db_model.finished_at, | ||||
| ) | ) |
| # The minimal selector length for valid variables. | |||||
| # | |||||
| # The first element of the selector is the node id, and the second element is the variable name. | |||||
| # | |||||
| # If the selector length is more than 2, the remaining parts are the keys / indexes paths used | |||||
| # to extract part of the variable value. | |||||
| MIN_SELECTORS_LENGTH = 2 |
| from collections.abc import Iterable, Sequence | |||||
| def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: | |||||
| selectors = [node_id, name] | |||||
| if paths: | |||||
| selectors.extend(paths) | |||||
| return selectors |
| """add WorkflowDraftVariable model | |||||
| Revision ID: 2adcbe1f5dfb | |||||
| Revises: d28f2004b072 | |||||
| Create Date: 2025-05-15 15:31:03.128680 | |||||
| """ | |||||
| import sqlalchemy as sa | |||||
| from alembic import op | |||||
| import models as models | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = "2adcbe1f5dfb" | |||||
| down_revision = "d28f2004b072" | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table( | |||||
| "workflow_draft_variables", | |||||
| sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), | |||||
| sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), | |||||
| sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), | |||||
| sa.Column("app_id", models.types.StringUUID(), nullable=False), | |||||
| sa.Column("last_edited_at", sa.DateTime(), nullable=True), | |||||
| sa.Column("node_id", sa.String(length=255), nullable=False), | |||||
| sa.Column("name", sa.String(length=255), nullable=False), | |||||
| sa.Column("description", sa.String(length=255), nullable=False), | |||||
| sa.Column("selector", sa.String(length=255), nullable=False), | |||||
| sa.Column("value_type", sa.String(length=20), nullable=False), | |||||
| sa.Column("value", sa.Text(), nullable=False), | |||||
| sa.Column("visible", sa.Boolean(), nullable=False), | |||||
| sa.Column("editable", sa.Boolean(), nullable=False), | |||||
| sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")), | |||||
| sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")), | |||||
| ) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| # Dropping `workflow_draft_variables` also drops any index associated with it. | |||||
| op.drop_table("workflow_draft_variables") | |||||
| # ### end Alembic commands ### |
| class WorkflowRunTriggeredFrom(StrEnum): | class WorkflowRunTriggeredFrom(StrEnum): | ||||
| DEBUGGING = "debugging" | DEBUGGING = "debugging" | ||||
| APP_RUN = "app-run" | APP_RUN = "app-run" | ||||
| class DraftVariableType(StrEnum): | |||||
| # node means that the correspond variable | |||||
| NODE = "node" | |||||
| SYS = "sys" | |||||
| CONVERSATION = "conversation" |
| from sqlalchemy import CHAR, TypeDecorator | |||||
| import enum | |||||
| from typing import Generic, TypeVar | |||||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | |||||
| from sqlalchemy.dialects.postgresql import UUID | from sqlalchemy.dialects.postgresql import UUID | ||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| return str(value) | return str(value) | ||||
| _E = TypeVar("_E", bound=enum.StrEnum) | |||||
| class EnumText(TypeDecorator, Generic[_E]): | |||||
| impl = VARCHAR | |||||
| cache_ok = True | |||||
| _length: int | |||||
| _enum_class: type[_E] | |||||
| def __init__(self, enum_class: type[_E], length: int | None = None): | |||||
| self._enum_class = enum_class | |||||
| max_enum_value_len = max(len(e.value) for e in enum_class) | |||||
| if length is not None: | |||||
| if length < max_enum_value_len: | |||||
| raise ValueError("length should be greater than enum value length.") | |||||
| self._length = length | |||||
| else: | |||||
| # leave some rooms for future longer enum values. | |||||
| self._length = max(max_enum_value_len, 20) | |||||
| def process_bind_param(self, value: _E | str | None, dialect): | |||||
| if value is None: | |||||
| return value | |||||
| if isinstance(value, self._enum_class): | |||||
| return value.value | |||||
| elif isinstance(value, str): | |||||
| self._enum_class(value) | |||||
| return value | |||||
| else: | |||||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||||
| def load_dialect_impl(self, dialect): | |||||
| return dialect.type_descriptor(VARCHAR(self._length)) | |||||
| def process_result_value(self, value, dialect) -> _E | None: | |||||
| if value is None: | |||||
| return value | |||||
| if not isinstance(value, str): | |||||
| raise TypeError(f"expected str, got {type(value)}") | |||||
| return self._enum_class(value) | |||||
| def compare_values(self, x, y): | |||||
| if x is None or y is None: | |||||
| return x is y | |||||
| return x == y |
| import json | import json | ||||
| import logging | |||||
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | from typing import TYPE_CHECKING, Any, Optional, Self, Union | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from core.variables import utils as variable_utils | |||||
| from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID | |||||
| from factories.variable_factory import build_segment | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from sqlalchemy import func | |||||
| from sqlalchemy import UniqueConstraint, func | |||||
| from sqlalchemy.orm import Mapped, mapped_column | from sqlalchemy.orm import Mapped, mapped_column | ||||
| import contexts | import contexts | ||||
| from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE | from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.variables import SecretVariable, Variable | |||||
| from core.variables import SecretVariable, Segment, SegmentType, Variable | |||||
| from factories import variable_factory | from factories import variable_factory | ||||
| from libs import helper | from libs import helper | ||||
| from .account import Account | from .account import Account | ||||
| from .base import Base | from .base import Base | ||||
| from .engine import db | from .engine import db | ||||
| from .enums import CreatorUserRole | |||||
| from .types import StringUUID | |||||
| from .enums import CreatorUserRole, DraftVariableType | |||||
| from .types import EnumText, StringUUID | |||||
| _logger = logging.getLogger(__name__) | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| return json.loads(self.inputs) if self.inputs else None | return json.loads(self.inputs) if self.inputs else None | ||||
| @property | @property | ||||
| def outputs_dict(self): | |||||
| def outputs_dict(self) -> dict[str, Any] | None: | |||||
| return json.loads(self.outputs) if self.outputs else None | return json.loads(self.outputs) if self.outputs else None | ||||
| @property | @property | ||||
| return json.loads(self.process_data) if self.process_data else None | return json.loads(self.process_data) if self.process_data else None | ||||
| @property | @property | ||||
| def execution_metadata_dict(self): | |||||
| def execution_metadata_dict(self) -> dict[str, Any] | None: | |||||
| return json.loads(self.execution_metadata) if self.execution_metadata else None | return json.loads(self.execution_metadata) if self.execution_metadata else None | ||||
| @property | @property | ||||
| def to_variable(self) -> Variable: | def to_variable(self) -> Variable: | ||||
| mapping = json.loads(self.data) | mapping = json.loads(self.data) | ||||
| return variable_factory.build_conversation_variable_from_mapping(mapping) | return variable_factory.build_conversation_variable_from_mapping(mapping) | ||||
| # Only `sys.query` and `sys.files` could be modified. | |||||
| _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) | |||||
| def _naive_utc_datetime(): | |||||
| return datetime.now(UTC).replace(tzinfo=None) | |||||
| class WorkflowDraftVariable(Base): | |||||
| @staticmethod | |||||
| def unique_columns() -> list[str]: | |||||
| return [ | |||||
| "app_id", | |||||
| "node_id", | |||||
| "name", | |||||
| ] | |||||
| __tablename__ = "workflow_draft_variables" | |||||
| __table_args__ = (UniqueConstraint(*unique_columns()),) | |||||
| # id is the unique identifier of a draft variable. | |||||
| id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||||
| created_at = mapped_column( | |||||
| db.DateTime, | |||||
| nullable=False, | |||||
| default=_naive_utc_datetime, | |||||
| server_default=func.current_timestamp(), | |||||
| ) | |||||
| updated_at = mapped_column( | |||||
| db.DateTime, | |||||
| nullable=False, | |||||
| default=_naive_utc_datetime, | |||||
| server_default=func.current_timestamp(), | |||||
| onupdate=func.current_timestamp(), | |||||
| ) | |||||
| # "`app_id` maps to the `id` field in the `model.App` model." | |||||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| # `last_edited_at` records when the value of a given draft variable | |||||
| # is edited. | |||||
| # | |||||
| # If it's not edited after creation, its value is `None`. | |||||
| last_edited_at: Mapped[datetime | None] = mapped_column( | |||||
| db.DateTime, | |||||
| nullable=True, | |||||
| default=None, | |||||
| ) | |||||
| # The `node_id` field is special. | |||||
| # | |||||
| # If the variable is a conversation variable or a system variable, then the value of `node_id` | |||||
| # is `conversation` or `sys`, respective. | |||||
| # | |||||
| # Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is | |||||
| # the identity of correspond node in graph definition. An example of node id is `"1745769620734"`. | |||||
| # | |||||
| # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other | |||||
| # "Answer" node conform the rules above.) | |||||
| node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") | |||||
| # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than | |||||
| # 80 chars. | |||||
| # | |||||
| # ref: api/core/workflow/entities/variable_pool.py:18 | |||||
| name: Mapped[str] = mapped_column(sa.String(255), nullable=False) | |||||
| description: Mapped[str] = mapped_column( | |||||
| sa.String(255), | |||||
| default="", | |||||
| nullable=False, | |||||
| ) | |||||
| selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") | |||||
| value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) | |||||
| # JSON string | |||||
| value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") | |||||
| # visible | |||||
| visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) | |||||
| editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) | |||||
| def get_selector(self) -> list[str]: | |||||
| selector = json.loads(self.selector) | |||||
| if not isinstance(selector, list): | |||||
| _logger.error( | |||||
| "invalid selector loaded from database, type=%s, value=%s", | |||||
| type(selector), | |||||
| self.selector, | |||||
| ) | |||||
| raise ValueError("invalid selector.") | |||||
| return selector | |||||
| def _set_selector(self, value: list[str]): | |||||
| self.selector = json.dumps(value) | |||||
| def get_value(self) -> Segment | None: | |||||
| return build_segment(json.loads(self.value)) | |||||
| def set_name(self, name: str): | |||||
| self.name = name | |||||
| self._set_selector([self.node_id, name]) | |||||
| def set_value(self, value: Segment): | |||||
| self.value = json.dumps(value.value) | |||||
| self.value_type = value.value_type | |||||
| def get_node_id(self) -> str | None: | |||||
| if self.get_variable_type() == DraftVariableType.NODE: | |||||
| return self.node_id | |||||
| else: | |||||
| return None | |||||
| def get_variable_type(self) -> DraftVariableType: | |||||
| match self.node_id: | |||||
| case DraftVariableType.CONVERSATION: | |||||
| return DraftVariableType.CONVERSATION | |||||
| case DraftVariableType.SYS: | |||||
| return DraftVariableType.SYS | |||||
| case _: | |||||
| return DraftVariableType.NODE | |||||
| @classmethod | |||||
| def _new( | |||||
| cls, | |||||
| *, | |||||
| app_id: str, | |||||
| node_id: str, | |||||
| name: str, | |||||
| value: Segment, | |||||
| description: str = "", | |||||
| ) -> "WorkflowDraftVariable": | |||||
| variable = WorkflowDraftVariable() | |||||
| variable.created_at = _naive_utc_datetime() | |||||
| variable.updated_at = _naive_utc_datetime() | |||||
| variable.description = description | |||||
| variable.app_id = app_id | |||||
| variable.node_id = node_id | |||||
| variable.name = name | |||||
| variable.app_id = app_id | |||||
| variable.set_value(value) | |||||
| variable._set_selector(list(variable_utils.to_selector(node_id, name))) | |||||
| return variable | |||||
| @classmethod | |||||
| def new_conversation_variable( | |||||
| cls, | |||||
| *, | |||||
| app_id: str, | |||||
| name: str, | |||||
| value: Segment, | |||||
| ) -> "WorkflowDraftVariable": | |||||
| variable = cls._new( | |||||
| app_id=app_id, | |||||
| node_id=CONVERSATION_VARIABLE_NODE_ID, | |||||
| name=name, | |||||
| value=value, | |||||
| ) | |||||
| return variable | |||||
| @classmethod | |||||
| def new_sys_variable( | |||||
| cls, | |||||
| *, | |||||
| app_id: str, | |||||
| name: str, | |||||
| value: Segment, | |||||
| editable: bool = False, | |||||
| ) -> "WorkflowDraftVariable": | |||||
| variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) | |||||
| variable.editable = editable | |||||
| return variable | |||||
| @classmethod | |||||
| def new_node_variable( | |||||
| cls, | |||||
| *, | |||||
| app_id: str, | |||||
| node_id: str, | |||||
| name: str, | |||||
| value: Segment, | |||||
| visible: bool = True, | |||||
| ) -> "WorkflowDraftVariable": | |||||
| variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) | |||||
| variable.visible = visible | |||||
| variable.editable = True | |||||
| return variable | |||||
| @property | |||||
| def edited(self): | |||||
| return self.last_edited_at is not None | |||||
| def is_system_variable_editable(name: str) -> bool: | |||||
| return name in _EDITABLE_SYSTEM_VARIABLE |
| from collections.abc import Callable, Iterable | |||||
| from enum import StrEnum | |||||
| from typing import Any, NamedTuple, TypeVar | |||||
| import pytest | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy import exc as sa_exc | |||||
| from sqlalchemy import insert | |||||
| from sqlalchemy.orm import DeclarativeBase, Mapped, Session | |||||
| from sqlalchemy.sql.sqltypes import VARCHAR | |||||
| from models.types import EnumText | |||||
| _user_type_admin = "admin" | |||||
| _user_type_normal = "normal" | |||||
| class _Base(DeclarativeBase): | |||||
| pass | |||||
| class _UserType(StrEnum): | |||||
| admin = _user_type_admin | |||||
| normal = _user_type_normal | |||||
| class _EnumWithLongValue(StrEnum): | |||||
| unknown = "unknown" | |||||
| a_really_long_enum_values = "a_really_long_enum_values" | |||||
| class _User(_Base): | |||||
| __tablename__ = "users" | |||||
| id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) | |||||
| name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False) | |||||
| user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) | |||||
| user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True) | |||||
| class _ColumnTest(_Base): | |||||
| __tablename__ = "column_test" | |||||
| id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) | |||||
| user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) | |||||
| explicit_length: Mapped[_UserType | None] = sa.Column( | |||||
| EnumText(_UserType, length=50), nullable=True, default=_UserType.normal | |||||
| ) | |||||
| long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False) | |||||
| _T = TypeVar("_T") | |||||
| def _first(it: Iterable[_T]) -> _T: | |||||
| ls = list(it) | |||||
| if not ls: | |||||
| raise ValueError("List is empty") | |||||
| return ls[0] | |||||
| class TestEnumText: | |||||
| def test_column_impl(self): | |||||
| engine = sa.create_engine("sqlite://", echo=False) | |||||
| _Base.metadata.create_all(engine) | |||||
| inspector = sa.inspect(engine) | |||||
| columns = inspector.get_columns(_ColumnTest.__tablename__) | |||||
| user_type_column = _first(c for c in columns if c["name"] == "user_type") | |||||
| sql_type = user_type_column["type"] | |||||
| assert isinstance(user_type_column["type"], VARCHAR) | |||||
| assert sql_type.length == 20 | |||||
| assert user_type_column["nullable"] is False | |||||
| explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length") | |||||
| sql_type = explicit_length_column["type"] | |||||
| assert isinstance(sql_type, VARCHAR) | |||||
| assert sql_type.length == 50 | |||||
| assert explicit_length_column["nullable"] is True | |||||
| long_value_column = _first(c for c in columns if c["name"] == "long_value") | |||||
| sql_type = long_value_column["type"] | |||||
| assert isinstance(sql_type, VARCHAR) | |||||
| assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values) | |||||
| def test_insert_and_select(self): | |||||
| engine = sa.create_engine("sqlite://", echo=False) | |||||
| _Base.metadata.create_all(engine) | |||||
| with Session(engine) as session: | |||||
| admin_user = _User( | |||||
| name="admin", | |||||
| user_type=_UserType.admin, | |||||
| user_type_nullable=None, | |||||
| ) | |||||
| session.add(admin_user) | |||||
| session.flush() | |||||
| admin_user_id = admin_user.id | |||||
| normal_user = _User( | |||||
| name="normal", | |||||
| user_type=_UserType.normal.value, | |||||
| user_type_nullable=_UserType.normal.value, | |||||
| ) | |||||
| session.add(normal_user) | |||||
| session.flush() | |||||
| normal_user_id = normal_user.id | |||||
| session.commit() | |||||
| with Session(engine) as session: | |||||
| user = session.query(_User).filter(_User.id == admin_user_id).first() | |||||
| assert user.user_type == _UserType.admin | |||||
| assert user.user_type_nullable is None | |||||
| with Session(engine) as session: | |||||
| user = session.query(_User).filter(_User.id == normal_user_id).first() | |||||
| assert user.user_type == _UserType.normal | |||||
| assert user.user_type_nullable == _UserType.normal | |||||
| def test_insert_invalid_values(self): | |||||
| def _session_insert_with_value(sess: Session, user_type: Any): | |||||
| user = _User(name="test_user", user_type=user_type) | |||||
| sess.add(user) | |||||
| sess.flush() | |||||
| def _insert_with_user(sess: Session, user_type: Any): | |||||
| stmt = insert(_User).values( | |||||
| { | |||||
| "name": "test_user", | |||||
| "user_type": user_type, | |||||
| } | |||||
| ) | |||||
| sess.execute(stmt) | |||||
| class TestCase(NamedTuple): | |||||
| name: str | |||||
| action: Callable[[Session], None] | |||||
| exc_type: type[Exception] | |||||
| engine = sa.create_engine("sqlite://", echo=False) | |||||
| _Base.metadata.create_all(engine) | |||||
| cases = [ | |||||
| TestCase( | |||||
| name="session insert with invalid value", | |||||
| action=lambda s: _session_insert_with_value(s, "invalid"), | |||||
| exc_type=ValueError, | |||||
| ), | |||||
| TestCase( | |||||
| name="session insert with invalid type", | |||||
| action=lambda s: _session_insert_with_value(s, 1), | |||||
| exc_type=TypeError, | |||||
| ), | |||||
| TestCase( | |||||
| name="insert with invalid value", | |||||
| action=lambda s: _insert_with_user(s, "invalid"), | |||||
| exc_type=ValueError, | |||||
| ), | |||||
| TestCase( | |||||
| name="insert with invalid type", | |||||
| action=lambda s: _insert_with_user(s, 1), | |||||
| exc_type=TypeError, | |||||
| ), | |||||
| ] | |||||
| for idx, c in enumerate(cases, 1): | |||||
| with pytest.raises(sa_exc.StatementError) as exc: | |||||
| with Session(engine) as session: | |||||
| c.action(session) | |||||
| assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}" | |||||
| def test_select_invalid_values(self): | |||||
| engine = sa.create_engine("sqlite://", echo=False) | |||||
| _Base.metadata.create_all(engine) | |||||
| insertion_sql = """ | |||||
| INSERT INTO users (id, name, user_type) VALUES | |||||
| (1, 'invalid_value', 'invalid'); | |||||
| """ | |||||
| with Session(engine) as session: | |||||
| session.execute(sa.text(insertion_sql)) | |||||
| session.commit() | |||||
| with pytest.raises(ValueError) as exc: | |||||
| with Session(engine) as session: | |||||
| _user = session.query(_User).filter(_User.id == 1).first() |