| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- import enum
- from typing import Generic, TypeVar
-
- from sqlalchemy import CHAR, VARCHAR, TypeDecorator
- from sqlalchemy.dialects.postgresql import UUID
-
-
- class StringUUID(TypeDecorator):
- impl = CHAR
- cache_ok = True
-
- def process_bind_param(self, value, dialect):
- if value is None:
- return value
- elif dialect.name == "postgresql":
- return str(value)
- else:
- return value.hex
-
- def load_dialect_impl(self, dialect):
- if dialect.name == "postgresql":
- return dialect.type_descriptor(UUID())
- else:
- return dialect.type_descriptor(CHAR(36))
-
- def process_result_value(self, value, dialect):
- if value is None:
- return value
- return str(value)
-
-
- _E = TypeVar("_E", bound=enum.StrEnum)
-
-
- class EnumText(TypeDecorator, Generic[_E]):
- impl = VARCHAR
- cache_ok = True
-
- _length: int
- _enum_class: type[_E]
-
- def __init__(self, enum_class: type[_E], length: int | None = None):
- self._enum_class = enum_class
- max_enum_value_len = max(len(e.value) for e in enum_class)
- if length is not None:
- if length < max_enum_value_len:
- raise ValueError("length should be greater than enum value length.")
- self._length = length
- else:
- # leave some rooms for future longer enum values.
- self._length = max(max_enum_value_len, 20)
-
- def process_bind_param(self, value: _E | str | None, dialect):
- if value is None:
- return value
- if isinstance(value, self._enum_class):
- return value.value
- elif isinstance(value, str):
- self._enum_class(value)
- return value
- else:
- raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
-
- def load_dialect_impl(self, dialect):
- return dialect.type_descriptor(VARCHAR(self._length))
-
- def process_result_value(self, value, dialect) -> _E | None:
- if value is None:
- return value
- if not isinstance(value, str):
- raise TypeError(f"expected str, got {type(value)}")
- return self._enum_class(value)
-
- def compare_values(self, x, y):
- if x is None or y is None:
- return x is y
- return x == y
|