You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

opensearch_coon.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import json
  19. import time
  20. import os
  21. import copy
  22. from opensearchpy import OpenSearch, NotFoundError
  23. from opensearchpy import UpdateByQuery, Q, Search, Index
  24. from opensearchpy import ConnectionTimeout
  25. from rag import settings
  26. from rag.settings import TAG_FLD, PAGERANK_FLD
  27. from rag.utils import singleton
  28. from api.utils.file_utils import get_project_base_directory
  29. from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
  30. FusionExpr
  31. from rag.nlp import is_english, rag_tokenizer
  32. ATTEMPT_TIME = 2
  33. logger = logging.getLogger('ragflow.opensearch_conn')
  34. @singleton
  35. class OSConnection(DocStoreConnection):
  36. def __init__(self):
  37. self.info = {}
  38. logger.info(f"Use OpenSearch {settings.OS['hosts']} as the doc engine.")
  39. for _ in range(ATTEMPT_TIME):
  40. try:
  41. self.os = OpenSearch(
  42. settings.OS["hosts"].split(","),
  43. http_auth=(settings.OS["username"], settings.OS[
  44. "password"]) if "username" in settings.OS and "password" in settings.OS else None,
  45. verify_certs=False,
  46. timeout=600
  47. )
  48. if self.os:
  49. self.info = self.os.info()
  50. break
  51. except Exception as e:
  52. logger.warning(f"{str(e)}. Waiting OpenSearch {settings.OS['hosts']} to be healthy.")
  53. time.sleep(5)
  54. if not self.os.ping():
  55. msg = f"OpenSearch {settings.OS['hosts']} is unhealthy in 120s."
  56. logger.error(msg)
  57. raise Exception(msg)
  58. v = self.info.get("version", {"number": "2.18.0"})
  59. v = v["number"].split(".")[0]
  60. if int(v) < 2:
  61. msg = f"OpenSearch version must be greater than or equal to 2, current version: {v}"
  62. logger.error(msg)
  63. raise Exception(msg)
  64. fp_mapping = os.path.join(get_project_base_directory(), "conf", "os_mapping.json")
  65. if not os.path.exists(fp_mapping):
  66. msg = f"OpenSearch mapping file not found at {fp_mapping}"
  67. logger.error(msg)
  68. raise Exception(msg)
  69. self.mapping = json.load(open(fp_mapping, "r"))
  70. logger.info(f"OpenSearch {settings.OS['hosts']} is healthy.")
  71. """
  72. Database operations
  73. """
  74. def dbType(self) -> str:
  75. return "opensearch"
  76. def health(self) -> dict:
  77. health_dict = dict(self.os.cluster.health())
  78. health_dict["type"] = "opensearch"
  79. return health_dict
  80. """
  81. Table operations
  82. """
  83. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  84. if self.indexExist(indexName, knowledgebaseId):
  85. return True
  86. try:
  87. from opensearchpy.client import IndicesClient
  88. return IndicesClient(self.os).create(index=indexName,
  89. body=self.mapping)
  90. except Exception:
  91. logger.exception("OSConnection.createIndex error %s" % (indexName))
  92. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  93. if len(knowledgebaseId) > 0:
  94. # The index need to be alive after any kb deletion since all kb under this tenant are in one index.
  95. return
  96. try:
  97. self.os.indices.delete(index=indexName, allow_no_indices=True)
  98. except NotFoundError:
  99. pass
  100. except Exception:
  101. logger.exception("OSConnection.deleteIdx error %s" % (indexName))
  102. def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
  103. s = Index(indexName, self.os)
  104. for i in range(ATTEMPT_TIME):
  105. try:
  106. return s.exists()
  107. except Exception as e:
  108. logger.exception("OSConnection.indexExist got exception")
  109. if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
  110. continue
  111. break
  112. return False
  113. """
  114. CRUD operations
  115. """
  116. def search(
  117. self, selectFields: list[str],
  118. highlightFields: list[str],
  119. condition: dict,
  120. matchExprs: list[MatchExpr],
  121. orderBy: OrderByExpr,
  122. offset: int,
  123. limit: int,
  124. indexNames: str | list[str],
  125. knowledgebaseIds: list[str],
  126. aggFields: list[str] = [],
  127. rank_feature: dict | None = None
  128. ):
  129. """
  130. Refers to https://github.com/opensearch-project/opensearch-py/blob/main/guides/dsl.md
  131. """
  132. use_knn = False
  133. if isinstance(indexNames, str):
  134. indexNames = indexNames.split(",")
  135. assert isinstance(indexNames, list) and len(indexNames) > 0
  136. assert "_id" not in condition
  137. bqry = Q("bool", must=[])
  138. condition["kb_id"] = knowledgebaseIds
  139. for k, v in condition.items():
  140. if k == "available_int":
  141. if v == 0:
  142. bqry.filter.append(Q("range", available_int={"lt": 1}))
  143. else:
  144. bqry.filter.append(
  145. Q("bool", must_not=Q("range", available_int={"lt": 1})))
  146. continue
  147. if not v:
  148. continue
  149. if isinstance(v, list):
  150. bqry.filter.append(Q("terms", **{k: v}))
  151. elif isinstance(v, str) or isinstance(v, int):
  152. bqry.filter.append(Q("term", **{k: v}))
  153. else:
  154. raise Exception(
  155. f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  156. s = Search()
  157. vector_similarity_weight = 0.5
  158. for m in matchExprs:
  159. if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
  160. assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
  161. MatchDenseExpr) and isinstance(
  162. matchExprs[2], FusionExpr)
  163. weights = m.fusion_params["weights"]
  164. vector_similarity_weight = float(weights.split(",")[1])
  165. knn_query = {}
  166. for m in matchExprs:
  167. if isinstance(m, MatchTextExpr):
  168. minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
  169. if isinstance(minimum_should_match, float):
  170. minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  171. bqry.must.append(Q("query_string", fields=m.fields,
  172. type="best_fields", query=m.matching_text,
  173. minimum_should_match=minimum_should_match,
  174. boost=1))
  175. bqry.boost = 1.0 - vector_similarity_weight
  176. # Elasticsearch has the encapsulation of KNN_search in python sdk
  177. # while the Python SDK for OpenSearch does not provide encapsulation for KNN_search,
  178. # the following codes implement KNN_search in OpenSearch using DSL
  179. # Besides, Opensearch's DSL for KNN_search query syntax differs from that in Elasticsearch, I also made some adaptions for it
  180. elif isinstance(m, MatchDenseExpr):
  181. assert (bqry is not None)
  182. similarity = 0.0
  183. if "similarity" in m.extra_options:
  184. similarity = m.extra_options["similarity"]
  185. use_knn = True
  186. vector_column_name = m.vector_column_name
  187. knn_query[vector_column_name] = {}
  188. knn_query[vector_column_name]["vector"] = list(m.embedding_data)
  189. knn_query[vector_column_name]["k"] = m.topn
  190. knn_query[vector_column_name]["filter"] = bqry.to_dict()
  191. knn_query[vector_column_name]["boost"] = similarity
  192. if bqry and rank_feature:
  193. for fld, sc in rank_feature.items():
  194. if fld != PAGERANK_FLD:
  195. fld = f"{TAG_FLD}.{fld}"
  196. bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
  197. if bqry:
  198. s = s.query(bqry)
  199. for field in highlightFields:
  200. s = s.highlight(field)
  201. if orderBy:
  202. orders = list()
  203. for field, order in orderBy.fields:
  204. order = "asc" if order == 0 else "desc"
  205. if field in ["page_num_int", "top_int"]:
  206. order_info = {"order": order, "unmapped_type": "float",
  207. "mode": "avg", "numeric_type": "double"}
  208. elif field.endswith("_int") or field.endswith("_flt"):
  209. order_info = {"order": order, "unmapped_type": "float"}
  210. else:
  211. order_info = {"order": order, "unmapped_type": "text"}
  212. orders.append({field: order_info})
  213. s = s.sort(*orders)
  214. for fld in aggFields:
  215. s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
  216. if limit > 0:
  217. s = s[offset:offset + limit]
  218. q = s.to_dict()
  219. logger.debug(f"OSConnection.search {str(indexNames)} query: " + json.dumps(q))
  220. if use_knn:
  221. del q["query"]
  222. q["query"] = {"knn" : knn_query}
  223. for i in range(ATTEMPT_TIME):
  224. try:
  225. res = self.os.search(index=indexNames,
  226. body=q,
  227. timeout=600,
  228. # search_type="dfs_query_then_fetch",
  229. track_total_hits=True,
  230. _source=True)
  231. if str(res.get("timed_out", "")).lower() == "true":
  232. raise Exception("OpenSearch Timeout.")
  233. logger.debug(f"OSConnection.search {str(indexNames)} res: " + str(res))
  234. return res
  235. except Exception as e:
  236. logger.exception(f"OSConnection.search {str(indexNames)} query: " + str(q))
  237. if str(e).find("Timeout") > 0:
  238. continue
  239. raise e
  240. logger.error("OSConnection.search timeout for 3 times!")
  241. raise Exception("OSConnection.search timeout.")
  242. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  243. for i in range(ATTEMPT_TIME):
  244. try:
  245. res = self.os.get(index=(indexName),
  246. id=chunkId, source=True, )
  247. if str(res.get("timed_out", "")).lower() == "true":
  248. raise Exception("Es Timeout.")
  249. chunk = res["_source"]
  250. chunk["id"] = chunkId
  251. return chunk
  252. except NotFoundError:
  253. return None
  254. except Exception as e:
  255. logger.exception(f"OSConnection.get({chunkId}) got exception")
  256. if str(e).find("Timeout") > 0:
  257. continue
  258. raise e
  259. logger.error("OSConnection.get timeout for 3 times!")
  260. raise Exception("OSConnection.get timeout.")
  261. def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
  262. # Refers to https://opensearch.org/docs/latest/api-reference/document-apis/bulk/
  263. operations = []
  264. for d in documents:
  265. assert "_id" not in d
  266. assert "id" in d
  267. d_copy = copy.deepcopy(d)
  268. meta_id = d_copy.pop("id", "")
  269. operations.append(
  270. {"index": {"_index": indexName, "_id": meta_id}})
  271. operations.append(d_copy)
  272. res = []
  273. for _ in range(ATTEMPT_TIME):
  274. try:
  275. res = []
  276. r = self.os.bulk(index=(indexName), body=operations,
  277. refresh=False, timeout=60)
  278. if re.search(r"False", str(r["errors"]), re.IGNORECASE):
  279. return res
  280. for item in r["items"]:
  281. for action in ["create", "delete", "index", "update"]:
  282. if action in item and "error" in item[action]:
  283. res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
  284. return res
  285. except Exception as e:
  286. res.append(str(e))
  287. logger.warning("OSConnection.insert got exception: " + str(e))
  288. res = []
  289. if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
  290. res.append(str(e))
  291. time.sleep(3)
  292. continue
  293. return res
  294. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  295. doc = copy.deepcopy(newValue)
  296. doc.pop("id", None)
  297. if "id" in condition and isinstance(condition["id"], str):
  298. # update specific single document
  299. chunkId = condition["id"]
  300. for i in range(ATTEMPT_TIME):
  301. try:
  302. self.os.update(index=indexName, id=chunkId, doc=doc)
  303. return True
  304. except Exception as e:
  305. logger.exception(
  306. f"OSConnection.update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
  307. if re.search(r"(timeout|connection)", str(e).lower()):
  308. continue
  309. break
  310. return False
  311. # update unspecific maybe-multiple documents
  312. bqry = Q("bool")
  313. for k, v in condition.items():
  314. if not isinstance(k, str) or not v:
  315. continue
  316. if k == "exists":
  317. bqry.filter.append(Q("exists", field=v))
  318. continue
  319. if isinstance(v, list):
  320. bqry.filter.append(Q("terms", **{k: v}))
  321. elif isinstance(v, str) or isinstance(v, int):
  322. bqry.filter.append(Q("term", **{k: v}))
  323. else:
  324. raise Exception(
  325. f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
  326. scripts = []
  327. params = {}
  328. for k, v in newValue.items():
  329. if k == "remove":
  330. if isinstance(v, str):
  331. scripts.append(f"ctx._source.remove('{v}');")
  332. if isinstance(v, dict):
  333. for kk, vv in v.items():
  334. scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
  335. params[f"p_{kk}"] = vv
  336. continue
  337. if k == "add":
  338. if isinstance(v, dict):
  339. for kk, vv in v.items():
  340. scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
  341. params[f"pp_{kk}"] = vv.strip()
  342. continue
  343. if (not isinstance(k, str) or not v) and k != "available_int":
  344. continue
  345. if isinstance(v, str):
  346. v = re.sub(r"(['\n\r]|\\.)", " ", v)
  347. params[f"pp_{k}"] = v
  348. scripts.append(f"ctx._source.{k}=params.pp_{k};")
  349. elif isinstance(v, int) or isinstance(v, float):
  350. scripts.append(f"ctx._source.{k}={v};")
  351. elif isinstance(v, list):
  352. scripts.append(f"ctx._source.{k}=params.pp_{k};")
  353. params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
  354. else:
  355. raise Exception(
  356. f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
  357. ubq = UpdateByQuery(
  358. index=indexName).using(
  359. self.os).query(bqry)
  360. ubq = ubq.script(source="".join(scripts), params=params)
  361. ubq = ubq.params(refresh=True)
  362. ubq = ubq.params(slices=5)
  363. ubq = ubq.params(conflicts="proceed")
  364. for _ in range(ATTEMPT_TIME):
  365. try:
  366. _ = ubq.execute()
  367. return True
  368. except Exception as e:
  369. logger.error("OSConnection.update got exception: " + str(e) + "\n".join(scripts))
  370. if re.search(r"(timeout|connection|conflict)", str(e).lower()):
  371. continue
  372. break
  373. return False
  374. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  375. qry = None
  376. assert "_id" not in condition
  377. if "id" in condition:
  378. chunk_ids = condition["id"]
  379. if not isinstance(chunk_ids, list):
  380. chunk_ids = [chunk_ids]
  381. qry = Q("ids", values=chunk_ids)
  382. else:
  383. qry = Q("bool")
  384. for k, v in condition.items():
  385. if k == "exists":
  386. qry.filter.append(Q("exists", field=v))
  387. elif k == "must_not":
  388. if isinstance(v, dict):
  389. for kk, vv in v.items():
  390. if kk == "exists":
  391. qry.must_not.append(Q("exists", field=vv))
  392. elif isinstance(v, list):
  393. qry.must.append(Q("terms", **{k: v}))
  394. elif isinstance(v, str) or isinstance(v, int):
  395. qry.must.append(Q("term", **{k: v}))
  396. else:
  397. raise Exception("Condition value must be int, str or list.")
  398. logger.debug("OSConnection.delete query: " + json.dumps(qry.to_dict()))
  399. for _ in range(ATTEMPT_TIME):
  400. try:
  401. #print(Search().query(qry).to_dict(), flush=True)
  402. res = self.os.delete_by_query(
  403. index=indexName,
  404. body=Search().query(qry).to_dict(),
  405. refresh=True)
  406. return res["deleted"]
  407. except Exception as e:
  408. logger.warning("OSConnection.delete got exception: " + str(e))
  409. if re.search(r"(timeout|connection)", str(e).lower()):
  410. time.sleep(3)
  411. continue
  412. if re.search(r"(not_found)", str(e), re.IGNORECASE):
  413. return 0
  414. return 0
  415. """
  416. Helper functions for search result
  417. """
  418. def getTotal(self, res):
  419. if isinstance(res["hits"]["total"], type({})):
  420. return res["hits"]["total"]["value"]
  421. return res["hits"]["total"]
  422. def getChunkIds(self, res):
  423. return [d["_id"] for d in res["hits"]["hits"]]
  424. def __getSource(self, res):
  425. rr = []
  426. for d in res["hits"]["hits"]:
  427. d["_source"]["id"] = d["_id"]
  428. d["_source"]["_score"] = d["_score"]
  429. rr.append(d["_source"])
  430. return rr
  431. def getFields(self, res, fields: list[str]) -> dict[str, dict]:
  432. res_fields = {}
  433. if not fields:
  434. return {}
  435. for d in self.__getSource(res):
  436. m = {n: d.get(n) for n in fields if d.get(n) is not None}
  437. for n, v in m.items():
  438. if isinstance(v, list):
  439. m[n] = v
  440. continue
  441. if not isinstance(v, str):
  442. m[n] = str(m[n])
  443. # if n.find("tks") > 0:
  444. # m[n] = rmSpace(m[n])
  445. if m:
  446. res_fields[d["id"]] = m
  447. return res_fields
  448. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  449. ans = {}
  450. for d in res["hits"]["hits"]:
  451. hlts = d.get("highlight")
  452. if not hlts:
  453. continue
  454. txt = "...".join([a for a in list(hlts.items())[0][1]])
  455. if not is_english(txt.split()):
  456. ans[d["_id"]] = txt
  457. continue
  458. txt = d["_source"][fieldnm]
  459. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  460. txts = []
  461. for t in re.split(r"[.?!;\n]", txt):
  462. for w in keywords:
  463. t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
  464. flags=re.IGNORECASE | re.MULTILINE)
  465. if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
  466. continue
  467. txts.append(t)
  468. ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
  469. return ans
  470. def getAggregation(self, res, fieldnm: str):
  471. agg_field = "aggs_" + fieldnm
  472. if "aggregations" not in res or agg_field not in res["aggregations"]:
  473. return list()
  474. bkts = res["aggregations"][agg_field]["buckets"]
  475. return [(b["key"], b["doc_count"]) for b in bkts]
  476. """
  477. SQL
  478. """
  479. def sql(self, sql: str, fetch_size: int, format: str):
  480. logger.debug(f"OSConnection.sql get sql: {sql}")
  481. sql = re.sub(r"[ `]+", " ", sql)
  482. sql = sql.replace("%", "")
  483. replaces = []
  484. for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
  485. fld, v = r.group(1), r.group(3)
  486. match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
  487. fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
  488. replaces.append(
  489. ("{}{}'{}'".format(
  490. r.group(1),
  491. r.group(2),
  492. r.group(3)),
  493. match))
  494. for p, r in replaces:
  495. sql = sql.replace(p, r, 1)
  496. logger.debug(f"OSConnection.sql to os: {sql}")
  497. for i in range(ATTEMPT_TIME):
  498. try:
  499. res = self.os.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
  500. request_timeout="2s")
  501. return res
  502. except ConnectionTimeout:
  503. logger.exception("OSConnection.sql timeout")
  504. continue
  505. except Exception:
  506. logger.exception("OSConnection.sql got exception")
  507. return None
  508. logger.error("OSConnection.sql timeout for 3 times!")
  509. return None