- import re
 - import json
 - import time
 - import copy
 - 
 - import elasticsearch
 - from elastic_transport import ConnectionTimeout
 - from elasticsearch import Elasticsearch
 - from elasticsearch_dsl import UpdateByQuery, Search, Index
 - from rag.settings import es_logger
 - from rag import settings
 - from rag.utils import singleton
 - 
 - es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__))
 - 
 - 
 - @singleton
 - class ESConnection:
 -     def __init__(self):
 -         self.info = {}
 -         self.conn()
 -         self.idxnm = settings.ES.get("index_name", "")
 -         if not self.es.ping():
 -             raise Exception("Can't connect to ES cluster")
 - 
 -     def conn(self):
 -         for _ in range(10):
 -             try:
 -                 self.es = Elasticsearch(
 -                     settings.ES["hosts"].split(","),
 -                     basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
 -                     verify_certs=False,
 -                     timeout=600
 -                 )
 -                 if self.es:
 -                     self.info = self.es.info()
 -                     es_logger.info("Connect to es.")
 -                     break
 -             except Exception as e:
 -                 es_logger.error("Fail to connect to es: " + str(e))
 -                 time.sleep(1)
 - 
 -     def version(self):
 -         v = self.info.get("version", {"number": "5.6"})
 -         v = v["number"].split(".")[0]
 -         return int(v) >= 7
 - 
 -     def health(self):
 -         return dict(self.es.cluster.health())
 - 
 -     def upsert(self, df, idxnm=""):
 -         res = []
 -         for d in df:
 -             id = d["id"]
 -             del d["id"]
 -             d = {"doc": d, "doc_as_upsert": "true"}
 -             T = False
 -             for _ in range(10):
 -                 try:
 -                     if not self.version():
 -                         r = self.es.update(
 -                             index=(
 -                                 self.idxnm if not idxnm else idxnm),
 -                             body=d,
 -                             id=id,
 -                             doc_type="doc",
 -                             refresh=True,
 -                             retry_on_conflict=100)
 -                     else:
 -                         r = self.es.update(
 -                             index=(
 -                                 self.idxnm if not idxnm else idxnm),
 -                             body=d,
 -                             id=id,
 -                             refresh=True,
 -                             retry_on_conflict=100)
 -                     es_logger.info("Successfully upsert: %s" % id)
 -                     T = True
 -                     break
 -                 except Exception as e:
 -                     es_logger.warning("Fail to index: " +
 -                                       json.dumps(d, ensure_ascii=False) + str(e))
 -                     if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
 -                         time.sleep(3)
 -                         continue
 -                     self.conn()
 -                     T = False
 - 
 -             if not T:
 -                 res.append(d)
 -                 es_logger.error(
 -                     "Fail to index: " +
 -                     re.sub(
 -                         "[\r\n]",
 -                         "",
 -                         json.dumps(
 -                             d,
 -                             ensure_ascii=False)))
 -                 d["id"] = id
 -                 d["_index"] = self.idxnm
 - 
 -         if not res:
 -             return True
 -         return False
 - 
 -     def bulk(self, df, idx_nm=None):
 -         ids, acts = {}, []
 -         for d in df:
 -             id = d["id"] if "id" in d else d["_id"]
 -             ids[id] = copy.deepcopy(d)
 -             ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm
 -             if "id" in d:
 -                 del d["id"]
 -             if "_id" in d:
 -                 del d["_id"]
 -             acts.append(
 -                 {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100})
 -             acts.append({"doc": d, "doc_as_upsert": "true"})
 - 
 -         res = []
 -         for _ in range(100):
 -             try:
 -                 if elasticsearch.__version__[0] < 8:
 -                     r = self.es.bulk(
 -                         index=(
 -                             self.idxnm if not idx_nm else idx_nm),
 -                         body=acts,
 -                         refresh=False,
 -                         timeout="600s")
 -                 else:
 -                     r = self.es.bulk(index=(self.idxnm if not idx_nm else
 -                                             idx_nm), operations=acts,
 -                                      refresh=False, timeout="600s")
 -                 if re.search(r"False", str(r["errors"]), re.IGNORECASE):
 -                     return res
 - 
 -                 for it in r["items"]:
 -                     if "error" in it["update"]:
 -                         res.append(str(it["update"]["_id"]) +
 -                                    ":" + str(it["update"]["error"]))
 - 
 -                 return res
 -             except Exception as e:
 -                 es_logger.warn("Fail to bulk: " + str(e))
 -                 if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
 -                     time.sleep(3)
 -                     continue
 -                 self.conn()
 - 
 -         return res
 - 
 -     def bulk4script(self, df):
 -         ids, acts = {}, []
 -         for d in df:
 -             id = d["id"]
 -             ids[id] = copy.deepcopy(d["raw"])
 -             acts.append({"update": {"_id": id, "_index": self.idxnm}})
 -             acts.append(d["script"])
 -             es_logger.info("bulk upsert: %s" % id)
 - 
 -         res = []
 -         for _ in range(10):
 -             try:
 -                 if not self.version():
 -                     r = self.es.bulk(
 -                         index=self.idxnm,
 -                         body=acts,
 -                         refresh=False,
 -                         timeout="600s",
 -                         doc_type="doc")
 -                 else:
 -                     r = self.es.bulk(
 -                         index=self.idxnm,
 -                         body=acts,
 -                         refresh=False,
 -                         timeout="600s")
 -                 if re.search(r"False", str(r["errors"]), re.IGNORECASE):
 -                     return res
 - 
 -                 for it in r["items"]:
 -                     if "error" in it["update"]:
 -                         res.append(str(it["update"]["_id"]))
 - 
 -                 return res
 -             except Exception as e:
 -                 es_logger.warning("Fail to bulk: " + str(e))
 -                 if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
 -                     time.sleep(3)
 -                     continue
 -                 self.conn()
 - 
 -         return res
 - 
 -     def rm(self, d):
 -         for _ in range(10):
 -             try:
 -                 if not self.version():
 -                     r = self.es.delete(
 -                         index=self.idxnm,
 -                         id=d["id"],
 -                         doc_type="doc",
 -                         refresh=True)
 -                 else:
 -                     r = self.es.delete(
 -                         index=self.idxnm,
 -                         id=d["id"],
 -                         refresh=True,
 -                         doc_type="_doc")
 -                 es_logger.info("Remove %s" % d["id"])
 -                 return True
 -             except Exception as e:
 -                 es_logger.warn("Fail to delete: " + str(d) + str(e))
 -                 if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
 -                     time.sleep(3)
 -                     continue
 -                 if re.search(r"(not_found)", str(e), re.IGNORECASE):
 -                     return True
 -                 self.conn()
 - 
 -         es_logger.error("Fail to delete: " + str(d))
 - 
 -         return False
 - 
 -     def search(self, q, idxnms=None, src=False, timeout="2s"):
 -         if not isinstance(q, dict):
 -             q = Search().query(q).to_dict()
 -         if isinstance(idxnms, str):
 -             idxnms = idxnms.split(",")
 -         for i in range(3):
 -             try:
 -                 res = self.es.search(index=(self.idxnm if not idxnms else idxnms),
 -                                      body=q,
 -                                      timeout=timeout,
 -                                      # search_type="dfs_query_then_fetch",
 -                                      track_total_hits=True,
 -                                      _source=src)
 -                 if str(res.get("timed_out", "")).lower() == "true":
 -                     raise Exception("Es Timeout.")
 -                 return res
 -             except Exception as e:
 -                 es_logger.error(
 -                     "ES search exception: " +
 -                     str(e) +
 -                     "【Q】:" +
 -                     str(q))
 -                 if str(e).find("Timeout") > 0:
 -                     continue
 -                 raise e
 -         es_logger.error("ES search timeout for 3 times!")
 -         raise Exception("ES search timeout.")
 - 
 -     def sql(self, sql, fetch_size=128, format="json", timeout="2s"):
 -         for i in range(3):
 -             try:
 -                 res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout)
 -                 return res
 -             except ConnectionTimeout as e:
 -                 es_logger.error("Timeout【Q】:" + sql)
 -                 continue
 -             except Exception as e:
 -                 raise e
 -         es_logger.error("ES search timeout for 3 times!")
 -         raise ConnectionTimeout()
 - 
 - 
 -     def get(self, doc_id, idxnm=None):
 -         for i in range(3):
 -             try:
 -                 res = self.es.get(index=(self.idxnm if not idxnm else idxnm),
 -                                      id=doc_id)
 -                 if str(res.get("timed_out", "")).lower() == "true":
 -                     raise Exception("Es Timeout.")
 -                 return res
 -             except Exception as e:
 -                 es_logger.error(
 -                     "ES get exception: " +
 -                     str(e) +
 -                     "【Q】:" +
 -                     doc_id)
 -                 if str(e).find("Timeout") > 0:
 -                     continue
 -                 raise e
 -         es_logger.error("ES search timeout for 3 times!")
 -         raise Exception("ES search timeout.")
 - 
 -     def updateByQuery(self, q, d):
 -         ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q)
 -         scripts = ""
 -         for k, v in d.items():
 -             scripts += "ctx._source.%s = params.%s;" % (str(k), str(k))
 -         ubq = ubq.script(source=scripts, params=d)
 -         ubq = ubq.params(refresh=False)
 -         ubq = ubq.params(slices=5)
 -         ubq = ubq.params(conflicts="proceed")
 -         for i in range(3):
 -             try:
 -                 r = ubq.execute()
 -                 return True
 -             except Exception as e:
 -                 es_logger.error("ES updateByQuery exception: " +
 -                                 str(e) + "【Q】:" + str(q.to_dict()))
 -                 if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 -                     continue
 -                 self.conn()
 - 
 -         return False
 - 
 -     def updateScriptByQuery(self, q, scripts, idxnm=None):
 -         ubq = UpdateByQuery(
 -             index=self.idxnm if not idxnm else idxnm).using(
 -             self.es).query(q)
 -         ubq = ubq.script(source=scripts)
 -         ubq = ubq.params(refresh=True)
 -         ubq = ubq.params(slices=5)
 -         ubq = ubq.params(conflicts="proceed")
 -         for i in range(3):
 -             try:
 -                 r = ubq.execute()
 -                 return True
 -             except Exception as e:
 -                 es_logger.error("ES updateByQuery exception: " +
 -                                 str(e) + "【Q】:" + str(q.to_dict()))
 -                 if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 -                     continue
 -                 self.conn()
 - 
 -         return False
 - 
 -     def deleteByQuery(self, query, idxnm=""):
 -         for i in range(3):
 -             try:
 -                 r = self.es.delete_by_query(
 -                     index=idxnm if idxnm else self.idxnm,
 -                     refresh = True,
 -                 body=Search().query(query).to_dict())
 -                 return True
 -             except Exception as e:
 -                 es_logger.error("ES updateByQuery deleteByQuery: " +
 -                                 str(e) + "【Q】:" + str(query.to_dict()))
 -                 if str(e).find("NotFoundError") > 0: return True
 -                 if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 -                     continue
 - 
 -         return False
 - 
 -     def update(self, id, script, routing=None):
 -         for i in range(3):
 -             try:
 -                 if not self.version():
 -                     r = self.es.update(
 -                         index=self.idxnm,
 -                         id=id,
 -                         body=json.dumps(
 -                             script,
 -                             ensure_ascii=False),
 -                         doc_type="doc",
 -                         routing=routing,
 -                         refresh=False)
 -                 else:
 -                     r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False),
 -                                        routing=routing, refresh=False)  # , doc_type="_doc")
 -                 return True
 -             except Exception as e:
 -                 es_logger.error(
 -                     "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) +
 -                     json.dumps(script, ensure_ascii=False))
 -                 if str(e).find("Timeout") > 0:
 -                     continue
 - 
 -         return False
 - 
 -     def indexExist(self, idxnm):
 -         s = Index(idxnm if idxnm else self.idxnm, self.es)
 -         for i in range(3):
 -             try:
 -                 return s.exists()
 -             except Exception as e:
 -                 es_logger.error("ES updateByQuery indexExist: " + str(e))
 -                 if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 -                     continue
 - 
 -         return False
 - 
 -     def docExist(self, docid, idxnm=None):
 -         for i in range(3):
 -             try:
 -                 return self.es.exists(index=(idxnm if idxnm else self.idxnm),
 -                                       id=docid)
 -             except Exception as e:
 -                 es_logger.error("ES Doc Exist: " + str(e))
 -                 if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
 -                     continue
 -         return False
 - 
 -     def createIdx(self, idxnm, mapping):
 -         try:
 -             if elasticsearch.__version__[0] < 8:
 -                 return self.es.indices.create(idxnm, body=mapping)
 -             from elasticsearch.client import IndicesClient
 -             return IndicesClient(self.es).create(index=idxnm,
 -                                                  settings=mapping["settings"],
 -                                                  mappings=mapping["mappings"])
 -         except Exception as e:
 -             es_logger.error("ES create index error %s ----%s" % (idxnm, str(e)))
 - 
 -     def deleteIdx(self, idxnm):
 -         try:
 -             return self.es.indices.delete(idxnm, allow_no_indices=True)
 -         except Exception as e:
 -             es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e)))
 - 
 -     def getTotal(self, res):
 -         if isinstance(res["hits"]["total"], type({})):
 -             return res["hits"]["total"]["value"]
 -         return res["hits"]["total"]
 - 
 -     def getDocIds(self, res):
 -         return [d["_id"] for d in res["hits"]["hits"]]
 - 
 -     def getSource(self, res):
 -         rr = []
 -         for d in res["hits"]["hits"]:
 -             d["_source"]["id"] = d["_id"]
 -             d["_source"]["_score"] = d["_score"]
 -             rr.append(d["_source"])
 -         return rr
 - 
 -     def scrollIter(self, pagesize=100, scroll_time='2m', q={
 -         "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
 -         for _ in range(100):
 -             try:
 -                 page = self.es.search(
 -                     index=self.idxnm,
 -                     scroll=scroll_time,
 -                     size=pagesize,
 -                     body=q,
 -                     _source=None
 -                 )
 -                 break
 -             except Exception as e:
 -                 es_logger.error("ES scrolling fail. " + str(e))
 -                 time.sleep(3)
 - 
 -         sid = page['_scroll_id']
 -         scroll_size = page['hits']['total']["value"]
 -         es_logger.info("[TOTAL]%d" % scroll_size)
 -         # Start scrolling
 -         while scroll_size > 0:
 -             yield page["hits"]["hits"]
 -             for _ in range(100):
 -                 try:
 -                     page = self.es.scroll(scroll_id=sid, scroll=scroll_time)
 -                     break
 -                 except Exception as e:
 -                     es_logger.error("ES scrolling fail. " + str(e))
 -                     time.sleep(3)
 - 
 -             # Update the scroll ID
 -             sid = page['_scroll_id']
 -             # Get the number of results that we returned in the last scroll
 -             scroll_size = len(page['hits']['hits'])
 - 
 - 
 - ELASTICSEARCH = ESConnection()
 
 
  |