You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

es_conn.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import re
  2. import json
  3. import time
  4. import copy
  5. import elasticsearch
  6. from elastic_transport import ConnectionTimeout
  7. from elasticsearch import Elasticsearch
  8. from elasticsearch_dsl import UpdateByQuery, Search, Index
  9. from rag.settings import es_logger
  10. from rag import settings
  11. from rag.utils import singleton
  12. es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
  13. @singleton
  14. class ESConnection:
  15. def __init__(self):
  16. self.info = {}
  17. self.conn()
  18. self.idxnm = settings.ES.get("index_name", "")
  19. if not self.es.ping():
  20. raise Exception("Can't connect to ES cluster")
  21. def conn(self):
  22. for _ in range(10):
  23. try:
  24. self.es = Elasticsearch(
  25. settings.ES["hosts"].split(","),
  26. basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
  27. verify_certs=False,
  28. timeout=600
  29. )
  30. if self.es:
  31. self.info = self.es.info()
  32. es_logger.info("Connect to es.")
  33. break
  34. except Exception as e:
  35. es_logger.error("Fail to connect to es: " + str(e))
  36. time.sleep(1)
  37. def version(self):
  38. v = self.info.get("version", {"number": "5.6"})
  39. v = v["number"].split(".")[0]
  40. return int(v) >= 7
  41. def health(self):
  42. return dict(self.es.cluster.health())
  43. def upsert(self, df, idxnm=""):
  44. res = []
  45. for d in df:
  46. id = d["id"]
  47. del d["id"]
  48. d = {"doc": d, "doc_as_upsert": "true"}
  49. T = False
  50. for _ in range(10):
  51. try:
  52. if not self.version():
  53. r = self.es.update(
  54. index=(
  55. self.idxnm if not idxnm else idxnm),
  56. body=d,
  57. id=id,
  58. doc_type="doc",
  59. refresh=True,
  60. retry_on_conflict=100)
  61. else:
  62. r = self.es.update(
  63. index=(
  64. self.idxnm if not idxnm else idxnm),
  65. body=d,
  66. id=id,
  67. refresh=True,
  68. retry_on_conflict=100)
  69. es_logger.info("Successfully upsert: %s" % id)
  70. T = True
  71. break
  72. except Exception as e:
  73. es_logger.warning("Fail to index: " +
  74. json.dumps(d, ensure_ascii=False) + str(e))
  75. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  76. time.sleep(3)
  77. continue
  78. self.conn()
  79. T = False
  80. if not T:
  81. res.append(d)
  82. es_logger.error(
  83. "Fail to index: " +
  84. re.sub(
  85. "[\r\n]",
  86. "",
  87. json.dumps(
  88. d,
  89. ensure_ascii=False)))
  90. d["id"] = id
  91. d["_index"] = self.idxnm
  92. if not res:
  93. return True
  94. return False
  95. def bulk(self, df, idx_nm=None):
  96. ids, acts = {}, []
  97. for d in df:
  98. id = d["id"] if "id" in d else d["_id"]
  99. ids[id] = copy.deepcopy(d)
  100. ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
  101. if "id" in d:
  102. del d["id"]
  103. if "_id" in d:
  104. del d["_id"]
  105. acts.append(
  106. {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
  107. acts.append({"doc": d, "doc_as_upsert": "true"})
  108. res = []
  109. for _ in range(100):
  110. try:
  111. if elasticsearch.__version__[0] < 8:
  112. r = self.es.bulk(
  113. index=(
  114. self.idxnm if not idx_nm else idx_nm),
  115. body=acts,
  116. refresh=False,
  117. timeout="600s")
  118. else:
  119. r = self.es.bulk(index=(self.idxnm if not idx_nm else
  120. idx_nm), operations=acts,
  121. refresh=False, timeout="600s")
  122. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  123. return res
  124. for it in r["items"]:
  125. if "error" in it["update"]:
  126. res.append(str(it["update"]["_id"]) +
  127. ":" + str(it["update"]["error"]))
  128. return res
  129. except Exception as e:
  130. es_logger.warn("Fail to bulk: " + str(e))
  131. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  132. time.sleep(3)
  133. continue
  134. self.conn()
  135. return res
  136. def bulk4script(self, df):
  137. ids, acts = {}, []
  138. for d in df:
  139. id = d["id"]
  140. ids[id] = copy.deepcopy(d["raw"])
  141. acts.append({"update": {"_id": id, "_index": self.idxnm}})
  142. acts.append(d["script"])
  143. es_logger.info("bulk upsert: %s" % id)
  144. res = []
  145. for _ in range(10):
  146. try:
  147. if not self.version():
  148. r = self.es.bulk(
  149. index=self.idxnm,
  150. body=acts,
  151. refresh=False,
  152. timeout="600s",
  153. doc_type="doc")
  154. else:
  155. r = self.es.bulk(
  156. index=self.idxnm,
  157. body=acts,
  158. refresh=False,
  159. timeout="600s")
  160. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  161. return res
  162. for it in r["items"]:
  163. if "error" in it["update"]:
  164. res.append(str(it["update"]["_id"]))
  165. return res
  166. except Exception as e:
  167. es_logger.warning("Fail to bulk: " + str(e))
  168. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  169. time.sleep(3)
  170. continue
  171. self.conn()
  172. return res
  173. def rm(self, d):
  174. for _ in range(10):
  175. try:
  176. if not self.version():
  177. r = self.es.delete(
  178. index=self.idxnm,
  179. id=d["id"],
  180. doc_type="doc",
  181. refresh=True)
  182. else:
  183. r = self.es.delete(
  184. index=self.idxnm,
  185. id=d["id"],
  186. refresh=True,
  187. doc_type="_doc")
  188. es_logger.info("Remove %s" % d["id"])
  189. return True
  190. except Exception as e:
  191. es_logger.warn("Fail to delete: " + str(d) + str(e))
  192. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  193. time.sleep(3)
  194. continue
  195. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  196. return True
  197. self.conn()
  198. es_logger.error("Fail to delete: " + str(d))
  199. return False
  200. def search(self, q, idxnms=None, src=False, timeout="2s"):
  201. if not isinstance(q, dict):
  202. q = Search().query(q).to_dict()
  203. if isinstance(idxnms, str):
  204. idxnms = idxnms.split(",")
  205. for i in range(3):
  206. try:
  207. res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
  208. body=q,
  209. timeout=timeout,
  210. # search_type="dfs_query_then_fetch",
  211. track_total_hits=True,
  212. _source=src)
  213. if str(res.get("timed_out", "")).lower() == "true":
  214. raise Exception("Es Timeout.")
  215. return res
  216. except Exception as e:
  217. es_logger.error(
  218. "ES search exception: " +
  219. str(e) +
  220. "【Q】:" +
  221. str(q))
  222. if str(e).find("Timeout") > 0:
  223. continue
  224. raise e
  225. es_logger.error("ES search timeout for 3 times!")
  226. raise Exception("ES search timeout.")
  227. def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
  228. for i in range(3):
  229. try:
  230. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
  231. return res
  232. except ConnectionTimeout as e:
  233. es_logger.error("Timeout【Q】:" + sql)
  234. continue
  235. except Exception as e:
  236. raise e
  237. es_logger.error("ES search timeout for 3 times!")
  238. raise ConnectionTimeout()
  239. def get(self, doc_id, idxnm=None):
  240. for i in range(3):
  241. try:
  242. res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
  243. id=doc_id)
  244. if str(res.get("timed_out", "")).lower() == "true":
  245. raise Exception("Es Timeout.")
  246. return res
  247. except Exception as e:
  248. es_logger.error(
  249. "ES get exception: " +
  250. str(e) +
  251. "【Q】:" +
  252. doc_id)
  253. if str(e).find("Timeout") > 0:
  254. continue
  255. raise e
  256. es_logger.error("ES search timeout for 3 times!")
  257. raise Exception("ES search timeout.")
  258. def updateByQuery(self, q, d):
  259. ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
  260. scripts = ""
  261. for k, v in d.items():
  262. scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
  263. ubq = ubq.script(source=scripts, params=d)
  264. ubq = ubq.params(refresh=False)
  265. ubq = ubq.params(slices=5)
  266. ubq = ubq.params(conflicts="proceed")
  267. for i in range(3):
  268. try:
  269. r = ubq.execute()
  270. return True
  271. except Exception as e:
  272. es_logger.error("ES updateByQuery exception: " +
  273. str(e) + "【Q】:" + str(q.to_dict()))
  274. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  275. continue
  276. self.conn()
  277. return False
  278. def updateScriptByQuery(self, q, scripts, idxnm=None):
  279. ubq = UpdateByQuery(
  280. index=self.idxnm if not idxnm else idxnm).using(
  281. self.es).query(q)
  282. ubq = ubq.script(source=scripts)
  283. ubq = ubq.params(refresh=True)
  284. ubq = ubq.params(slices=5)
  285. ubq = ubq.params(conflicts="proceed")
  286. for i in range(3):
  287. try:
  288. r = ubq.execute()
  289. return True
  290. except Exception as e:
  291. es_logger.error("ES updateByQuery exception: " +
  292. str(e) + "【Q】:" + str(q.to_dict()))
  293. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  294. continue
  295. self.conn()
  296. return False
  297. def deleteByQuery(self, query, idxnm=""):
  298. for i in range(3):
  299. try:
  300. r = self.es.delete_by_query(
  301. index=idxnm if idxnm else self.idxnm,
  302. refresh = True,
  303. body=Search().query(query).to_dict())
  304. return True
  305. except Exception as e:
  306. es_logger.error("ES updateByQuery deleteByQuery: " +
  307. str(e) + "【Q】:" + str(query.to_dict()))
  308. if str(e).find("NotFoundError") > 0: return True
  309. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  310. continue
  311. return False
  312. def update(self, id, script, routing=None):
  313. for i in range(3):
  314. try:
  315. if not self.version():
  316. r = self.es.update(
  317. index=self.idxnm,
  318. id=id,
  319. body=json.dumps(
  320. script,
  321. ensure_ascii=False),
  322. doc_type="doc",
  323. routing=routing,
  324. refresh=False)
  325. else:
  326. r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
  327. routing=routing, refresh=False) # , doc_type="_doc")
  328. return True
  329. except Exception as e:
  330. es_logger.error(
  331. "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
  332. json.dumps(script, ensure_ascii=False))
  333. if str(e).find("Timeout") > 0:
  334. continue
  335. return False
  336. def indexExist(self, idxnm):
  337. s = Index(idxnm if idxnm else self.idxnm, self.es)
  338. for i in range(3):
  339. try:
  340. return s.exists()
  341. except Exception as e:
  342. es_logger.error("ES updateByQuery indexExist: " + str(e))
  343. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  344. continue
  345. return False
  346. def docExist(self, docid, idxnm=None):
  347. for i in range(3):
  348. try:
  349. return self.es.exists(index=(idxnm if idxnm else self.idxnm),
  350. id=docid)
  351. except Exception as e:
  352. es_logger.error("ES Doc Exist: " + str(e))
  353. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  354. continue
  355. return False
  356. def createIdx(self, idxnm, mapping):
  357. try:
  358. if elasticsearch.__version__[0] < 8:
  359. return self.es.indices.create(idxnm, body=mapping)
  360. from elasticsearch.client import IndicesClient
  361. return IndicesClient(self.es).create(index=idxnm,
  362. settings=mapping["settings"],
  363. mappings=mapping["mappings"])
  364. except Exception as e:
  365. es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
  366. def deleteIdx(self, idxnm):
  367. try:
  368. return self.es.indices.delete(idxnm, allow_no_indices=True)
  369. except Exception as e:
  370. es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
  371. def getTotal(self, res):
  372. if isinstance(res["hits"]["total"], type({})):
  373. return res["hits"]["total"]["value"]
  374. return res["hits"]["total"]
  375. def getDocIds(self, res):
  376. return [d["_id"] for d in res["hits"]["hits"]]
  377. def getSource(self, res):
  378. rr = []
  379. for d in res["hits"]["hits"]:
  380. d["_source"]["id"] = d["_id"]
  381. d["_source"]["_score"] = d["_score"]
  382. rr.append(d["_source"])
  383. return rr
  384. def scrollIter(self, pagesize=100, scroll_time='2m', q={
  385. "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
  386. for _ in range(100):
  387. try:
  388. page = self.es.search(
  389. index=self.idxnm,
  390. scroll=scroll_time,
  391. size=pagesize,
  392. body=q,
  393. _source=None
  394. )
  395. break
  396. except Exception as e:
  397. es_logger.error("ES scrolling fail. " + str(e))
  398. time.sleep(3)
  399. sid = page['_scroll_id']
  400. scroll_size = page['hits']['total']["value"]
  401. es_logger.info("[TOTAL]%d" % scroll_size)
  402. # Start scrolling
  403. while scroll_size > 0:
  404. yield page["hits"]["hits"]
  405. for _ in range(100):
  406. try:
  407. page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
  408. break
  409. except Exception as e:
  410. es_logger.error("ES scrolling fail. " + str(e))
  411. time.sleep(3)
  412. # Update the scroll ID
  413. sid = page['_scroll_id']
  414. # Get the number of results that we returned in the last scroll
  415. scroll_size = len(page['hits']['hits'])
  416. ELASTICSEARCH = ESConnection()