You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

es_conn.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. import logging
  2. import re
  3. import json
  4. import time
  5. import os
  6. from typing import List, Dict
  7. import copy
  8. from elasticsearch import Elasticsearch
  9. from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
  10. from elastic_transport import ConnectionTimeout
  11. from rag import settings
  12. from rag.utils import singleton
  13. from api.utils.file_utils import get_project_base_directory
  14. import polars as pl
  15. from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
  16. FusionExpr
  17. from rag.nlp import is_english, rag_tokenizer
  18. @singleton
  19. class ESConnection(DocStoreConnection):
  20. def __init__(self):
  21. self.info = {}
  22. logging.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
  23. for _ in range(24):
  24. try:
  25. self.es = Elasticsearch(
  26. settings.ES["hosts"].split(","),
  27. basic_auth=(settings.ES["username"], settings.ES[
  28. "password"]) if "username" in settings.ES and "password" in settings.ES else None,
  29. verify_certs=False,
  30. timeout=600
  31. )
  32. if self.es:
  33. self.info = self.es.info()
  34. break
  35. except Exception as e:
  36. logging.warn(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
  37. time.sleep(5)
  38. if not self.es.ping():
  39. msg = f"Elasticsearch {settings.ES['hosts']} didn't become healthy in 120s."
  40. logging.error(msg)
  41. raise Exception(msg)
  42. v = self.info.get("version", {"number": "8.11.3"})
  43. v = v["number"].split(".")[0]
  44. if int(v) < 8:
  45. msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
  46. logging.error(msg)
  47. raise Exception(msg)
  48. fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
  49. if not os.path.exists(fp_mapping):
  50. msg = f"Elasticsearch mapping file not found at {fp_mapping}"
  51. logging.error(msg)
  52. raise Exception(msg)
  53. self.mapping = json.load(open(fp_mapping, "r"))
  54. logging.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
  55. """
  56. Database operations
  57. """
  58. def dbType(self) -> str:
  59. return "elasticsearch"
  60. def health(self) -> dict:
  61. return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
  62. """
  63. Table operations
  64. """
  65. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  66. if self.indexExist(indexName, knowledgebaseId):
  67. return True
  68. try:
  69. from elasticsearch.client import IndicesClient
  70. return IndicesClient(self.es).create(index=indexName,
  71. settings=self.mapping["settings"],
  72. mappings=self.mapping["mappings"])
  73. except Exception:
  74. logging.exception("ES create index error %s" % (indexName))
  75. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  76. try:
  77. return self.es.indices.delete(indexName, allow_no_indices=True)
  78. except Exception:
  79. logging.exception("ES delete index error %s" % (indexName))
  80. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  81. s = Index(indexName, self.es)
  82. for i in range(3):
  83. try:
  84. return s.exists()
  85. except Exception as e:
  86. logging.exception("ES indexExist")
  87. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  88. continue
  89. return False
  90. """
  91. CRUD operations
  92. """
  93. def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
  94. orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
  95. knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
  96. """
  97. Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
  98. """
  99. if isinstance(indexNames, str):
  100. indexNames = indexNames.split(",")
  101. assert isinstance(indexNames, list) and len(indexNames) > 0
  102. assert "_id" not in condition
  103. s = Search()
  104. bqry = None
  105. vector_similarity_weight = 0.5
  106. for m in matchExprs:
  107. if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
  108. assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
  109. MatchDenseExpr) and isinstance(
  110. matchExprs[2], FusionExpr)
  111. weights = m.fusion_params["weights"]
  112. vector_similarity_weight = float(weights.split(",")[1])
  113. for m in matchExprs:
  114. if isinstance(m, MatchTextExpr):
  115. minimum_should_match = "0%"
  116. if "minimum_should_match" in m.extra_options:
  117. minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
  118. bqry = Q("bool",
  119. must=Q("query_string", fields=m.fields,
  120. type="best_fields", query=m.matching_text,
  121. minimum_should_match=minimum_should_match,
  122. boost=1),
  123. boost=1.0 - vector_similarity_weight,
  124. )
  125. elif isinstance(m, MatchDenseExpr):
  126. assert (bqry is not None)
  127. similarity = 0.0
  128. if "similarity" in m.extra_options:
  129. similarity = m.extra_options["similarity"]
  130. s = s.knn(m.vector_column_name,
  131. m.topn,
  132. m.topn * 2,
  133. query_vector=list(m.embedding_data),
  134. filter=bqry.to_dict(),
  135. similarity=similarity,
  136. )
  137. if condition:
  138. if not bqry:
  139. bqry = Q("bool", must=[])
  140. for k, v in condition.items():
  141. if not isinstance(k, str) or not v:
  142. continue
  143. if isinstance(v, list):
  144. bqry.filter.append(Q("terms", **{k: v}))
  145. elif isinstance(v, str) or isinstance(v, int):
  146. bqry.filter.append(Q("term", **{k: v}))
  147. else:
  148. raise Exception(
  149. f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  150. if bqry:
  151. s = s.query(bqry)
  152. for field in highlightFields:
  153. s = s.highlight(field)
  154. if orderBy:
  155. orders = list()
  156. for field, order in orderBy.fields:
  157. order = "asc" if order == 0 else "desc"
  158. orders.append({field: {"order": order, "unmapped_type": "float",
  159. "mode": "avg", "numeric_type": "double"}})
  160. s = s.sort(*orders)
  161. if limit > 0:
  162. s = s[offset:limit]
  163. q = s.to_dict()
  164. print(json.dumps(q), flush=True)
  165. logging.debug("ESConnection.search [Q]: " + json.dumps(q))
  166. for i in range(3):
  167. try:
  168. res = self.es.search(index=indexNames,
  169. body=q,
  170. timeout="600s",
  171. # search_type="dfs_query_then_fetch",
  172. track_total_hits=True,
  173. _source=True)
  174. if str(res.get("timed_out", "")).lower() == "true":
  175. raise Exception("Es Timeout.")
  176. logging.debug("ESConnection.search res: " + str(res))
  177. return res
  178. except Exception as e:
  179. logging.exception("ES search [Q]: " + str(q))
  180. if str(e).find("Timeout") > 0:
  181. continue
  182. raise e
  183. logging.error("ES search timeout for 3 times!")
  184. raise Exception("ES search timeout.")
  185. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  186. for i in range(3):
  187. try:
  188. res = self.es.get(index=(indexName),
  189. id=chunkId, source=True, )
  190. if str(res.get("timed_out", "")).lower() == "true":
  191. raise Exception("Es Timeout.")
  192. if not res.get("found"):
  193. return None
  194. chunk = res["_source"]
  195. chunk["id"] = chunkId
  196. return chunk
  197. except Exception as e:
  198. logging.exception(f"ES get({chunkId}) got exception")
  199. if str(e).find("Timeout") > 0:
  200. continue
  201. raise e
  202. logging.error("ES search timeout for 3 times!")
  203. raise Exception("ES search timeout.")
  204. def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
  205. # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
  206. operations = []
  207. for d in documents:
  208. assert "_id" not in d
  209. assert "id" in d
  210. d_copy = copy.deepcopy(d)
  211. meta_id = d_copy["id"]
  212. del d_copy["id"]
  213. operations.append(
  214. {"index": {"_index": indexName, "_id": meta_id}})
  215. operations.append(d_copy)
  216. res = []
  217. for _ in range(100):
  218. try:
  219. r = self.es.bulk(index=(indexName), operations=operations,
  220. refresh=False, timeout="600s")
  221. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  222. return res
  223. for item in r["items"]:
  224. for action in ["create", "delete", "index", "update"]:
  225. if action in item and "error" in item[action]:
  226. res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
  227. return res
  228. except Exception as e:
  229. logging.warning("Fail to bulk: " + str(e))
  230. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  231. time.sleep(3)
  232. continue
  233. return res
  234. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  235. doc = copy.deepcopy(newValue)
  236. del doc['id']
  237. if "id" in condition and isinstance(condition["id"], str):
  238. # update specific single document
  239. chunkId = condition["id"]
  240. for i in range(3):
  241. try:
  242. self.es.update(index=indexName, id=chunkId, doc=doc)
  243. return True
  244. except Exception as e:
  245. logging.exception(
  246. f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
  247. if str(e).find("Timeout") > 0:
  248. continue
  249. else:
  250. # update unspecific maybe-multiple documents
  251. bqry = Q("bool")
  252. for k, v in condition.items():
  253. if not isinstance(k, str) or not v:
  254. continue
  255. if isinstance(v, list):
  256. bqry.filter.append(Q("terms", **{k: v}))
  257. elif isinstance(v, str) or isinstance(v, int):
  258. bqry.filter.append(Q("term", **{k: v}))
  259. else:
  260. raise Exception(
  261. f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  262. scripts = []
  263. for k, v in newValue.items():
  264. if not isinstance(k, str) or not v:
  265. continue
  266. if isinstance(v, str):
  267. scripts.append(f"ctx._source.{k} = '{v}'")
  268. elif isinstance(v, int):
  269. scripts.append(f"ctx._source.{k} = {v}")
  270. else:
  271. raise Exception(
  272. f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
  273. ubq = UpdateByQuery(
  274. index=indexName).using(
  275. self.es).query(bqry)
  276. ubq = ubq.script(source="; ".join(scripts))
  277. ubq = ubq.params(refresh=True)
  278. ubq = ubq.params(slices=5)
  279. ubq = ubq.params(conflicts="proceed")
  280. for i in range(3):
  281. try:
  282. _ = ubq.execute()
  283. return True
  284. except Exception as e:
  285. logging.error("ES update exception: " + str(e) + "[Q]:" + str(bqry.to_dict()))
  286. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  287. continue
  288. return False
  289. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  290. qry = None
  291. assert "_id" not in condition
  292. if "id" in condition:
  293. chunk_ids = condition["id"]
  294. if not isinstance(chunk_ids, list):
  295. chunk_ids = [chunk_ids]
  296. qry = Q("ids", values=chunk_ids)
  297. else:
  298. qry = Q("bool")
  299. for k, v in condition.items():
  300. if isinstance(v, list):
  301. qry.must.append(Q("terms", **{k: v}))
  302. elif isinstance(v, str) or isinstance(v, int):
  303. qry.must.append(Q("term", **{k: v}))
  304. else:
  305. raise Exception("Condition value must be int, str or list.")
  306. logging.debug("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
  307. for _ in range(10):
  308. try:
  309. res = self.es.delete_by_query(
  310. index=indexName,
  311. body=Search().query(qry).to_dict(),
  312. refresh=True)
  313. return res["deleted"]
  314. except Exception as e:
  315. logging.warning("Fail to delete: " + str(filter) + str(e))
  316. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  317. time.sleep(3)
  318. continue
  319. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  320. return 0
  321. return 0
  322. """
  323. Helper functions for search result
  324. """
  325. def getTotal(self, res):
  326. if isinstance(res["hits"]["total"], type({})):
  327. return res["hits"]["total"]["value"]
  328. return res["hits"]["total"]
  329. def getChunkIds(self, res):
  330. return [d["_id"] for d in res["hits"]["hits"]]
  331. def __getSource(self, res):
  332. rr = []
  333. for d in res["hits"]["hits"]:
  334. d["_source"]["id"] = d["_id"]
  335. d["_source"]["_score"] = d["_score"]
  336. rr.append(d["_source"])
  337. return rr
  338. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  339. res_fields = {}
  340. if not fields:
  341. return {}
  342. for d in self.__getSource(res):
  343. m = {n: d.get(n) for n in fields if d.get(n) is not None}
  344. for n, v in m.items():
  345. if isinstance(v, list):
  346. m[n] = v
  347. continue
  348. if not isinstance(v, str):
  349. m[n] = str(m[n])
  350. # if n.find("tks") > 0:
  351. # m[n] = rmSpace(m[n])
  352. if m:
  353. res_fields[d["id"]] = m
  354. return res_fields
  355. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  356. ans = {}
  357. for d in res["hits"]["hits"]:
  358. hlts = d.get("highlight")
  359. if not hlts:
  360. continue
  361. txt = "...".join([a for a in list(hlts.items())[0][1]])
  362. if not is_english(txt.split(" ")):
  363. ans[d["_id"]] = txt
  364. continue
  365. txt = d["_source"][fieldnm]
  366. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  367. txts = []
  368. for t in re.split(r"[.?!;\n]", txt):
  369. for w in keywords:
  370. t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
  371. flags=re.IGNORECASE | re.MULTILINE)
  372. if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
  373. continue
  374. txts.append(t)
  375. ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
  376. return ans
  377. def getAggregation(self, res, fieldnm: str):
  378. agg_field = "aggs_" + fieldnm
  379. if "aggregations" not in res or agg_field not in res["aggregations"]:
  380. return list()
  381. bkts = res["aggregations"][agg_field]["buckets"]
  382. return [(b["key"], b["doc_count"]) for b in bkts]
  383. """
  384. SQL
  385. """
  386. def sql(self, sql: str, fetch_size: int, format: str):
  387. logging.debug(f"ESConnection.sql get sql: {sql}")
  388. sql = re.sub(r"[ `]+", " ", sql)
  389. sql = sql.replace("%", "")
  390. replaces = []
  391. for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
  392. fld, v = r.group(1), r.group(3)
  393. match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
  394. fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
  395. replaces.append(
  396. ("{}{}'{}'".format(
  397. r.group(1),
  398. r.group(2),
  399. r.group(3)),
  400. match))
  401. for p, r in replaces:
  402. sql = sql.replace(p, r, 1)
  403. logging.debug(f"ESConnection.sql to es: {sql}")
  404. for i in range(3):
  405. try:
  406. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
  407. request_timeout="2s")
  408. return res
  409. except ConnectionTimeout:
  410. logging.exception("ESConnection.sql timeout [Q]: " + sql)
  411. continue
  412. except Exception:
  413. logging.exception("ESConnection.sql got exception [Q]: " + sql)
  414. return None
  415. logging.error("ESConnection.sql timeout for 3 times!")
  416. return None