您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

encoders.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import dataclasses
  2. import datetime
  3. from collections import defaultdict, deque
  4. from collections.abc import Callable
  5. from decimal import Decimal
  6. from enum import Enum
  7. from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
  8. from pathlib import Path, PurePath
  9. from re import Pattern
  10. from types import GeneratorType
  11. from typing import Any, Literal, Optional, Union
  12. from uuid import UUID
  13. from pydantic import BaseModel
  14. from pydantic.networks import AnyUrl, NameEmail
  15. from pydantic.types import SecretBytes, SecretStr
  16. from pydantic_core import Url
  17. from pydantic_extra_types.color import Color
  18. def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
  19. return model.model_dump(mode=mode, **kwargs)
  20. # Taken from Pydantic v1 as is
  21. def isoformat(o: Union[datetime.date, datetime.time]) -> str:
  22. return o.isoformat()
  23. # Taken from Pydantic v1 as is
  24. # TODO: pv2 should this return strings instead?
  25. def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
  26. """
  27. Encodes a Decimal as int of there's no exponent, otherwise float
  28. This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
  29. where a integer (but not int typed) is used. Encoding this as a float
  30. results in failed round-tripping between encode and parse.
  31. Our Id type is a prime example of this.
  32. >>> decimal_encoder(Decimal("1.0"))
  33. 1.0
  34. >>> decimal_encoder(Decimal("1"))
  35. 1
  36. """
  37. if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
  38. return int(dec_value)
  39. else:
  40. return float(dec_value)
  41. ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
  42. bytes: lambda o: o.decode(),
  43. Color: str,
  44. datetime.date: isoformat,
  45. datetime.datetime: isoformat,
  46. datetime.time: isoformat,
  47. datetime.timedelta: lambda td: td.total_seconds(),
  48. Decimal: decimal_encoder,
  49. Enum: lambda o: o.value,
  50. frozenset: list,
  51. deque: list,
  52. GeneratorType: list,
  53. IPv4Address: str,
  54. IPv4Interface: str,
  55. IPv4Network: str,
  56. IPv6Address: str,
  57. IPv6Interface: str,
  58. IPv6Network: str,
  59. NameEmail: str,
  60. Path: str,
  61. Pattern: lambda o: o.pattern,
  62. SecretBytes: str,
  63. SecretStr: str,
  64. set: list,
  65. UUID: str,
  66. Url: str,
  67. AnyUrl: str,
  68. }
  69. def generate_encoders_by_class_tuples(
  70. type_encoder_map: dict[Any, Callable[[Any], Any]],
  71. ) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
  72. encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple)
  73. for type_, encoder in type_encoder_map.items():
  74. encoders_by_class_tuples[encoder] += (type_,)
  75. return encoders_by_class_tuples
  76. encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
  77. def jsonable_encoder(
  78. obj: Any,
  79. by_alias: bool = True,
  80. exclude_unset: bool = False,
  81. exclude_defaults: bool = False,
  82. exclude_none: bool = False,
  83. custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
  84. sqlalchemy_safe: bool = True,
  85. ) -> Any:
  86. custom_encoder = custom_encoder or {}
  87. if custom_encoder:
  88. if type(obj) in custom_encoder:
  89. return custom_encoder[type(obj)](obj)
  90. else:
  91. for encoder_type, encoder_instance in custom_encoder.items():
  92. if isinstance(obj, encoder_type):
  93. return encoder_instance(obj)
  94. if isinstance(obj, BaseModel):
  95. obj_dict = _model_dump(
  96. obj,
  97. mode="json",
  98. include=None,
  99. exclude=None,
  100. by_alias=by_alias,
  101. exclude_unset=exclude_unset,
  102. exclude_none=exclude_none,
  103. exclude_defaults=exclude_defaults,
  104. )
  105. if "__root__" in obj_dict:
  106. obj_dict = obj_dict["__root__"]
  107. return jsonable_encoder(
  108. obj_dict,
  109. exclude_none=exclude_none,
  110. exclude_defaults=exclude_defaults,
  111. sqlalchemy_safe=sqlalchemy_safe,
  112. )
  113. if dataclasses.is_dataclass(obj):
  114. # Ensure obj is a dataclass instance, not a dataclass type
  115. if not isinstance(obj, type):
  116. obj_dict = dataclasses.asdict(obj)
  117. return jsonable_encoder(
  118. obj_dict,
  119. by_alias=by_alias,
  120. exclude_unset=exclude_unset,
  121. exclude_defaults=exclude_defaults,
  122. exclude_none=exclude_none,
  123. custom_encoder=custom_encoder,
  124. sqlalchemy_safe=sqlalchemy_safe,
  125. )
  126. if isinstance(obj, Enum):
  127. return obj.value
  128. if isinstance(obj, PurePath):
  129. return str(obj)
  130. if isinstance(obj, str | int | float | type(None)):
  131. return obj
  132. if isinstance(obj, Decimal):
  133. return format(obj, "f")
  134. if isinstance(obj, dict):
  135. encoded_dict = {}
  136. for key, value in obj.items():
  137. if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (
  138. value is not None or not exclude_none
  139. ):
  140. encoded_key = jsonable_encoder(
  141. key,
  142. by_alias=by_alias,
  143. exclude_unset=exclude_unset,
  144. exclude_none=exclude_none,
  145. custom_encoder=custom_encoder,
  146. sqlalchemy_safe=sqlalchemy_safe,
  147. )
  148. encoded_value = jsonable_encoder(
  149. value,
  150. by_alias=by_alias,
  151. exclude_unset=exclude_unset,
  152. exclude_none=exclude_none,
  153. custom_encoder=custom_encoder,
  154. sqlalchemy_safe=sqlalchemy_safe,
  155. )
  156. encoded_dict[encoded_key] = encoded_value
  157. return encoded_dict
  158. if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
  159. encoded_list = []
  160. for item in obj:
  161. encoded_list.append(
  162. jsonable_encoder(
  163. item,
  164. by_alias=by_alias,
  165. exclude_unset=exclude_unset,
  166. exclude_defaults=exclude_defaults,
  167. exclude_none=exclude_none,
  168. custom_encoder=custom_encoder,
  169. sqlalchemy_safe=sqlalchemy_safe,
  170. )
  171. )
  172. return encoded_list
  173. if type(obj) in ENCODERS_BY_TYPE:
  174. return ENCODERS_BY_TYPE[type(obj)](obj)
  175. for encoder, classes_tuple in encoders_by_class_tuples.items():
  176. if isinstance(obj, classes_tuple):
  177. return encoder(obj)
  178. try:
  179. data = dict(obj)
  180. except Exception as e:
  181. errors: list[Exception] = []
  182. errors.append(e)
  183. try:
  184. data = vars(obj)
  185. except Exception as e:
  186. errors.append(e)
  187. raise ValueError(errors) from e
  188. return jsonable_encoder(
  189. data,
  190. by_alias=by_alias,
  191. exclude_unset=exclude_unset,
  192. exclude_defaults=exclude_defaults,
  193. exclude_none=exclude_none,
  194. custom_encoder=custom_encoder,
  195. sqlalchemy_safe=sqlalchemy_safe,
  196. )