Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

es_conn.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. import re
  2. import json
  3. import time
  4. import copy
  5. import elasticsearch
  6. from elastic_transport import ConnectionTimeout
  7. from elasticsearch import Elasticsearch
  8. from elasticsearch_dsl import UpdateByQuery, Search, Index
  9. from rag.settings import es_logger
  10. from rag import settings
  11. from rag.utils import singleton
  12. es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
  13. @singleton
  14. class HuEs:
  15. def __init__(self):
  16. self.info = {}
  17. self.conn()
  18. self.idxnm = settings.ES.get("index_name", "")
  19. if not self.es.ping():
  20. raise Exception("Can't connect to ES cluster")
  21. def conn(self):
  22. for _ in range(10):
  23. try:
  24. self.es = Elasticsearch(
  25. settings.ES["hosts"].split(","),
  26. timeout=600
  27. )
  28. if self.es:
  29. self.info = self.es.info()
  30. es_logger.info("Connect to es.")
  31. break
  32. except Exception as e:
  33. es_logger.error("Fail to connect to es: " + str(e))
  34. time.sleep(1)
  35. def version(self):
  36. v = self.info.get("version", {"number": "5.6"})
  37. v = v["number"].split(".")[0]
  38. return int(v) >= 7
  39. def upsert(self, df, idxnm=""):
  40. res = []
  41. for d in df:
  42. id = d["id"]
  43. del d["id"]
  44. d = {"doc": d, "doc_as_upsert": "true"}
  45. T = False
  46. for _ in range(10):
  47. try:
  48. if not self.version():
  49. r = self.es.update(
  50. index=(
  51. self.idxnm if not idxnm else idxnm),
  52. body=d,
  53. id=id,
  54. doc_type="doc",
  55. refresh=True,
  56. retry_on_conflict=100)
  57. else:
  58. r = self.es.update(
  59. index=(
  60. self.idxnm if not idxnm else idxnm),
  61. body=d,
  62. id=id,
  63. refresh=True,
  64. retry_on_conflict=100)
  65. es_logger.info("Successfully upsert: %s" % id)
  66. T = True
  67. break
  68. except Exception as e:
  69. es_logger.warning("Fail to index: " +
  70. json.dumps(d, ensure_ascii=False) + str(e))
  71. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  72. time.sleep(3)
  73. continue
  74. self.conn()
  75. T = False
  76. if not T:
  77. res.append(d)
  78. es_logger.error(
  79. "Fail to index: " +
  80. re.sub(
  81. "[\r\n]",
  82. "",
  83. json.dumps(
  84. d,
  85. ensure_ascii=False)))
  86. d["id"] = id
  87. d["_index"] = self.idxnm
  88. if not res:
  89. return True
  90. return False
  91. def bulk(self, df, idx_nm=None):
  92. ids, acts = {}, []
  93. for d in df:
  94. id = d["id"] if "id" in d else d["_id"]
  95. ids[id] = copy.deepcopy(d)
  96. ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
  97. if "id" in d:
  98. del d["id"]
  99. if "_id" in d:
  100. del d["_id"]
  101. acts.append(
  102. {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
  103. acts.append({"doc": d, "doc_as_upsert": "true"})
  104. res = []
  105. for _ in range(100):
  106. try:
  107. if elasticsearch.__version__[0] < 8:
  108. r = self.es.bulk(
  109. index=(
  110. self.idxnm if not idx_nm else idx_nm),
  111. body=acts,
  112. refresh=False,
  113. timeout="600s")
  114. else:
  115. r = self.es.bulk(index=(self.idxnm if not idx_nm else
  116. idx_nm), operations=acts,
  117. refresh=False, timeout="600s")
  118. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  119. return res
  120. for it in r["items"]:
  121. if "error" in it["update"]:
  122. res.append(str(it["update"]["_id"]) +
  123. ":" + str(it["update"]["error"]))
  124. return res
  125. except Exception as e:
  126. es_logger.warn("Fail to bulk: " + str(e))
  127. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  128. time.sleep(3)
  129. continue
  130. self.conn()
  131. return res
  132. def bulk4script(self, df):
  133. ids, acts = {}, []
  134. for d in df:
  135. id = d["id"]
  136. ids[id] = copy.deepcopy(d["raw"])
  137. acts.append({"update": {"_id": id, "_index": self.idxnm}})
  138. acts.append(d["script"])
  139. es_logger.info("bulk upsert: %s" % id)
  140. res = []
  141. for _ in range(10):
  142. try:
  143. if not self.version():
  144. r = self.es.bulk(
  145. index=self.idxnm,
  146. body=acts,
  147. refresh=False,
  148. timeout="600s",
  149. doc_type="doc")
  150. else:
  151. r = self.es.bulk(
  152. index=self.idxnm,
  153. body=acts,
  154. refresh=False,
  155. timeout="600s")
  156. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  157. return res
  158. for it in r["items"]:
  159. if "error" in it["update"]:
  160. res.append(str(it["update"]["_id"]))
  161. return res
  162. except Exception as e:
  163. es_logger.warning("Fail to bulk: " + str(e))
  164. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  165. time.sleep(3)
  166. continue
  167. self.conn()
  168. return res
  169. def rm(self, d):
  170. for _ in range(10):
  171. try:
  172. if not self.version():
  173. r = self.es.delete(
  174. index=self.idxnm,
  175. id=d["id"],
  176. doc_type="doc",
  177. refresh=True)
  178. else:
  179. r = self.es.delete(
  180. index=self.idxnm,
  181. id=d["id"],
  182. refresh=True,
  183. doc_type="_doc")
  184. es_logger.info("Remove %s" % d["id"])
  185. return True
  186. except Exception as e:
  187. es_logger.warn("Fail to delete: " + str(d) + str(e))
  188. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  189. time.sleep(3)
  190. continue
  191. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  192. return True
  193. self.conn()
  194. es_logger.error("Fail to delete: " + str(d))
  195. return False
  196. def search(self, q, idxnm=None, src=False, timeout="2s"):
  197. if not isinstance(q, dict):
  198. q = Search().query(q).to_dict()
  199. for i in range(3):
  200. try:
  201. res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
  202. body=q,
  203. timeout=timeout,
  204. # search_type="dfs_query_then_fetch",
  205. track_total_hits=True,
  206. _source=src)
  207. if str(res.get("timed_out", "")).lower() == "true":
  208. raise Exception("Es Timeout.")
  209. return res
  210. except Exception as e:
  211. es_logger.error(
  212. "ES search exception: " +
  213. str(e) +
  214. "【Q】:" +
  215. str(q))
  216. if str(e).find("Timeout") > 0:
  217. continue
  218. raise e
  219. es_logger.error("ES search timeout for 3 times!")
  220. raise Exception("ES search timeout.")
  221. def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
  222. for i in range(3):
  223. try:
  224. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
  225. return res
  226. except ConnectionTimeout as e:
  227. es_logger.error("Timeout【Q】:" + sql)
  228. continue
  229. except Exception as e:
  230. raise e
  231. es_logger.error("ES search timeout for 3 times!")
  232. raise ConnectionTimeout()
  233. def get(self, doc_id, idxnm=None):
  234. for i in range(3):
  235. try:
  236. res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
  237. id=doc_id)
  238. if str(res.get("timed_out", "")).lower() == "true":
  239. raise Exception("Es Timeout.")
  240. return res
  241. except Exception as e:
  242. es_logger.error(
  243. "ES get exception: " +
  244. str(e) +
  245. "【Q】:" +
  246. doc_id)
  247. if str(e).find("Timeout") > 0:
  248. continue
  249. raise e
  250. es_logger.error("ES search timeout for 3 times!")
  251. raise Exception("ES search timeout.")
  252. def updateByQuery(self, q, d):
  253. ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
  254. scripts = ""
  255. for k, v in d.items():
  256. scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
  257. ubq = ubq.script(source=scripts, params=d)
  258. ubq = ubq.params(refresh=False)
  259. ubq = ubq.params(slices=5)
  260. ubq = ubq.params(conflicts="proceed")
  261. for i in range(3):
  262. try:
  263. r = ubq.execute()
  264. return True
  265. except Exception as e:
  266. es_logger.error("ES updateByQuery exception: " +
  267. str(e) + "【Q】:" + str(q.to_dict()))
  268. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  269. continue
  270. self.conn()
  271. return False
  272. def updateScriptByQuery(self, q, scripts, idxnm=None):
  273. ubq = UpdateByQuery(
  274. index=self.idxnm if not idxnm else idxnm).using(
  275. self.es).query(q)
  276. ubq = ubq.script(source=scripts)
  277. ubq = ubq.params(refresh=True)
  278. ubq = ubq.params(slices=5)
  279. ubq = ubq.params(conflicts="proceed")
  280. for i in range(3):
  281. try:
  282. r = ubq.execute()
  283. return True
  284. except Exception as e:
  285. es_logger.error("ES updateByQuery exception: " +
  286. str(e) + "【Q】:" + str(q.to_dict()))
  287. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  288. continue
  289. self.conn()
  290. return False
  291. def deleteByQuery(self, query, idxnm=""):
  292. for i in range(3):
  293. try:
  294. r = self.es.delete_by_query(
  295. index=idxnm if idxnm else self.idxnm,
  296. refresh = True,
  297. body=Search().query(query).to_dict())
  298. return True
  299. except Exception as e:
  300. es_logger.error("ES updateByQuery deleteByQuery: " +
  301. str(e) + "【Q】:" + str(query.to_dict()))
  302. if str(e).find("NotFoundError") > 0: return True
  303. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  304. continue
  305. return False
  306. def update(self, id, script, routing=None):
  307. for i in range(3):
  308. try:
  309. if not self.version():
  310. r = self.es.update(
  311. index=self.idxnm,
  312. id=id,
  313. body=json.dumps(
  314. script,
  315. ensure_ascii=False),
  316. doc_type="doc",
  317. routing=routing,
  318. refresh=False)
  319. else:
  320. r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
  321. routing=routing, refresh=False) # , doc_type="_doc")
  322. return True
  323. except Exception as e:
  324. es_logger.error(
  325. "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
  326. json.dumps(script, ensure_ascii=False))
  327. if str(e).find("Timeout") > 0:
  328. continue
  329. return False
  330. def indexExist(self, idxnm):
  331. s = Index(idxnm if idxnm else self.idxnm, self.es)
  332. for i in range(3):
  333. try:
  334. return s.exists()
  335. except Exception as e:
  336. es_logger.error("ES updateByQuery indexExist: " + str(e))
  337. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  338. continue
  339. return False
  340. def docExist(self, docid, idxnm=None):
  341. for i in range(3):
  342. try:
  343. return self.es.exists(index=(idxnm if idxnm else self.idxnm),
  344. id=docid)
  345. except Exception as e:
  346. es_logger.error("ES Doc Exist: " + str(e))
  347. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  348. continue
  349. return False
  350. def createIdx(self, idxnm, mapping):
  351. try:
  352. if elasticsearch.__version__[0] < 8:
  353. return self.es.indices.create(idxnm, body=mapping)
  354. from elasticsearch.client import IndicesClient
  355. return IndicesClient(self.es).create(index=idxnm,
  356. settings=mapping["settings"],
  357. mappings=mapping["mappings"])
  358. except Exception as e:
  359. es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
  360. def deleteIdx(self, idxnm):
  361. try:
  362. return self.es.indices.delete(idxnm, allow_no_indices=True)
  363. except Exception as e:
  364. es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
  365. def getTotal(self, res):
  366. if isinstance(res["hits"]["total"], type({})):
  367. return res["hits"]["total"]["value"]
  368. return res["hits"]["total"]
  369. def getDocIds(self, res):
  370. return [d["_id"] for d in res["hits"]["hits"]]
  371. def getSource(self, res):
  372. rr = []
  373. for d in res["hits"]["hits"]:
  374. d["_source"]["id"] = d["_id"]
  375. d["_source"]["_score"] = d["_score"]
  376. rr.append(d["_source"])
  377. return rr
  378. def scrollIter(self, pagesize=100, scroll_time='2m', q={
  379. "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
  380. for _ in range(100):
  381. try:
  382. page = self.es.search(
  383. index=self.idxnm,
  384. scroll=scroll_time,
  385. size=pagesize,
  386. body=q,
  387. _source=None
  388. )
  389. break
  390. except Exception as e:
  391. es_logger.error("ES scrolling fail. " + str(e))
  392. time.sleep(3)
  393. sid = page['_scroll_id']
  394. scroll_size = page['hits']['total']["value"]
  395. es_logger.info("[TOTAL]%d" % scroll_size)
  396. # Start scrolling
  397. while scroll_size > 0:
  398. yield page["hits"]["hits"]
  399. for _ in range(100):
  400. try:
  401. page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
  402. break
  403. except Exception as e:
  404. es_logger.error("ES scrolling fail. " + str(e))
  405. time.sleep(3)
  406. # Update the scroll ID
  407. sid = page['_scroll_id']
  408. # Get the number of results that we returned in the last scroll
  409. scroll_size = len(page['hits']['hits'])
  410. ELASTICSEARCH = HuEs()