Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

db_utils.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import operator
  17. from functools import reduce
  18. from typing import Dict, Type, Union
  19. from playhouse.pool import PooledMySQLDatabase
  20. from api.utils import current_timestamp, timestamp_to_date
  21. from api.db.db_models import DB, DataBaseModel
  22. from api.db.runtime_config import RuntimeConfig
  23. from api.utils.log_utils import getLogger
  24. from enum import Enum
  25. LOGGER = getLogger()
  26. @DB.connection_context()
  27. def bulk_insert_into_db(model, data_source, replace_on_conflict=False):
  28. DB.create_tables([model])
  29. for i, data in enumerate(data_source):
  30. current_time = current_timestamp() + i
  31. current_date = timestamp_to_date(current_time)
  32. if 'create_time' not in data:
  33. data['create_time'] = current_time
  34. data['create_date'] = timestamp_to_date(data['create_time'])
  35. data['update_time'] = current_time
  36. data['update_date'] = current_date
  37. preserve = tuple(data_source[0].keys() - {'create_time', 'create_date'})
  38. batch_size = 1000
  39. for i in range(0, len(data_source), batch_size):
  40. with DB.atomic():
  41. query = model.insert_many(data_source[i:i + batch_size])
  42. if replace_on_conflict:
  43. if isinstance(DB, PooledMySQLDatabase):
  44. query = query.on_conflict(preserve=preserve)
  45. else:
  46. query = query.on_conflict(conflict_target="id", preserve=preserve)
  47. query.execute()
  48. def get_dynamic_db_model(base, job_id):
  49. return type(base.model(
  50. table_index=get_dynamic_tracking_table_index(job_id=job_id)))
  51. def get_dynamic_tracking_table_index(job_id):
  52. return job_id[:8]
  53. def fill_db_model_object(model_object, human_model_dict):
  54. for k, v in human_model_dict.items():
  55. attr_name = 'f_%s' % k
  56. if hasattr(model_object.__class__, attr_name):
  57. setattr(model_object, attr_name, v)
  58. return model_object
  59. # https://docs.peewee-orm.com/en/latest/peewee/query_operators.html
  60. supported_operators = {
  61. '==': operator.eq,
  62. '<': operator.lt,
  63. '<=': operator.le,
  64. '>': operator.gt,
  65. '>=': operator.ge,
  66. '!=': operator.ne,
  67. '<<': operator.lshift,
  68. '>>': operator.rshift,
  69. '%': operator.mod,
  70. '**': operator.pow,
  71. '^': operator.xor,
  72. '~': operator.inv,
  73. }
  74. def query_dict2expression(
  75. model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]):
  76. expression = []
  77. for field, value in query.items():
  78. if not isinstance(value, (list, tuple)):
  79. value = ('==', value)
  80. op, *val = value
  81. field = getattr(model, f'f_{field}')
  82. value = supported_operators[op](
  83. field, val[0]) if op in supported_operators else getattr(
  84. field, op)(
  85. *val)
  86. expression.append(value)
  87. return reduce(operator.iand, expression)
  88. def query_db(model: Type[DataBaseModel], limit: int = 0, offset: int = 0,
  89. query: dict = None, order_by: Union[str, list, tuple] = None):
  90. data = model.select()
  91. if query:
  92. data = data.where(query_dict2expression(model, query))
  93. count = data.count()
  94. if not order_by:
  95. order_by = 'create_time'
  96. if not isinstance(order_by, (list, tuple)):
  97. order_by = (order_by, 'asc')
  98. order_by, order = order_by
  99. order_by = getattr(model, f'f_{order_by}')
  100. order_by = getattr(order_by, order)()
  101. data = data.order_by(order_by)
  102. if limit > 0:
  103. data = data.limit(limit)
  104. if offset > 0:
  105. data = data.offset(offset)
  106. return list(data), count