Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

redis_conn.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import json
  18. import uuid
  19. import valkey as redis
  20. from rag import settings
  21. from rag.utils import singleton
  22. from valkey.lock import Lock
  23. import trio
  24. class RedisMsg:
  25. def __init__(self, consumer, queue_name, group_name, msg_id, message):
  26. self.__consumer = consumer
  27. self.__queue_name = queue_name
  28. self.__group_name = group_name
  29. self.__msg_id = msg_id
  30. self.__message = json.loads(message["message"])
  31. def ack(self):
  32. try:
  33. self.__consumer.xack(self.__queue_name, self.__group_name, self.__msg_id)
  34. return True
  35. except Exception as e:
  36. logging.warning("[EXCEPTION]ack" + str(self.__queue_name) + "||" + str(e))
  37. return False
  38. def get_message(self):
  39. return self.__message
  40. def get_msg_id(self):
  41. return self.__msg_id
  42. @singleton
  43. class RedisDB:
  44. lua_delete_if_equal = None
  45. LUA_DELETE_IF_EQUAL_SCRIPT = """
  46. local current_value = redis.call('get', KEYS[1])
  47. if current_value and current_value == ARGV[1] then
  48. redis.call('del', KEYS[1])
  49. return 1
  50. end
  51. return 0
  52. """
  53. def __init__(self):
  54. self.REDIS = None
  55. self.config = settings.REDIS
  56. self.__open__()
  57. def register_scripts(self) -> None:
  58. cls = self.__class__
  59. client = self.REDIS
  60. cls.lua_delete_if_equal = client.register_script(cls.LUA_DELETE_IF_EQUAL_SCRIPT)
  61. def __open__(self):
  62. try:
  63. self.REDIS = redis.StrictRedis(
  64. host=self.config["host"].split(":")[0],
  65. port=int(self.config.get("host", ":6379").split(":")[1]),
  66. db=int(self.config.get("db", 1)),
  67. password=self.config.get("password"),
  68. decode_responses=True,
  69. )
  70. self.register_scripts()
  71. except Exception:
  72. logging.warning("Redis can't be connected.")
  73. return self.REDIS
  74. def health(self):
  75. self.REDIS.ping()
  76. a, b = "xx", "yy"
  77. self.REDIS.set(a, b, 3)
  78. if self.REDIS.get(a) == b:
  79. return True
  80. def is_alive(self):
  81. return self.REDIS is not None
  82. def exist(self, k):
  83. if not self.REDIS:
  84. return
  85. try:
  86. return self.REDIS.exists(k)
  87. except Exception as e:
  88. logging.warning("RedisDB.exist " + str(k) + " got exception: " + str(e))
  89. self.__open__()
  90. def get(self, k):
  91. if not self.REDIS:
  92. return
  93. try:
  94. return self.REDIS.get(k)
  95. except Exception as e:
  96. logging.warning("RedisDB.get " + str(k) + " got exception: " + str(e))
  97. self.__open__()
  98. def set_obj(self, k, obj, exp=3600):
  99. try:
  100. self.REDIS.set(k, json.dumps(obj, ensure_ascii=False), exp)
  101. return True
  102. except Exception as e:
  103. logging.warning("RedisDB.set_obj " + str(k) + " got exception: " + str(e))
  104. self.__open__()
  105. return False
  106. def set(self, k, v, exp=3600):
  107. try:
  108. self.REDIS.set(k, v, exp)
  109. return True
  110. except Exception as e:
  111. logging.warning("RedisDB.set " + str(k) + " got exception: " + str(e))
  112. self.__open__()
  113. return False
  114. def sadd(self, key: str, member: str):
  115. try:
  116. self.REDIS.sadd(key, member)
  117. return True
  118. except Exception as e:
  119. logging.warning("RedisDB.sadd " + str(key) + " got exception: " + str(e))
  120. self.__open__()
  121. return False
  122. def srem(self, key: str, member: str):
  123. try:
  124. self.REDIS.srem(key, member)
  125. return True
  126. except Exception as e:
  127. logging.warning("RedisDB.srem " + str(key) + " got exception: " + str(e))
  128. self.__open__()
  129. return False
  130. def smembers(self, key: str):
  131. try:
  132. res = self.REDIS.smembers(key)
  133. return res
  134. except Exception as e:
  135. logging.warning(
  136. "RedisDB.smembers " + str(key) + " got exception: " + str(e)
  137. )
  138. self.__open__()
  139. return None
  140. def zadd(self, key: str, member: str, score: float):
  141. try:
  142. self.REDIS.zadd(key, {member: score})
  143. return True
  144. except Exception as e:
  145. logging.warning("RedisDB.zadd " + str(key) + " got exception: " + str(e))
  146. self.__open__()
  147. return False
  148. def zcount(self, key: str, min: float, max: float):
  149. try:
  150. res = self.REDIS.zcount(key, min, max)
  151. return res
  152. except Exception as e:
  153. logging.warning("RedisDB.zcount " + str(key) + " got exception: " + str(e))
  154. self.__open__()
  155. return 0
  156. def zpopmin(self, key: str, count: int):
  157. try:
  158. res = self.REDIS.zpopmin(key, count)
  159. return res
  160. except Exception as e:
  161. logging.warning("RedisDB.zpopmin " + str(key) + " got exception: " + str(e))
  162. self.__open__()
  163. return None
  164. def zrangebyscore(self, key: str, min: float, max: float):
  165. try:
  166. res = self.REDIS.zrangebyscore(key, min, max)
  167. return res
  168. except Exception as e:
  169. logging.warning(
  170. "RedisDB.zrangebyscore " + str(key) + " got exception: " + str(e)
  171. )
  172. self.__open__()
  173. return None
  174. def transaction(self, key, value, exp=3600):
  175. try:
  176. pipeline = self.REDIS.pipeline(transaction=True)
  177. pipeline.set(key, value, exp, nx=True)
  178. pipeline.execute()
  179. return True
  180. except Exception as e:
  181. logging.warning(
  182. "RedisDB.transaction " + str(key) + " got exception: " + str(e)
  183. )
  184. self.__open__()
  185. return False
  186. def queue_product(self, queue, message) -> bool:
  187. for _ in range(3):
  188. try:
  189. payload = {"message": json.dumps(message)}
  190. self.REDIS.xadd(queue, payload)
  191. return True
  192. except Exception as e:
  193. logging.exception(
  194. "RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
  195. )
  196. return False
  197. def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
  198. """https://redis.io/docs/latest/commands/xreadgroup/"""
  199. try:
  200. group_info = self.REDIS.xinfo_groups(queue_name)
  201. if not any(gi["name"] == group_name for gi in group_info):
  202. self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
  203. args = {
  204. "groupname": group_name,
  205. "consumername": consumer_name,
  206. "count": 1,
  207. "block": 5,
  208. "streams": {queue_name: msg_id},
  209. }
  210. messages = self.REDIS.xreadgroup(**args)
  211. if not messages:
  212. return None
  213. stream, element_list = messages[0]
  214. if not element_list:
  215. return None
  216. msg_id, payload = element_list[0]
  217. res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
  218. return res
  219. except Exception as e:
  220. if str(e) == 'no such key':
  221. pass
  222. else:
  223. logging.exception(
  224. "RedisDB.queue_consumer "
  225. + str(queue_name)
  226. + " got exception: "
  227. + str(e)
  228. )
  229. return None
  230. def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
  231. try:
  232. for queue_name in queue_names:
  233. try:
  234. group_info = self.REDIS.xinfo_groups(queue_name)
  235. except Exception as e:
  236. if str(e) == 'no such key':
  237. logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} doesn't exist")
  238. continue
  239. if not any(gi["name"] == group_name for gi in group_info):
  240. logging.warning(f"RedisDB.get_unacked_iterator queue {queue_name} group {group_name} doesn't exist")
  241. continue
  242. current_min = 0
  243. while True:
  244. payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min)
  245. if not payload:
  246. break
  247. current_min = payload.get_msg_id()
  248. logging.info(f"RedisDB.get_unacked_iterator {queue_name} {consumer_name} {current_min}")
  249. yield payload
  250. except Exception:
  251. logging.exception(
  252. "RedisDB.get_unacked_iterator got exception: "
  253. )
  254. self.__open__()
  255. def get_pending_msg(self, queue, group_name):
  256. try:
  257. messages = self.REDIS.xpending_range(queue, group_name, '-', '+', 10)
  258. return messages
  259. except Exception as e:
  260. if 'No such key' not in (str(e) or ''):
  261. logging.warning(
  262. "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
  263. )
  264. return []
  265. def requeue_msg(self, queue: str, group_name: str, msg_id: str):
  266. try:
  267. messages = self.REDIS.xrange(queue, msg_id, msg_id)
  268. if messages:
  269. self.REDIS.xadd(queue, messages[0][1])
  270. self.REDIS.xack(queue, group_name, msg_id)
  271. except Exception as e:
  272. logging.warning(
  273. "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
  274. )
  275. def queue_info(self, queue, group_name) -> dict | None:
  276. try:
  277. groups = self.REDIS.xinfo_groups(queue)
  278. for group in groups:
  279. if group["name"] == group_name:
  280. return group
  281. except Exception as e:
  282. logging.warning(
  283. "RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
  284. )
  285. return None
  286. def delete_if_equal(self, key: str, expected_value: str) -> bool:
  287. """
  288. Do follwing atomically:
  289. Delete a key if its value is equals to the given one, do nothing otherwise.
  290. """
  291. return bool(self.lua_delete_if_equal(keys=[key], args=[expected_value], client=self.REDIS))
  292. def delete(self, key) -> bool:
  293. try:
  294. self.REDIS.delete(key)
  295. return True
  296. except Exception as e:
  297. logging.warning("RedisDB.delete " + str(key) + " got exception: " + str(e))
  298. self.__open__()
  299. return False
  300. REDIS_CONN = RedisDB()
  301. class RedisDistributedLock:
  302. def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
  303. self.lock_key = lock_key
  304. if lock_value:
  305. self.lock_value = lock_value
  306. else:
  307. self.lock_value = str(uuid.uuid4())
  308. self.timeout = timeout
  309. self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)
  310. def acquire(self):
  311. REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
  312. return self.lock.acquire(token=self.lock_value)
  313. async def spin_acquire(self):
  314. REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)
  315. while True:
  316. if self.lock.acquire(token=self.lock_value):
  317. break
  318. await trio.sleep(10)
  319. def release(self):
  320. REDIS_CONN.delete_if_equal(self.lock_key, self.lock_value)