Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import enum
  2. import uuid
  3. from typing import Any, Generic, TypeVar
  4. from sqlalchemy import CHAR, VARCHAR, TypeDecorator
  5. from sqlalchemy.dialects.postgresql import UUID
  6. from sqlalchemy.engine.interfaces import Dialect
  7. from sqlalchemy.sql.type_api import TypeEngine
  8. class StringUUID(TypeDecorator[uuid.UUID | str | None]):
  9. impl = CHAR
  10. cache_ok = True
  11. def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
  12. if value is None:
  13. return value
  14. elif dialect.name == "postgresql":
  15. return str(value)
  16. else:
  17. if isinstance(value, uuid.UUID):
  18. return value.hex
  19. return value
  20. def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
  21. if dialect.name == "postgresql":
  22. return dialect.type_descriptor(UUID())
  23. else:
  24. return dialect.type_descriptor(CHAR(36))
  25. def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
  26. if value is None:
  27. return value
  28. return str(value)
  29. _E = TypeVar("_E", bound=enum.StrEnum)
  30. class EnumText(TypeDecorator[_E | None], Generic[_E]):
  31. impl = VARCHAR
  32. cache_ok = True
  33. _length: int
  34. _enum_class: type[_E]
  35. def __init__(self, enum_class: type[_E], length: int | None = None):
  36. self._enum_class = enum_class
  37. max_enum_value_len = max(len(e.value) for e in enum_class)
  38. if length is not None:
  39. if length < max_enum_value_len:
  40. raise ValueError("length should be greater than enum value length.")
  41. self._length = length
  42. else:
  43. # leave some rooms for future longer enum values.
  44. self._length = max(max_enum_value_len, 20)
  45. def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
  46. if value is None:
  47. return value
  48. if isinstance(value, self._enum_class):
  49. return value.value
  50. # Since _E is bound to StrEnum which inherits from str, at this point value must be str
  51. self._enum_class(value)
  52. return value
  53. def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
  54. return dialect.type_descriptor(VARCHAR(self._length))
  55. def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
  56. if value is None:
  57. return value
  58. # Type annotation guarantees value is str at this point
  59. return self._enum_class(value)
  60. def compare_values(self, x: _E | None, y: _E | None) -> bool:
  61. if x is None or y is None:
  62. return x is y
  63. return x == y