Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

es_conn.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. import re
  2. import json
  3. import time
  4. import os
  5. from typing import List, Dict
  6. import elasticsearch
  7. import copy
  8. from elasticsearch import Elasticsearch
  9. from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
  10. from elastic_transport import ConnectionTimeout
  11. from api.utils.log_utils import logger
  12. from rag import settings
  13. from rag.utils import singleton
  14. from api.utils.file_utils import get_project_base_directory
  15. import polars as pl
  16. from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
  17. from rag.nlp import is_english, rag_tokenizer
  18. logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
  19. @singleton
  20. class ESConnection(DocStoreConnection):
  21. def __init__(self):
  22. self.info = {}
  23. for _ in range(10):
  24. try:
  25. self.es = Elasticsearch(
  26. settings.ES["hosts"].split(","),
  27. basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None,
  28. verify_certs=False,
  29. timeout=600
  30. )
  31. if self.es:
  32. self.info = self.es.info()
  33. logger.info("Connect to es.")
  34. break
  35. except Exception:
  36. logger.exception("Fail to connect to es")
  37. time.sleep(1)
  38. if not self.es.ping():
  39. raise Exception("Can't connect to ES cluster")
  40. v = self.info.get("version", {"number": "5.6"})
  41. v = v["number"].split(".")[0]
  42. if int(v) < 8:
  43. raise Exception(f"ES version must be greater than or equal to 8, current version: {v}")
  44. fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
  45. if not os.path.exists(fp_mapping):
  46. raise Exception(f"Mapping file not found at {fp_mapping}")
  47. self.mapping = json.load(open(fp_mapping, "r"))
  48. """
  49. Database operations
  50. """
  51. def dbType(self) -> str:
  52. return "elasticsearch"
  53. def health(self) -> dict:
  54. return dict(self.es.cluster.health()) + {"type": "elasticsearch"}
  55. """
  56. Table operations
  57. """
  58. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  59. if self.indexExist(indexName, knowledgebaseId):
  60. return True
  61. try:
  62. from elasticsearch.client import IndicesClient
  63. return IndicesClient(self.es).create(index=indexName,
  64. settings=self.mapping["settings"],
  65. mappings=self.mapping["mappings"])
  66. except Exception:
  67. logger.exception("ES create index error %s" % (indexName))
  68. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  69. try:
  70. return self.es.indices.delete(indexName, allow_no_indices=True)
  71. except Exception:
  72. logger.exception("ES delete index error %s" % (indexName))
  73. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  74. s = Index(indexName, self.es)
  75. for i in range(3):
  76. try:
  77. return s.exists()
  78. except Exception as e:
  79. logger.exception("ES indexExist")
  80. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  81. continue
  82. return False
  83. """
  84. CRUD operations
  85. """
  86. def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
  87. """
  88. Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
  89. """
  90. if isinstance(indexNames, str):
  91. indexNames = indexNames.split(",")
  92. assert isinstance(indexNames, list) and len(indexNames) > 0
  93. assert "_id" not in condition
  94. s = Search()
  95. bqry = None
  96. vector_similarity_weight = 0.5
  97. for m in matchExprs:
  98. if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params:
  99. assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr)
  100. weights = m.fusion_params["weights"]
  101. vector_similarity_weight = float(weights.split(",")[1])
  102. for m in matchExprs:
  103. if isinstance(m, MatchTextExpr):
  104. minimum_should_match = "0%"
  105. if "minimum_should_match" in m.extra_options:
  106. minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%"
  107. bqry = Q("bool",
  108. must=Q("query_string", fields=m.fields,
  109. type="best_fields", query=m.matching_text,
  110. minimum_should_match = minimum_should_match,
  111. boost=1),
  112. boost = 1.0 - vector_similarity_weight,
  113. )
  114. if condition:
  115. for k, v in condition.items():
  116. if not isinstance(k, str) or not v:
  117. continue
  118. if isinstance(v, list):
  119. bqry.filter.append(Q("terms", **{k: v}))
  120. elif isinstance(v, str) or isinstance(v, int):
  121. bqry.filter.append(Q("term", **{k: v}))
  122. else:
  123. raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  124. elif isinstance(m, MatchDenseExpr):
  125. assert(bqry is not None)
  126. similarity = 0.0
  127. if "similarity" in m.extra_options:
  128. similarity = m.extra_options["similarity"]
  129. s = s.knn(m.vector_column_name,
  130. m.topn,
  131. m.topn * 2,
  132. query_vector = list(m.embedding_data),
  133. filter = bqry.to_dict(),
  134. similarity = similarity,
  135. )
  136. if matchExprs:
  137. s.query = bqry
  138. for field in highlightFields:
  139. s = s.highlight(field)
  140. if orderBy:
  141. orders = list()
  142. for field, order in orderBy.fields:
  143. order = "asc" if order == 0 else "desc"
  144. orders.append({field: {"order": order, "unmapped_type": "float",
  145. "mode": "avg", "numeric_type": "double"}})
  146. s = s.sort(*orders)
  147. if limit > 0:
  148. s = s[offset:limit]
  149. q = s.to_dict()
  150. # logger.info("ESConnection.search [Q]: " + json.dumps(q))
  151. for i in range(3):
  152. try:
  153. res = self.es.search(index=indexNames,
  154. body=q,
  155. timeout="600s",
  156. # search_type="dfs_query_then_fetch",
  157. track_total_hits=True,
  158. _source=True)
  159. if str(res.get("timed_out", "")).lower() == "true":
  160. raise Exception("Es Timeout.")
  161. logger.info("ESConnection.search res: " + str(res))
  162. return res
  163. except Exception as e:
  164. logger.exception("ES search [Q]: " + str(q))
  165. if str(e).find("Timeout") > 0:
  166. continue
  167. raise e
  168. logger.error("ES search timeout for 3 times!")
  169. raise Exception("ES search timeout.")
  170. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  171. for i in range(3):
  172. try:
  173. res = self.es.get(index=(indexName),
  174. id=chunkId, source=True,)
  175. if str(res.get("timed_out", "")).lower() == "true":
  176. raise Exception("Es Timeout.")
  177. if not res.get("found"):
  178. return None
  179. chunk = res["_source"]
  180. chunk["id"] = chunkId
  181. return chunk
  182. except Exception as e:
  183. logger.exception(f"ES get({chunkId}) got exception")
  184. if str(e).find("Timeout") > 0:
  185. continue
  186. raise e
  187. logger.error("ES search timeout for 3 times!")
  188. raise Exception("ES search timeout.")
  189. def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
  190. # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
  191. operations = []
  192. for d in documents:
  193. assert "_id" not in d
  194. assert "id" in d
  195. d_copy = copy.deepcopy(d)
  196. meta_id = d_copy["id"]
  197. del d_copy["id"]
  198. operations.append(
  199. {"index": {"_index": indexName, "_id": meta_id}})
  200. operations.append(d_copy)
  201. res = []
  202. for _ in range(100):
  203. try:
  204. r = self.es.bulk(index=(indexName), operations=operations,
  205. refresh=False, timeout="600s")
  206. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  207. return res
  208. for item in r["items"]:
  209. for action in ["create", "delete", "index", "update"]:
  210. if action in item and "error" in item[action]:
  211. res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
  212. return res
  213. except Exception as e:
  214. logger.warning("Fail to bulk: " + str(e))
  215. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  216. time.sleep(3)
  217. continue
  218. return res
  219. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  220. doc = copy.deepcopy(newValue)
  221. del doc['id']
  222. if "id" in condition and isinstance(condition["id"], str):
  223. # update specific single document
  224. chunkId = condition["id"]
  225. for i in range(3):
  226. try:
  227. self.es.update(index=indexName, id=chunkId, doc=doc)
  228. return True
  229. except Exception as e:
  230. logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
  231. if str(e).find("Timeout") > 0:
  232. continue
  233. else:
  234. # update unspecific maybe-multiple documents
  235. bqry = Q("bool")
  236. for k, v in condition.items():
  237. if not isinstance(k, str) or not v:
  238. continue
  239. if isinstance(v, list):
  240. bqry.filter.append(Q("terms", **{k: v}))
  241. elif isinstance(v, str) or isinstance(v, int):
  242. bqry.filter.append(Q("term", **{k: v}))
  243. else:
  244. raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  245. scripts = []
  246. for k, v in newValue.items():
  247. if not isinstance(k, str) or not v:
  248. continue
  249. if isinstance(v, str):
  250. scripts.append(f"ctx._source.{k} = '{v}'")
  251. elif isinstance(v, int):
  252. scripts.append(f"ctx._source.{k} = {v}")
  253. else:
  254. raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
  255. ubq = UpdateByQuery(
  256. index=indexName).using(
  257. self.es).query(bqry)
  258. ubq = ubq.script(source="; ".join(scripts))
  259. ubq = ubq.params(refresh=True)
  260. ubq = ubq.params(slices=5)
  261. ubq = ubq.params(conflicts="proceed")
  262. for i in range(3):
  263. try:
  264. _ = ubq.execute()
  265. return True
  266. except Exception as e:
  267. logger.error("ES update exception: " + str(e) + "[Q]:" + str(bqry.to_dict()))
  268. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  269. continue
  270. return False
  271. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  272. qry = None
  273. assert "_id" not in condition
  274. if "id" in condition:
  275. chunk_ids = condition["id"]
  276. if not isinstance(chunk_ids, list):
  277. chunk_ids = [chunk_ids]
  278. qry = Q("ids", values=chunk_ids)
  279. else:
  280. qry = Q("bool")
  281. for k, v in condition.items():
  282. if isinstance(v, list):
  283. qry.must.append(Q("terms", **{k: v}))
  284. elif isinstance(v, str) or isinstance(v, int):
  285. qry.must.append(Q("term", **{k: v}))
  286. else:
  287. raise Exception("Condition value must be int, str or list.")
  288. logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
  289. for _ in range(10):
  290. try:
  291. res = self.es.delete_by_query(
  292. index=indexName,
  293. body = Search().query(qry).to_dict(),
  294. refresh=True)
  295. return res["deleted"]
  296. except Exception as e:
  297. logger.warning("Fail to delete: " + str(filter) + str(e))
  298. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  299. time.sleep(3)
  300. continue
  301. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  302. return 0
  303. return 0
  304. """
  305. Helper functions for search result
  306. """
  307. def getTotal(self, res):
  308. if isinstance(res["hits"]["total"], type({})):
  309. return res["hits"]["total"]["value"]
  310. return res["hits"]["total"]
  311. def getChunkIds(self, res):
  312. return [d["_id"] for d in res["hits"]["hits"]]
  313. def __getSource(self, res):
  314. rr = []
  315. for d in res["hits"]["hits"]:
  316. d["_source"]["id"] = d["_id"]
  317. d["_source"]["_score"] = d["_score"]
  318. rr.append(d["_source"])
  319. return rr
  320. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  321. res_fields = {}
  322. if not fields:
  323. return {}
  324. for d in self.__getSource(res):
  325. m = {n: d.get(n) for n in fields if d.get(n) is not None}
  326. for n, v in m.items():
  327. if isinstance(v, list):
  328. m[n] = v
  329. continue
  330. if not isinstance(v, str):
  331. m[n] = str(m[n])
  332. # if n.find("tks") > 0:
  333. # m[n] = rmSpace(m[n])
  334. if m:
  335. res_fields[d["id"]] = m
  336. return res_fields
  337. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  338. ans = {}
  339. for d in res["hits"]["hits"]:
  340. hlts = d.get("highlight")
  341. if not hlts:
  342. continue
  343. txt = "...".join([a for a in list(hlts.items())[0][1]])
  344. if not is_english(txt.split(" ")):
  345. ans[d["_id"]] = txt
  346. continue
  347. txt = d["_source"][fieldnm]
  348. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE)
  349. txts = []
  350. for t in re.split(r"[.?!;\n]", txt):
  351. for w in keywords:
  352. t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE)
  353. if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE):
  354. continue
  355. txts.append(t)
  356. ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
  357. return ans
  358. def getAggregation(self, res, fieldnm: str):
  359. agg_field = "aggs_" + fieldnm
  360. if "aggregations" not in res or agg_field not in res["aggregations"]:
  361. return list()
  362. bkts = res["aggregations"][agg_field]["buckets"]
  363. return [(b["key"], b["doc_count"]) for b in bkts]
  364. """
  365. SQL
  366. """
  367. def sql(self, sql: str, fetch_size: int, format: str):
  368. logger.info(f"ESConnection.sql get sql: {sql}")
  369. sql = re.sub(r"[ `]+", " ", sql)
  370. sql = sql.replace("%", "")
  371. replaces = []
  372. for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
  373. fld, v = r.group(1), r.group(3)
  374. match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
  375. fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
  376. replaces.append(
  377. ("{}{}'{}'".format(
  378. r.group(1),
  379. r.group(2),
  380. r.group(3)),
  381. match))
  382. for p, r in replaces:
  383. sql = sql.replace(p, r, 1)
  384. logger.info(f"ESConnection.sql to es: {sql}")
  385. for i in range(3):
  386. try:
  387. res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
  388. return res
  389. except ConnectionTimeout:
  390. logger.exception("ESConnection.sql timeout [Q]: " + sql)
  391. continue
  392. except Exception:
  393. logger.exception("ESConnection.sql got exception [Q]: " + sql)
  394. return None
  395. logger.error("ESConnection.sql timeout for 3 times!")
  396. return None