Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

redis_conn.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import json
  18. import uuid
  19. import valkey as redis
  20. from rag import settings
  21. from rag.utils import singleton
  22. from valkey.lock import Lock
  23. class RedisMsg:
  24. def __init__(self, consumer, queue_name, group_name, msg_id, message):
  25. self.__consumer = consumer
  26. self.__queue_name = queue_name
  27. self.__group_name = group_name
  28. self.__msg_id = msg_id
  29. self.__message = json.loads(message["message"])
  30. def ack(self):
  31. try:
  32. self.__consumer.xack(self.__queue_name, self.__group_name, self.__msg_id)
  33. return True
  34. except Exception as e:
  35. logging.warning("[EXCEPTION]ack" + str(self.__queue_name) + "||" + str(e))
  36. return False
  37. def get_message(self):
  38. return self.__message
  39. def get_msg_id(self):
  40. return self.__msg_id
  41. @singleton
  42. class RedisDB:
  43. lua_delete_if_equal = None
  44. LUA_DELETE_IF_EQUAL_SCRIPT = """
  45. local current_value = redis.call('get', KEYS[1])
  46. if current_value and current_value == ARGV[1] then
  47. redis.call('del', KEYS[1])
  48. return 1
  49. end
  50. return 0
  51. """
  52. def __init__(self):
  53. self.REDIS = None
  54. self.config = settings.REDIS
  55. self.__open__()
  56. def register_scripts(self) -> None:
  57. cls = self.__class__
  58. client = self.REDIS
  59. cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
  60. def __open__(self):
  61. try:
  62. self.REDIS = redis.StrictRedis(
  63. host=self.config["host"].split(":")[0],
  64. port=int(self.config.get("host", ":6379").split(":")[1]),
  65. db=int(self.config.get("db", 1)),
  66. password=self.config.get("password"),
  67. decode_responses=True,
  68. )
  69. self.register_scripts()
  70. except Exception:
  71. logging.warning("Redis can't be connected.")
  72. return self.REDIS
  73. def health(self):
  74. self.REDIS.ping()
  75. a, b = "xx", "yy"
  76. self.REDIS.set(a, b, 3)
  77. if self.REDIS.get(a) == b:
  78. return True
  79. def is_alive(self):
  80. return self.REDIS is not None
  81. def exist(self, k):
  82. if not self.REDIS:
  83. return
  84. try:
  85. return self.REDIS.exists(k)
  86. except Exception as e:
  87. logging.warning("RedisDB.exist " + str(k) + " got exception: " + str(e))
  88. self.__open__()
  89. def get(self, k):
  90. if not self.REDIS:
  91. return
  92. try:
  93. return self.REDIS.get(k)
  94. except Exception as e:
  95. logging.warning("RedisDB.get " + str(k) + " got exception: " + str(e))
  96. self.__open__()
  97. def set_obj(self, k, obj, exp=3600):
  98. try:
  99. self.REDIS.set(k, json.dumps(obj, ensure_ascii=False), exp)
  100. return True
  101. except Exception as e:
  102. logging.warning("RedisDB.set_obj " + str(k) + " got exception: " + str(e))
  103. self.__open__()
  104. return False
  105. def set(self, k, v, exp=3600):
  106. try:
  107. self.REDIS.set(k, v, exp)
  108. return True
  109. except Exception as e:
  110. logging.warning("RedisDB.set " + str(k) + " got exception: " + str(e))
  111. self.__open__()
  112. return False
  113. def sadd(self, key: str, member: str):
  114. try:
  115. self.REDIS.sadd(key, member)
  116. return True
  117. except Exception as e:
  118. logging.warning("RedisDB.sadd " + str(key) + " got exception: " + str(e))
  119. self.__open__()
  120. return False
  121. def srem(self, key: str, member: str):
  122. try:
  123. self.REDIS.srem(key, member)
  124. return True
  125. except Exception as e:
  126. logging.warning("RedisDB.srem " + str(key) + " got exception: " + str(e))
  127. self.__open__()
  128. return False
  129. def smembers(self, key: str):
  130. try:
  131. res = self.REDIS.smembers(key)
  132. return res
  133. except Exception as e:
  134. logging.warning(
  135. "RedisDB.smembers " + str(key) + " got exception: " + str(e)
  136. )
  137. self.__open__()
  138. return None
  139. def zadd(self, key: str, member: str, score: float):
  140. try:
  141. self.REDIS.zadd(key, {member: score})
  142. return True
  143. except Exception as e:
  144. logging.warning("RedisDB.zadd " + str(key) + " got exception: " + str(e))
  145. self.__open__()
  146. return False
  147. def zcount(self, key: str, min: float, max: float):
  148. try:
  149. res = self.REDIS.zcount(key, min, max)
  150. return res
  151. except Exception as e:
  152. logging.warning("RedisDB.zcount " + str(key) + " got exception: " + str(e))
  153. self.__open__()
  154. return 0
  155. def zpopmin(self, key: str, count: int):
  156. try:
  157. res = self.REDIS.zpopmin(key, count)
  158. return res
  159. except Exception as e:
  160. logging.warning("RedisDB.zpopmin " + str(key) + " got exception: " + str(e))
  161. self.__open__()
  162. return None
  163. def zrangebyscore(self, key: str, min: float, max: float):
  164. try:
  165. res = self.REDIS.zrangebyscore(key, min, max)
  166. return res
  167. except Exception as e:
  168. logging.warning(
  169. "RedisDB.zrangebyscore " + str(key) + " got exception: " + str(e)
  170. )
  171. self.__open__()
  172. return None
  173. def transaction(self, key, value, exp=3600):
  174. try:
  175. pipeline = self.REDIS.pipeline(transaction=True)
  176. pipeline.set(key, value, exp, nx=True)
  177. pipeline.execute()
  178. return True
  179. except Exception as e:
  180. logging.warning(
  181. "RedisDB.transaction " + str(key) + " got exception: " + str(e)
  182. )
  183. self.__open__()
  184. return False
  185. def queue_product(self, queue, message) -> bool:
  186. for _ in range(3):
  187. try:
  188. payload = {"message": json.dumps(message)}
  189. self.REDIS.xadd(queue, payload)
  190. return True
  191. except Exception as e:
  192. logging.exception(
  193. "RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
  194. )
  195. return False
  196. def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
  197. """https://redis.io/docs/latest/commands/xreadgroup/"""
  198. try:
  199. group_info = self.REDIS.xinfo_groups(queue_name)
  200. if not any(gi["name"] == group_name for gi in group_info):
  201. self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
  202. args = {
  203. "groupname": group_name,
  204. "consumername": consumer_name,
  205. "count": 1,
  206. "block": 5,
  207. "streams": {queue_name: msg_id},
  208. }
  209. messages = self.REDIS.xreadgroup(**args)
  210. if not messages:
  211. return None
  212. stream, element_list = messages[0]
  213. if not element_list:
  214. return None
  215. msg_id, payload = element_list[0]
  216. res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
  217. return res
  218. except Exception as e:
  219. if str(e) == 'no such key':
  220. pass
  221. else:
  222. logging.exception(
  223. "RedisDB.queue_consumer "
  224. + str(queue_name)
  225. + " got exception: "
  226. + str(e)
  227. )
  228. return None
  229. def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
  230. try:
  231. for queue_name in queue_names:
  232. try:
  233. group_info = self.REDIS.xinfo_groups(queue_name)
  234. except Exception as e:
  235. if str(e) == 'no such key':
  236. logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} doesn't exist")
  237. continue
  238. if not any(gi["name"] == group_name for gi in group_info):
  239. logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} group {group_name} doesn't exist")
  240. continue
  241. current_min = 0
  242. while True:
  243. payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
  244. if not payload:
  245. break
  246. current_min = payload.get_msg_id()
  247. logging.info(f"RedisDB.get_unacked_iterator {queue_name} {consumer_name} {current_min}")
  248. yield payload
  249. except Exception:
  250. logging.exception(
  251. "RedisDB.get_unacked_iterator got exception: "
  252. )
  253. self.__open__()
  254. def queue_info(self, queue, group_name) -> dict | None:
  255. try:
  256. groups = self.REDIS.xinfo_groups(queue)
  257. for group in groups:
  258. if group["name"] == group_name:
  259. return group
  260. except Exception as e:
  261. logging.warning(
  262. "RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
  263. )
  264. return None
  265. def delete_if_equal(self, key: str, expected_value: str) -> bool:
  266. """
  267. Do follwing atomically:
  268. Delete a key if its value is equals to the given one, do nothing otherwise.
  269. """
  270. return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS))
  271. REDIS_CONN = RedisDB()
  272. class RedisDistributedLock:
  273. def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
  274. self.lock_key = lock_key
  275. if lock_value:
  276. self.lock_value = lock_value
  277. else:
  278. self.lock_value = str(uuid.uuid4())
  279. self.timeout = timeout
  280. self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)
  281. def acquire(self):
  282. REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
  283. return self.lock.acquire(token=self.lock_value)
  284. def release(self):
  285. return self.lock.release()
  286. def __enter__(self):
  287. self.acquire()
  288. def __exit__(self, exception_type, exception_value, exception_traceback):
  289. self.release()