You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

types.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import enum
  2. from typing import Generic, TypeVar
  3. from sqlalchemy import CHAR, VARCHAR, TypeDecorator
  4. from sqlalchemy.dialects.postgresql import UUID
  5. class StringUUID(TypeDecorator):
  6. impl = CHAR
  7. cache_ok = True
  8. def process_bind_param(self, value, dialect):
  9. if value is None:
  10. return value
  11. elif dialect.name == "postgresql":
  12. return str(value)
  13. else:
  14. return value.hex
  15. def load_dialect_impl(self, dialect):
  16. if dialect.name == "postgresql":
  17. return dialect.type_descriptor(UUID())
  18. else:
  19. return dialect.type_descriptor(CHAR(36))
  20. def process_result_value(self, value, dialect):
  21. if value is None:
  22. return value
  23. return str(value)
  24. _E = TypeVar("_E", bound=enum.StrEnum)
  25. class EnumText(TypeDecorator, Generic[_E]):
  26. impl = VARCHAR
  27. cache_ok = True
  28. _length: int
  29. _enum_class: type[_E]
  30. def __init__(self, enum_class: type[_E], length: int | None = None):
  31. self._enum_class = enum_class
  32. max_enum_value_len = max(len(e.value) for e in enum_class)
  33. if length is not None:
  34. if length < max_enum_value_len:
  35. raise ValueError("length should be greater than enum value length.")
  36. self._length = length
  37. else:
  38. # leave some rooms for future longer enum values.
  39. self._length = max(max_enum_value_len, 20)
  40. def process_bind_param(self, value: _E | str | None, dialect):
  41. if value is None:
  42. return value
  43. if isinstance(value, self._enum_class):
  44. return value.value
  45. elif isinstance(value, str):
  46. self._enum_class(value)
  47. return value
  48. else:
  49. raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
  50. def load_dialect_impl(self, dialect):
  51. return dialect.type_descriptor(VARCHAR(self._length))
  52. def process_result_value(self, value, dialect) -> _E | None:
  53. if value is None:
  54. return value
  55. if not isinstance(value, str):
  56. raise TypeError(f"expected str, got {type(value)}")
  57. return self._enum_class(value)
  58. def compare_values(self, x, y):
  59. if x is None or y is None:
  60. return x is y
  61. return x == y