You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

es_conn.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import re
  2. import json
  3. import time
  4. import os
  5. from typing import List, Dict
  6. import elasticsearch
  7. import copy
  8. from elasticsearch import Elasticsearch
  9. from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
  10. from elastic_transport import ConnectionTimeout
  11. from rag.settings import doc_store_logger
  12. from rag import settings
  13. from rag.utils import singleton
  14. from api.utils.file_utils import get_project_base_directory
  15. import polars as pl
  16. from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
  17. from rag.nlp import is_english, rag_tokenizer
  18. doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
  19. @singleton
  20. class ESConnection(DocStoreConnection):
  21. def __init__(self):
  22. self.info = {}
  23. for _ in range(10):
  24. try:
  25. self.es = Elasticsearch(
  26. settings.ES["hosts"].split(","),
  27. basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
  28. verify_certs=False,
  29. timeout=600
  30. )
  31. if self.es:
  32. self.info = self.es.info()
  33. doc_store_logger.info("Connect to es.")
  34. break
  35. except Exception as e:
  36. doc_store_logger.error("Fail to connect to es: " + str(e))
  37. time.sleep(1)
  38. if not self.es.ping():
  39. raise Exception("Can't connect to ES cluster")
  40. v = self.info.get("version", {"number": "5.6"})
  41. v = v["number"].split(".")[0]
  42. if int(v) < 8:
  43. raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
  44. fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
  45. if not os.path.exists(fp_mapping):
  46. raise Exception(f"Mapping file not found at {fp_mapping}")
  47. self.mapping = json.load(open(fp_mapping, "r"))
  48. """
  49. Database operations
  50. """
  51. def dbType(self) -> str:
  52. return "elasticsearch"
  53. def health(self) -> dict:
  54. return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
  55. """
  56. Table operations
  57. """
  58. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  59. if self.indexExist(indexName, knowledgebaseId):
  60. return True
  61. try:
  62. from elasticsearch.client import IndicesClient
  63. return IndicesClient(self.es).create(index=indexName,
  64. settings=self.mapping["settings"],
  65. mappings=self.mapping["mappings"])
  66. except Exception as e:
  67. doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
  68. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  69. try:
  70. return self.es.indices.delete(indexName, allow_no_indices=True)
  71. except Exception as e:
  72. doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
  73. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  74. s = Index(indexName, self.es)
  75. for i in range(3):
  76. try:
  77. return s.exists()
  78. except Exception as e:
  79. doc_store_logger.error("ES indexExist: " + str(e))
  80. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  81. continue
  82. return False
  83. """
  84. CRUD operations
  85. """
  86. def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
  87. """
  88. Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
  89. """
  90. if isinstance(indexNames, str):
  91. indexNames = indexNames.split(",")
  92. assert isinstance(indexNames, list) and len(indexNames) > 0
  93. assert "_id" not in condition
  94. s = Search()
  95. bqry = None
  96. vector_similarity_weight = 0.5
  97. for m in matchExprs:
  98. if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
  99. assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
  100. weights = m.fusion_params["weights"]
  101. vector_similarity_weight = float(weights.split(",")[1])
  102. for m in matchExprs:
  103. if isinstance(m, MatchTextExpr):
  104. minimum_should_match = "0%"
  105. if "minimum_should_match" in m.extra_options:
  106. minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
  107. bqry = Q("bool",
  108. must=Q("query_string", fields=m.fields,
  109. type="best_fields", query=m.matching_text,
  110. minimum_should_match = minimum_should_match,
  111. boost=1),
  112. boost = 1.0 - vector_similarity_weight,
  113. )
  114. if condition:
  115. for k, v in condition.items():
  116. if not isinstance(k, str) or not v:
  117. continue
  118. if isinstance(v, list):
  119. bqry.filter.append(Q("terms", **{k: v}))
  120. elif isinstance(v, str) or isinstance(v, int):
  121. bqry.filter.append(Q("term", **{k: v}))
  122. else:
  123. raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  124. elif isinstance(m, MatchDenseExpr):
  125. assert(bqry is not None)
  126. similarity = 0.0
  127. if "similarity" in m.extra_options:
  128. similarity = m.extra_options["similarity"]
  129. s = s.knn(m.vector_column_name,
  130. m.topn,
  131. m.topn * 2,
  132. query_vector = list(m.embedding_data),
  133. filter = bqry.to_dict(),
  134. similarity = similarity,
  135. )
  136. if matchExprs:
  137. s.query = bqry
  138. for field in highlightFields:
  139. s = s.highlight(field)
  140. if orderBy:
  141. orders = list()
  142. for field, order in orderBy.fields:
  143. order = "asc" if order == 0 else "desc"
  144. orders.append({field: {"order": order, "unmapped_type": "float",
  145. "mode": "avg", "numeric_type": "double"}})
  146. s = s.sort(*orders)
  147. if limit > 0:
  148. s = s[offset:limit]
  149. q = s.to_dict()
  150. doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
  151. for i in range(3):
  152. try:
  153. res = self.es.search(index=indexNames,
  154. body=q,
  155. timeout="600s",
  156. # search_type="dfs_query_then_fetch",
  157. track_total_hits=True,
  158. _source=True)
  159. if str(res.get("timed_out", "")).lower() == "true":
  160. raise Exception("Es Timeout.")
  161. doc_store_logger.info("ESConnection.search res: " + str(res))
  162. return res
  163. except Exception as e:
  164. doc_store_logger.error(
  165. "ES search exception: " +
  166. str(e) +
  167. "\n[Q]: " +
  168. str(q))
  169. if str(e).find("Timeout") > 0:
  170. continue
  171. raise e
  172. doc_store_logger.error("ES search timeout for 3 times!")
  173. raise Exception("ES search timeout.")
  174. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  175. for i in range(3):
  176. try:
  177. res = self.es.get(index=(indexName),
  178. id=chunkId, source=True,)
  179. if str(res.get("timed_out", "")).lower() == "true":
  180. raise Exception("Es Timeout.")
  181. if not res.get("found"):
  182. return None
  183. chunk = res["_source"]
  184. chunk["id"] = chunkId
  185. return chunk
  186. except Exception as e:
  187. doc_store_logger.error(
  188. "ES get exception: " +
  189. str(e) +
  190. "[Q]: " +
  191. chunkId)
  192. if str(e).find("Timeout") > 0:
  193. continue
  194. raise e
  195. doc_store_logger.error("ES search timeout for 3 times!")
  196. raise Exception("ES search timeout.")
  197. def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
  198. # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
  199. operations = []
  200. for d in documents:
  201. assert "_id" not in d
  202. assert "id" in d
  203. d_copy = copy.deepcopy(d)
  204. meta_id = d_copy["id"]
  205. del d_copy["id"]
  206. operations.append(
  207. {"index": {"_index": indexName, "_id": meta_id}})
  208. operations.append(d_copy)
  209. res = []
  210. for _ in range(100):
  211. try:
  212. r = self.es.bulk(index=(indexName), operations=operations,
  213. refresh=False, timeout="600s")
  214. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  215. return res
  216. for item in r["items"]:
  217. for action in ["create", "delete", "index", "update"]:
  218. if action in item and "error" in item[action]:
  219. res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
  220. return res
  221. except Exception as e:
  222. doc_store_logger.warning("Fail to bulk: " + str(e))
  223. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  224. time.sleep(3)
  225. continue
  226. return res
  227. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  228. doc = copy.deepcopy(newValue)
  229. del doc['id']
  230. if "id" in condition and isinstance(condition["id"], str):
  231. # update specific single document
  232. chunkId = condition["id"]
  233. for i in range(3):
  234. try:
  235. self.es.update(index=indexName, id=chunkId, doc=doc)
  236. return True
  237. except Exception as e:
  238. doc_store_logger.error(
  239. "ES update exception: " + str(e) + " id:" + str(id) +
  240. json.dumps(newValue, ensure_ascii=False))
  241. if str(e).find("Timeout") > 0:
  242. continue
  243. else:
  244. # update unspecific maybe-multiple documents
  245. bqry = Q("bool")
  246. for k, v in condition.items():
  247. if not isinstance(k, str) or not v:
  248. continue
  249. if isinstance(v, list):
  250. bqry.filter.append(Q("terms", **{k: v}))
  251. elif isinstance(v, str) or isinstance(v, int):
  252. bqry.filter.append(Q("term", **{k: v}))
  253. else:
  254. raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  255. scripts = []
  256. for k, v in newValue.items():
  257. if not isinstance(k, str) or not v:
  258. continue
  259. if isinstance(v, str):
  260. scripts.append(f"ctx._source.{k} = '{v}'")
  261. elif isinstance(v, int):
  262. scripts.append(f"ctx._source.{k} = {v}")
  263. else:
  264. raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
  265. ubq = UpdateByQuery(
  266. index=indexName).using(
  267. self.es).query(bqry)
  268. ubq = ubq.script(source="; ".join(scripts))
  269. ubq = ubq.params(refresh=True)
  270. ubq = ubq.params(slices=5)
  271. ubq = ubq.params(conflicts="proceed")
  272. for i in range(3):
  273. try:
  274. _ = ubq.execute()
  275. return True
  276. except Exception as e:
  277. doc_store_logger.error("ES update exception: " +
  278. str(e) + "[Q]:" + str(bqry.to_dict()))
  279. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  280. continue
  281. return False
  282. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  283. qry = None
  284. assert "_id" not in condition
  285. if "id" in condition:
  286. chunk_ids = condition["id"]
  287. if not isinstance(chunk_ids, list):
  288. chunk_ids = [chunk_ids]
  289. qry = Q("ids", values=chunk_ids)
  290. else:
  291. qry = Q("bool")
  292. for k, v in condition.items():
  293. if isinstance(v, list):
  294. qry.must.append(Q("terms", **{k: v}))
  295. elif isinstance(v, str) or isinstance(v, int):
  296. qry.must.append(Q("term", **{k: v}))
  297. else:
  298. raise Exception("Condition value must be int, str or list.")
  299. doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
  300. for _ in range(10):
  301. try:
  302. res = self.es.delete_by_query(
  303. index=indexName,
  304. body = Search().query(qry).to_dict(),
  305. refresh=True)
  306. return res["deleted"]
  307. except Exception as e:
  308. doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
  309. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  310. time.sleep(3)
  311. continue
  312. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  313. return 0
  314. return 0
  315. """
  316. Helper functions for search result
  317. """
  318. def getTotal(self, res):
  319. if isinstance(res["hits"]["total"], type({})):
  320. return res["hits"]["total"]["value"]
  321. return res["hits"]["total"]
  322. def getChunkIds(self, res):
  323. return [d["_id"] for d in res["hits"]["hits"]]
  324. def __getSource(self, res):
  325. rr = []
  326. for d in res["hits"]["hits"]:
  327. d["_source"]["id"] = d["_id"]
  328. d["_source"]["_score"] = d["_score"]
  329. rr.append(d["_source"])
  330. return rr
  331. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  332. res_fields = {}
  333. if not fields:
  334. return {}
  335. for d in self.__getSource(res):
  336. m = {n: d.get(n) for n in fields if d.get(n) is not None}
  337. for n, v in m.items():
  338. if isinstance(v, list):
  339. m[n] = v
  340. continue
  341. if not isinstance(v, str):
  342. m[n] = str(m[n])
  343. # if n.find("tks") > 0:
  344. # m[n] = rmSpace(m[n])
  345. if m:
  346. res_fields[d["id"]] = m
  347. return res_fields
  348. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  349. ans = {}
  350. for d in res["hits"]["hits"]:
  351. hlts = d.get("highlight")
  352. if not hlts:
  353. continue
  354. txt = "...".join([a for a in list(hlts.items())[0][1]])
  355. if not is_english(txt.split(" ")):
  356. ans[d["_id"]] = txt
  357. continue
  358. txt = d["_source"][fieldnm]
  359. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
  360. txts = []
  361. for t in re.split(r"[.?!;\n]", txt):
  362. for w in keywords:
  363. t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
  364. if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
  365. continue
  366. txts.append(t)
  367. ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
  368. return ans
  369. def getAggregation(self, res, fieldnm: str):
  370. agg_field = "aggs_" + fieldnm
  371. if "aggregations" not in res or agg_field not in res["aggregations"]:
  372. return list()
  373. bkts = res["aggregations"][agg_field]["buckets"]
  374. return [(b["key"], b["doc_count"]) for b in bkts]
  375. """
  376. SQL
  377. """
  378. def sql(self, sql: str, fetch_size: int, format: str):
  379. doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
  380. sql = re.sub(r"[ `]+", " ", sql)
  381. sql = sql.replace("%", "")
  382. replaces = []
  383. for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
  384. fld, v = r.group(1), r.group(3)
  385. match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
  386. fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
  387. replaces.append(
  388. ("{}{}'{}'".format(
  389. r.group(1),
  390. r.group(2),
  391. r.group(3)),
  392. match))
  393. for p, r in replaces:
  394. sql = sql.replace(p, r, 1)
  395. doc_store_logger.info(f"ESConnection.sql to es: {sql}")
  396. for i in range(3):
  397. try:
  398. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
  399. return res
  400. except ConnectionTimeout:
  401. doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
  402. continue
  403. except Exception as e:
  404. doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
  405. return None
  406. doc_store_logger.error("ESConnection.sql timeout for 3 times!")
  407. return None