You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

infinity_conn.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import logging
  2. import os
  3. import re
  4. import json
  5. import time
  6. import infinity
  7. from infinity.common import ConflictType, InfinityException, SortType
  8. from infinity.index import IndexInfo, IndexType
  9. from infinity.connection_pool import ConnectionPool
  10. from infinity.errors import ErrorCode
  11. from rag import settings
  12. from rag.utils import singleton
  13. import polars as pl
  14. from polars.series.series import Series
  15. from api.utils.file_utils import get_project_base_directory
  16. from rag.utils.doc_store_conn import (
  17. DocStoreConnection,
  18. MatchExpr,
  19. MatchTextExpr,
  20. MatchDenseExpr,
  21. FusionExpr,
  22. OrderByExpr,
  23. )
  24. logger = logging.getLogger('ragflow.infinity_conn')
  25. def equivalent_condition_to_str(condition: dict) -> str|None:
  26. assert "_id" not in condition
  27. cond = list()
  28. for k, v in condition.items():
  29. if not isinstance(k, str) or k in ["kb_id"] or not v:
  30. continue
  31. if isinstance(v, list):
  32. inCond = list()
  33. for item in v:
  34. if isinstance(item, str):
  35. inCond.append(f"'{item}'")
  36. else:
  37. inCond.append(str(item))
  38. if inCond:
  39. strInCond = ", ".join(inCond)
  40. strInCond = f"{k} IN ({strInCond})"
  41. cond.append(strInCond)
  42. elif isinstance(v, str):
  43. cond.append(f"{k}='{v}'")
  44. else:
  45. cond.append(f"{k}={str(v)}")
  46. return " AND ".join(cond) if cond else None
  47. def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
  48. """
  49. Concatenate multiple dataframes into one.
  50. """
  51. if df_list:
  52. return pl.concat(df_list)
  53. schema = dict()
  54. for fieldnm in selectFields:
  55. schema[fieldnm] = str
  56. return pl.DataFrame(schema=schema)
  57. @singleton
  58. class InfinityConnection(DocStoreConnection):
  59. def __init__(self):
  60. self.dbName = settings.INFINITY.get("db_name", "default_db")
  61. infinity_uri = settings.INFINITY["uri"]
  62. if ":" in infinity_uri:
  63. host, port = infinity_uri.split(":")
  64. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  65. self.connPool = None
  66. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  67. for _ in range(24):
  68. try:
  69. connPool = ConnectionPool(infinity_uri)
  70. inf_conn = connPool.get_conn()
  71. res = inf_conn.show_current_node()
  72. if res.error_code == ErrorCode.OK and res.server_status=="started":
  73. self._migrate_db(inf_conn)
  74. self.connPool = connPool
  75. connPool.release_conn(inf_conn)
  76. break
  77. connPool.release_conn(inf_conn)
  78. logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
  79. time.sleep(5)
  80. except Exception as e:
  81. logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  82. time.sleep(5)
  83. if self.connPool is None:
  84. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  85. logger.error(msg)
  86. raise Exception(msg)
  87. logger.info(f"Infinity {infinity_uri} is healthy.")
  88. def _migrate_db(self, inf_conn):
  89. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  90. fp_mapping = os.path.join(
  91. get_project_base_directory(), "conf", "infinity_mapping.json"
  92. )
  93. if not os.path.exists(fp_mapping):
  94. raise Exception(f"Mapping file not found at {fp_mapping}")
  95. schema = json.load(open(fp_mapping))
  96. table_names = inf_db.list_tables().table_names
  97. for table_name in table_names:
  98. inf_table = inf_db.get_table(table_name)
  99. index_names = inf_table.list_indexes().index_names
  100. if "q_vec_idx" not in index_names:
  101. # Skip tables not created by me
  102. continue
  103. column_names = inf_table.show_columns()["name"]
  104. column_names = set(column_names)
  105. for field_name, field_info in schema.items():
  106. if field_name in column_names:
  107. continue
  108. res = inf_table.add_columns({field_name: field_info})
  109. assert res.error_code == infinity.ErrorCode.OK
  110. logger.info(
  111. f"INFINITY added following column to table {table_name}: {field_name} {field_info}"
  112. )
  113. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  114. continue
  115. inf_table.create_index(
  116. f"text_idx_{field_name}",
  117. IndexInfo(
  118. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  119. ),
  120. ConflictType.Ignore,
  121. )
  122. """
  123. Database operations
  124. """
  125. def dbType(self) -> str:
  126. return "infinity"
  127. def health(self) -> dict:
  128. """
  129. Return the health status of the database.
  130. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  131. """
  132. inf_conn = self.connPool.get_conn()
  133. res = inf_conn.show_current_node()
  134. self.connPool.release_conn(inf_conn)
  135. res2 = {
  136. "type": "infinity",
  137. "status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
  138. "error": res.error_msg,
  139. }
  140. return res2
  141. """
  142. Table operations
  143. """
  144. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  145. table_name = f"{indexName}_{knowledgebaseId}"
  146. inf_conn = self.connPool.get_conn()
  147. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  148. fp_mapping = os.path.join(
  149. get_project_base_directory(), "conf", "infinity_mapping.json"
  150. )
  151. if not os.path.exists(fp_mapping):
  152. raise Exception(f"Mapping file not found at {fp_mapping}")
  153. schema = json.load(open(fp_mapping))
  154. vector_name = f"q_{vectorSize}_vec"
  155. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  156. inf_table = inf_db.create_table(
  157. table_name,
  158. schema,
  159. ConflictType.Ignore,
  160. )
  161. inf_table.create_index(
  162. "q_vec_idx",
  163. IndexInfo(
  164. vector_name,
  165. IndexType.Hnsw,
  166. {
  167. "M": "16",
  168. "ef_construction": "50",
  169. "metric": "cosine",
  170. "encode": "lvq",
  171. },
  172. ),
  173. ConflictType.Ignore,
  174. )
  175. for field_name, field_info in schema.items():
  176. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  177. continue
  178. inf_table.create_index(
  179. f"text_idx_{field_name}",
  180. IndexInfo(
  181. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  182. ),
  183. ConflictType.Ignore,
  184. )
  185. self.connPool.release_conn(inf_conn)
  186. logger.info(
  187. f"INFINITY created table {table_name}, vector size {vectorSize}"
  188. )
  189. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  190. table_name = f"{indexName}_{knowledgebaseId}"
  191. inf_conn = self.connPool.get_conn()
  192. db_instance = inf_conn.get_database(self.dbName)
  193. db_instance.drop_table(table_name, ConflictType.Ignore)
  194. self.connPool.release_conn(inf_conn)
  195. logger.info(f"INFINITY dropped table {table_name}")
  196. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  197. table_name = f"{indexName}_{knowledgebaseId}"
  198. try:
  199. inf_conn = self.connPool.get_conn()
  200. db_instance = inf_conn.get_database(self.dbName)
  201. _ = db_instance.get_table(table_name)
  202. self.connPool.release_conn(inf_conn)
  203. return True
  204. except Exception as e:
  205. logger.warning(f"INFINITY indexExist {str(e)}")
  206. return False
  207. """
  208. CRUD operations
  209. """
  210. def search(
  211. self,
  212. selectFields: list[str],
  213. highlightFields: list[str],
  214. condition: dict,
  215. matchExprs: list[MatchExpr],
  216. orderBy: OrderByExpr,
  217. offset: int,
  218. limit: int,
  219. indexNames: str | list[str],
  220. knowledgebaseIds: list[str],
  221. ) -> tuple[pl.DataFrame, int]:
  222. """
  223. TODO: Infinity doesn't provide highlight
  224. """
  225. if isinstance(indexNames, str):
  226. indexNames = indexNames.split(",")
  227. assert isinstance(indexNames, list) and len(indexNames) > 0
  228. inf_conn = self.connPool.get_conn()
  229. db_instance = inf_conn.get_database(self.dbName)
  230. df_list = list()
  231. table_list = list()
  232. if "id" not in selectFields:
  233. selectFields.append("id")
  234. # Prepare expressions common to all tables
  235. filter_cond = None
  236. filter_fulltext = ""
  237. if condition:
  238. filter_cond = equivalent_condition_to_str(condition)
  239. for matchExpr in matchExprs:
  240. if isinstance(matchExpr, MatchTextExpr):
  241. if filter_cond and "filter" not in matchExpr.extra_options:
  242. matchExpr.extra_options.update({"filter": filter_cond})
  243. fields = ",".join(matchExpr.fields)
  244. filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  245. if filter_cond:
  246. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  247. minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
  248. if isinstance(minimum_should_match, float):
  249. str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  250. matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
  251. for k, v in matchExpr.extra_options.items():
  252. if not isinstance(v, str):
  253. matchExpr.extra_options[k] = str(v)
  254. logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
  255. elif isinstance(matchExpr, MatchDenseExpr):
  256. if filter_cond and "filter" not in matchExpr.extra_options:
  257. matchExpr.extra_options.update({"filter": filter_fulltext})
  258. for k, v in matchExpr.extra_options.items():
  259. if not isinstance(v, str):
  260. matchExpr.extra_options[k] = str(v)
  261. logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
  262. elif isinstance(matchExpr, FusionExpr):
  263. logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
  264. order_by_expr_list = list()
  265. if orderBy.fields:
  266. for order_field in orderBy.fields:
  267. if order_field[1] == 0:
  268. order_by_expr_list.append((order_field[0], SortType.Asc))
  269. else:
  270. order_by_expr_list.append((order_field[0], SortType.Desc))
  271. total_hits_count = 0
  272. # Scatter search tables and gather the results
  273. for indexName in indexNames:
  274. for knowledgebaseId in knowledgebaseIds:
  275. table_name = f"{indexName}_{knowledgebaseId}"
  276. try:
  277. table_instance = db_instance.get_table(table_name)
  278. except Exception:
  279. continue
  280. table_list.append(table_name)
  281. builder = table_instance.output(selectFields)
  282. if len(matchExprs) > 0:
  283. for matchExpr in matchExprs:
  284. if isinstance(matchExpr, MatchTextExpr):
  285. fields = ",".join(matchExpr.fields)
  286. builder = builder.match_text(
  287. fields,
  288. matchExpr.matching_text,
  289. matchExpr.topn,
  290. matchExpr.extra_options,
  291. )
  292. elif isinstance(matchExpr, MatchDenseExpr):
  293. builder = builder.match_dense(
  294. matchExpr.vector_column_name,
  295. matchExpr.embedding_data,
  296. matchExpr.embedding_data_type,
  297. matchExpr.distance_type,
  298. matchExpr.topn,
  299. matchExpr.extra_options,
  300. )
  301. elif isinstance(matchExpr, FusionExpr):
  302. builder = builder.fusion(
  303. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  304. )
  305. else:
  306. if len(filter_cond) > 0:
  307. builder.filter(filter_cond)
  308. if orderBy.fields:
  309. builder.sort(order_by_expr_list)
  310. builder.offset(offset).limit(limit)
  311. kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
  312. if extra_result:
  313. total_hits_count += int(extra_result["total_hits_count"])
  314. df_list.append(kb_res)
  315. self.connPool.release_conn(inf_conn)
  316. res = concat_dataframes(df_list, selectFields)
  317. logger.debug(f"INFINITY search tables: {str(table_list)}, result: {str(res)}")
  318. return res, total_hits_count
  319. def get(
  320. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  321. ) -> dict | None:
  322. inf_conn = self.connPool.get_conn()
  323. db_instance = inf_conn.get_database(self.dbName)
  324. df_list = list()
  325. assert isinstance(knowledgebaseIds, list)
  326. table_list = list()
  327. for knowledgebaseId in knowledgebaseIds:
  328. table_name = f"{indexName}_{knowledgebaseId}"
  329. table_list.append(table_name)
  330. table_instance = db_instance.get_table(table_name)
  331. kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  332. if len(kb_res) != 0 and kb_res.shape[0] > 0:
  333. df_list.append(kb_res)
  334. self.connPool.release_conn(inf_conn)
  335. res = concat_dataframes(df_list, ["id"])
  336. logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
  337. res_fields = self.getFields(res, res.columns)
  338. return res_fields.get(chunkId, None)
  339. def insert(
  340. self, documents: list[dict], indexName: str, knowledgebaseId: str
  341. ) -> list[str]:
  342. inf_conn = self.connPool.get_conn()
  343. db_instance = inf_conn.get_database(self.dbName)
  344. table_name = f"{indexName}_{knowledgebaseId}"
  345. try:
  346. table_instance = db_instance.get_table(table_name)
  347. except InfinityException as e:
  348. # src/common/status.cppm, kTableNotExist = 3022
  349. if e.error_code != ErrorCode.TABLE_NOT_EXIST:
  350. raise
  351. vector_size = 0
  352. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  353. for k in documents[0].keys():
  354. m = patt.match(k)
  355. if m:
  356. vector_size = int(m.group("vector_size"))
  357. break
  358. if vector_size == 0:
  359. raise ValueError("Cannot infer vector size from documents")
  360. self.createIdx(indexName, knowledgebaseId, vector_size)
  361. table_instance = db_instance.get_table(table_name)
  362. for d in documents:
  363. assert "_id" not in d
  364. assert "id" in d
  365. for k, v in d.items():
  366. if k in ["important_kwd", "question_kwd", "entities_kwd"]:
  367. assert isinstance(v, list)
  368. d[k] = "###".join(v)
  369. elif k == 'kb_id':
  370. if isinstance(d[k], list):
  371. d[k] = d[k][0] # since d[k] is a list, but we need a str
  372. elif k == "position_int":
  373. assert isinstance(v, list)
  374. arr = [num for row in v for num in row]
  375. d[k] = "_".join(f"{num:08x}" for num in arr)
  376. elif k in ["page_num_int", "top_int", "position_int"]:
  377. assert isinstance(v, list)
  378. d[k] = "_".join(f"{num:08x}" for num in v)
  379. ids = ["'{}'".format(d["id"]) for d in documents]
  380. str_ids = ", ".join(ids)
  381. str_filter = f"id IN ({str_ids})"
  382. table_instance.delete(str_filter)
  383. # for doc in documents:
  384. # logger.info(f"insert position_int: {doc['position_int']}")
  385. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  386. table_instance.insert(documents)
  387. self.connPool.release_conn(inf_conn)
  388. logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
  389. return []
  390. def update(
  391. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  392. ) -> bool:
  393. # if 'position_int' in newValue:
  394. # logger.info(f"update position_int: {newValue['position_int']}")
  395. inf_conn = self.connPool.get_conn()
  396. db_instance = inf_conn.get_database(self.dbName)
  397. table_name = f"{indexName}_{knowledgebaseId}"
  398. table_instance = db_instance.get_table(table_name)
  399. filter = equivalent_condition_to_str(condition)
  400. for k, v in newValue.items():
  401. if k.endswith("_kwd") and isinstance(v, list):
  402. newValue[k] = " ".join(v)
  403. elif k == 'kb_id':
  404. if isinstance(newValue[k], list):
  405. newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
  406. elif k == "position_int":
  407. assert isinstance(v, list)
  408. arr = [num for row in v for num in row]
  409. newValue[k] = "_".join(f"{num:08x}" for num in arr)
  410. elif k in ["page_num_int", "top_int"]:
  411. assert isinstance(v, list)
  412. newValue[k] = "_".join(f"{num:08x}" for num in v)
  413. logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
  414. table_instance.update(filter, newValue)
  415. self.connPool.release_conn(inf_conn)
  416. return True
  417. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  418. inf_conn = self.connPool.get_conn()
  419. db_instance = inf_conn.get_database(self.dbName)
  420. table_name = f"{indexName}_{knowledgebaseId}"
  421. filter = equivalent_condition_to_str(condition)
  422. try:
  423. table_instance = db_instance.get_table(table_name)
  424. except Exception:
  425. logger.warning(
  426. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  427. )
  428. return 0
  429. logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
  430. res = table_instance.delete(filter)
  431. self.connPool.release_conn(inf_conn)
  432. return res.deleted_rows
  433. """
  434. Helper functions for search result
  435. """
  436. def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
  437. if isinstance(res, tuple):
  438. return res[1]
  439. return len(res)
  440. def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
  441. if isinstance(res, tuple):
  442. res = res[0]
  443. return list(res["id"])
  444. def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
  445. if isinstance(res, tuple):
  446. res = res[0]
  447. res_fields = {}
  448. if not fields:
  449. return {}
  450. num_rows = len(res)
  451. column_id = res["id"]
  452. for i in range(num_rows):
  453. id = column_id[i]
  454. m = {"id": id}
  455. for fieldnm in fields:
  456. if fieldnm not in res:
  457. m[fieldnm] = None
  458. continue
  459. v = res[fieldnm][i]
  460. if isinstance(v, Series):
  461. v = list(v)
  462. elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd"]:
  463. assert isinstance(v, str)
  464. v = [kwd for kwd in v.split("###") if kwd]
  465. elif fieldnm == "position_int":
  466. assert isinstance(v, str)
  467. if v:
  468. arr = [int(hex_val, 16) for hex_val in v.split('_')]
  469. v = [arr[i:i + 4] for i in range(0, len(arr), 4)]
  470. else:
  471. v = []
  472. elif fieldnm in ["page_num_int", "top_int"]:
  473. assert isinstance(v, str)
  474. if v:
  475. v = [int(hex_val, 16) for hex_val in v.split('_')]
  476. else:
  477. v = []
  478. else:
  479. if not isinstance(v, str):
  480. v = str(v)
  481. # if fieldnm.endswith("_tks"):
  482. # v = rmSpace(v)
  483. m[fieldnm] = v
  484. res_fields[id] = m
  485. return res_fields
  486. def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
  487. if isinstance(res, tuple):
  488. res = res[0]
  489. ans = {}
  490. num_rows = len(res)
  491. column_id = res["id"]
  492. for i in range(num_rows):
  493. id = column_id[i]
  494. txt = res[fieldnm][i]
  495. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  496. txts = []
  497. for t in re.split(r"[.?!;\n]", txt):
  498. for w in keywords:
  499. t = re.sub(
  500. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  501. % re.escape(w),
  502. r"\1<em>\2</em>\3",
  503. t,
  504. flags=re.IGNORECASE | re.MULTILINE,
  505. )
  506. if not re.search(
  507. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  508. ):
  509. continue
  510. txts.append(t)
  511. ans[id] = "...".join(txts)
  512. return ans
  513. def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
  514. """
  515. TODO: Infinity doesn't provide aggregation
  516. """
  517. return list()
  518. """
  519. SQL
  520. """
  521. def sql(sql: str, fetch_size: int, format: str):
  522. raise NotImplementedError("Not implemented")