You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

es_conn.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import re
  2. import json
  3. import time
  4. import copy
  5. import elasticsearch
  6. from elastic_transport import ConnectionTimeout
  7. from elasticsearch import Elasticsearch
  8. from elasticsearch_dsl import UpdateByQuery, Search, Index
  9. from rag.settings import es_logger
  10. from rag import settings
  11. from rag.utils import singleton
  12. es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
  13. @singleton
  14. class ESConnection:
  15. def __init__(self):
  16. self.info = {}
  17. self.conn()
  18. self.idxnm = settings.ES.get("index_name", "")
  19. if not self.es.ping():
  20. raise Exception("Can't connect to ES cluster")
  21. def conn(self):
  22. for _ in range(10):
  23. try:
  24. self.es = Elasticsearch(
  25. settings.ES["hosts"].split(","),
  26. timeout=600
  27. )
  28. if self.es:
  29. self.info = self.es.info()
  30. es_logger.info("Connect to es.")
  31. break
  32. except Exception as e:
  33. es_logger.error("Fail to connect to es: " + str(e))
  34. time.sleep(1)
  35. def version(self):
  36. v = self.info.get("version", {"number": "5.6"})
  37. v = v["number"].split(".")[0]
  38. return int(v) >= 7
  39. def health(self):
  40. return dict(self.es.cluster.health())
  41. def upsert(self, df, idxnm=""):
  42. res = []
  43. for d in df:
  44. id = d["id"]
  45. del d["id"]
  46. d = {"doc": d, "doc_as_upsert": "true"}
  47. T = False
  48. for _ in range(10):
  49. try:
  50. if not self.version():
  51. r = self.es.update(
  52. index=(
  53. self.idxnm if not idxnm else idxnm),
  54. body=d,
  55. id=id,
  56. doc_type="doc",
  57. refresh=True,
  58. retry_on_conflict=100)
  59. else:
  60. r = self.es.update(
  61. index=(
  62. self.idxnm if not idxnm else idxnm),
  63. body=d,
  64. id=id,
  65. refresh=True,
  66. retry_on_conflict=100)
  67. es_logger.info("Successfully upsert: %s" % id)
  68. T = True
  69. break
  70. except Exception as e:
  71. es_logger.warning("Fail to index: " +
  72. json.dumps(d, ensure_ascii=False) + str(e))
  73. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  74. time.sleep(3)
  75. continue
  76. self.conn()
  77. T = False
  78. if not T:
  79. res.append(d)
  80. es_logger.error(
  81. "Fail to index: " +
  82. re.sub(
  83. "[\r\n]",
  84. "",
  85. json.dumps(
  86. d,
  87. ensure_ascii=False)))
  88. d["id"] = id
  89. d["_index"] = self.idxnm
  90. if not res:
  91. return True
  92. return False
  93. def bulk(self, df, idx_nm=None):
  94. ids, acts = {}, []
  95. for d in df:
  96. id = d["id"] if "id" in d else d["_id"]
  97. ids[id] = copy.deepcopy(d)
  98. ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
  99. if "id" in d:
  100. del d["id"]
  101. if "_id" in d:
  102. del d["_id"]
  103. acts.append(
  104. {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
  105. acts.append({"doc": d, "doc_as_upsert": "true"})
  106. res = []
  107. for _ in range(100):
  108. try:
  109. if elasticsearch.__version__[0] < 8:
  110. r = self.es.bulk(
  111. index=(
  112. self.idxnm if not idx_nm else idx_nm),
  113. body=acts,
  114. refresh=False,
  115. timeout="600s")
  116. else:
  117. r = self.es.bulk(index=(self.idxnm if not idx_nm else
  118. idx_nm), operations=acts,
  119. refresh=False, timeout="600s")
  120. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  121. return res
  122. for it in r["items"]:
  123. if "error" in it["update"]:
  124. res.append(str(it["update"]["_id"]) +
  125. ":" + str(it["update"]["error"]))
  126. return res
  127. except Exception as e:
  128. es_logger.warn("Fail to bulk: " + str(e))
  129. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  130. time.sleep(3)
  131. continue
  132. self.conn()
  133. return res
  134. def bulk4script(self, df):
  135. ids, acts = {}, []
  136. for d in df:
  137. id = d["id"]
  138. ids[id] = copy.deepcopy(d["raw"])
  139. acts.append({"update": {"_id": id, "_index": self.idxnm}})
  140. acts.append(d["script"])
  141. es_logger.info("bulk upsert: %s" % id)
  142. res = []
  143. for _ in range(10):
  144. try:
  145. if not self.version():
  146. r = self.es.bulk(
  147. index=self.idxnm,
  148. body=acts,
  149. refresh=False,
  150. timeout="600s",
  151. doc_type="doc")
  152. else:
  153. r = self.es.bulk(
  154. index=self.idxnm,
  155. body=acts,
  156. refresh=False,
  157. timeout="600s")
  158. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  159. return res
  160. for it in r["items"]:
  161. if "error" in it["update"]:
  162. res.append(str(it["update"]["_id"]))
  163. return res
  164. except Exception as e:
  165. es_logger.warning("Fail to bulk: " + str(e))
  166. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  167. time.sleep(3)
  168. continue
  169. self.conn()
  170. return res
  171. def rm(self, d):
  172. for _ in range(10):
  173. try:
  174. if not self.version():
  175. r = self.es.delete(
  176. index=self.idxnm,
  177. id=d["id"],
  178. doc_type="doc",
  179. refresh=True)
  180. else:
  181. r = self.es.delete(
  182. index=self.idxnm,
  183. id=d["id"],
  184. refresh=True,
  185. doc_type="_doc")
  186. es_logger.info("Remove %s" % d["id"])
  187. return True
  188. except Exception as e:
  189. es_logger.warn("Fail to delete: " + str(d) + str(e))
  190. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  191. time.sleep(3)
  192. continue
  193. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  194. return True
  195. self.conn()
  196. es_logger.error("Fail to delete: " + str(d))
  197. return False
  198. def search(self, q, idxnm=None, src=False, timeout="2s"):
  199. if not isinstance(q, dict):
  200. q = Search().query(q).to_dict()
  201. for i in range(3):
  202. try:
  203. res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
  204. body=q,
  205. timeout=timeout,
  206. # search_type="dfs_query_then_fetch",
  207. track_total_hits=True,
  208. _source=src)
  209. if str(res.get("timed_out", "")).lower() == "true":
  210. raise Exception("Es Timeout.")
  211. return res
  212. except Exception as e:
  213. es_logger.error(
  214. "ES search exception: " +
  215. str(e) +
  216. "【Q】:" +
  217. str(q))
  218. if str(e).find("Timeout") > 0:
  219. continue
  220. raise e
  221. es_logger.error("ES search timeout for 3 times!")
  222. raise Exception("ES search timeout.")
  223. def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
  224. for i in range(3):
  225. try:
  226. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
  227. return res
  228. except ConnectionTimeout as e:
  229. es_logger.error("Timeout【Q】:" + sql)
  230. continue
  231. except Exception as e:
  232. raise e
  233. es_logger.error("ES search timeout for 3 times!")
  234. raise ConnectionTimeout()
  235. def get(self, doc_id, idxnm=None):
  236. for i in range(3):
  237. try:
  238. res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
  239. id=doc_id)
  240. if str(res.get("timed_out", "")).lower() == "true":
  241. raise Exception("Es Timeout.")
  242. return res
  243. except Exception as e:
  244. es_logger.error(
  245. "ES get exception: " +
  246. str(e) +
  247. "【Q】:" +
  248. doc_id)
  249. if str(e).find("Timeout") > 0:
  250. continue
  251. raise e
  252. es_logger.error("ES search timeout for 3 times!")
  253. raise Exception("ES search timeout.")
  254. def updateByQuery(self, q, d):
  255. ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
  256. scripts = ""
  257. for k, v in d.items():
  258. scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
  259. ubq = ubq.script(source=scripts, params=d)
  260. ubq = ubq.params(refresh=False)
  261. ubq = ubq.params(slices=5)
  262. ubq = ubq.params(conflicts="proceed")
  263. for i in range(3):
  264. try:
  265. r = ubq.execute()
  266. return True
  267. except Exception as e:
  268. es_logger.error("ES updateByQuery exception: " +
  269. str(e) + "【Q】:" + str(q.to_dict()))
  270. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  271. continue
  272. self.conn()
  273. return False
  274. def updateScriptByQuery(self, q, scripts, idxnm=None):
  275. ubq = UpdateByQuery(
  276. index=self.idxnm if not idxnm else idxnm).using(
  277. self.es).query(q)
  278. ubq = ubq.script(source=scripts)
  279. ubq = ubq.params(refresh=True)
  280. ubq = ubq.params(slices=5)
  281. ubq = ubq.params(conflicts="proceed")
  282. for i in range(3):
  283. try:
  284. r = ubq.execute()
  285. return True
  286. except Exception as e:
  287. es_logger.error("ES updateByQuery exception: " +
  288. str(e) + "【Q】:" + str(q.to_dict()))
  289. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  290. continue
  291. self.conn()
  292. return False
  293. def deleteByQuery(self, query, idxnm=""):
  294. for i in range(3):
  295. try:
  296. r = self.es.delete_by_query(
  297. index=idxnm if idxnm else self.idxnm,
  298. refresh = True,
  299. body=Search().query(query).to_dict())
  300. return True
  301. except Exception as e:
  302. es_logger.error("ES updateByQuery deleteByQuery: " +
  303. str(e) + "【Q】:" + str(query.to_dict()))
  304. if str(e).find("NotFoundError") > 0: return True
  305. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  306. continue
  307. return False
  308. def update(self, id, script, routing=None):
  309. for i in range(3):
  310. try:
  311. if not self.version():
  312. r = self.es.update(
  313. index=self.idxnm,
  314. id=id,
  315. body=json.dumps(
  316. script,
  317. ensure_ascii=False),
  318. doc_type="doc",
  319. routing=routing,
  320. refresh=False)
  321. else:
  322. r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
  323. routing=routing, refresh=False) # , doc_type="_doc")
  324. return True
  325. except Exception as e:
  326. es_logger.error(
  327. "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
  328. json.dumps(script, ensure_ascii=False))
  329. if str(e).find("Timeout") > 0:
  330. continue
  331. return False
  332. def indexExist(self, idxnm):
  333. s = Index(idxnm if idxnm else self.idxnm, self.es)
  334. for i in range(3):
  335. try:
  336. return s.exists()
  337. except Exception as e:
  338. es_logger.error("ES updateByQuery indexExist: " + str(e))
  339. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  340. continue
  341. return False
  342. def docExist(self, docid, idxnm=None):
  343. for i in range(3):
  344. try:
  345. return self.es.exists(index=(idxnm if idxnm else self.idxnm),
  346. id=docid)
  347. except Exception as e:
  348. es_logger.error("ES Doc Exist: " + str(e))
  349. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  350. continue
  351. return False
  352. def createIdx(self, idxnm, mapping):
  353. try:
  354. if elasticsearch.__version__[0] < 8:
  355. return self.es.indices.create(idxnm, body=mapping)
  356. from elasticsearch.client import IndicesClient
  357. return IndicesClient(self.es).create(index=idxnm,
  358. settings=mapping["settings"],
  359. mappings=mapping["mappings"])
  360. except Exception as e:
  361. es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
  362. def deleteIdx(self, idxnm):
  363. try:
  364. return self.es.indices.delete(idxnm, allow_no_indices=True)
  365. except Exception as e:
  366. es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
  367. def getTotal(self, res):
  368. if isinstance(res["hits"]["total"], type({})):
  369. return res["hits"]["total"]["value"]
  370. return res["hits"]["total"]
  371. def getDocIds(self, res):
  372. return [d["_id"] for d in res["hits"]["hits"]]
  373. def getSource(self, res):
  374. rr = []
  375. for d in res["hits"]["hits"]:
  376. d["_source"]["id"] = d["_id"]
  377. d["_source"]["_score"] = d["_score"]
  378. rr.append(d["_source"])
  379. return rr
  380. def scrollIter(self, pagesize=100, scroll_time='2m', q={
  381. "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
  382. for _ in range(100):
  383. try:
  384. page = self.es.search(
  385. index=self.idxnm,
  386. scroll=scroll_time,
  387. size=pagesize,
  388. body=q,
  389. _source=None
  390. )
  391. break
  392. except Exception as e:
  393. es_logger.error("ES scrolling fail. " + str(e))
  394. time.sleep(3)
  395. sid = page['_scroll_id']
  396. scroll_size = page['hits']['total']["value"]
  397. es_logger.info("[TOTAL]%d" % scroll_size)
  398. # Start scrolling
  399. while scroll_size > 0:
  400. yield page["hits"]["hits"]
  401. for _ in range(100):
  402. try:
  403. page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
  404. break
  405. except Exception as e:
  406. es_logger.error("ES scrolling fail. " + str(e))
  407. time.sleep(3)
  408. # Update the scroll ID
  409. sid = page['_scroll_id']
  410. # Get the number of results that we returned in the last scroll
  411. scroll_size = len(page['hits']['hits'])
  412. ELASTICSEARCH = ESConnection()