Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

infinity_conn.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import logging
  2. import os
  3. import re
  4. import json
  5. import time
  6. import copy
  7. import infinity
  8. from infinity.common import ConflictType, InfinityException, SortType
  9. from infinity.index import IndexInfo, IndexType
  10. from infinity.connection_pool import ConnectionPool
  11. from infinity.errors import ErrorCode
  12. from rag import settings
  13. from rag.utils import singleton
  14. import polars as pl
  15. from polars.series.series import Series
  16. from api.utils.file_utils import get_project_base_directory
  17. from rag.utils.doc_store_conn import (
  18. DocStoreConnection,
  19. MatchExpr,
  20. MatchTextExpr,
  21. MatchDenseExpr,
  22. FusionExpr,
  23. OrderByExpr,
  24. )
  25. logger = logging.getLogger('ragflow.infinity_conn')
  26. def equivalent_condition_to_str(condition: dict) -> str|None:
  27. assert "_id" not in condition
  28. cond = list()
  29. for k, v in condition.items():
  30. if not isinstance(k, str) or k in ["kb_id"] or not v:
  31. continue
  32. if isinstance(v, list):
  33. inCond = list()
  34. for item in v:
  35. if isinstance(item, str):
  36. inCond.append(f"'{item}'")
  37. else:
  38. inCond.append(str(item))
  39. if inCond:
  40. strInCond = ", ".join(inCond)
  41. strInCond = f"{k} IN ({strInCond})"
  42. cond.append(strInCond)
  43. elif isinstance(v, str):
  44. cond.append(f"{k}='{v}'")
  45. else:
  46. cond.append(f"{k}={str(v)}")
  47. return " AND ".join(cond) if cond else "1=1"
  48. def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
  49. """
  50. Concatenate multiple dataframes into one.
  51. """
  52. df_list = [df for df in df_list if not df.is_empty()]
  53. if df_list:
  54. return pl.concat(df_list)
  55. schema = dict()
  56. for field_name in selectFields:
  57. if field_name == 'score()': # Workaround: fix schema is changed to score()
  58. schema['SCORE'] = str
  59. else:
  60. schema[field_name] = str
  61. return pl.DataFrame(schema=schema)
  62. @singleton
  63. class InfinityConnection(DocStoreConnection):
  64. def __init__(self):
  65. self.dbName = settings.INFINITY.get("db_name", "default_db")
  66. infinity_uri = settings.INFINITY["uri"]
  67. if ":" in infinity_uri:
  68. host, port = infinity_uri.split(":")
  69. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  70. self.connPool = None
  71. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  72. for _ in range(24):
  73. try:
  74. connPool = ConnectionPool(infinity_uri)
  75. inf_conn = connPool.get_conn()
  76. res = inf_conn.show_current_node()
  77. if res.error_code == ErrorCode.OK and res.server_status=="started":
  78. self._migrate_db(inf_conn)
  79. self.connPool = connPool
  80. connPool.release_conn(inf_conn)
  81. break
  82. connPool.release_conn(inf_conn)
  83. logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
  84. time.sleep(5)
  85. except Exception as e:
  86. logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  87. time.sleep(5)
  88. if self.connPool is None:
  89. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  90. logger.error(msg)
  91. raise Exception(msg)
  92. logger.info(f"Infinity {infinity_uri} is healthy.")
  93. def _migrate_db(self, inf_conn):
  94. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  95. fp_mapping = os.path.join(
  96. get_project_base_directory(), "conf", "infinity_mapping.json"
  97. )
  98. if not os.path.exists(fp_mapping):
  99. raise Exception(f"Mapping file not found at {fp_mapping}")
  100. schema = json.load(open(fp_mapping))
  101. table_names = inf_db.list_tables().table_names
  102. for table_name in table_names:
  103. inf_table = inf_db.get_table(table_name)
  104. index_names = inf_table.list_indexes().index_names
  105. if "q_vec_idx" not in index_names:
  106. # Skip tables not created by me
  107. continue
  108. column_names = inf_table.show_columns()["name"]
  109. column_names = set(column_names)
  110. for field_name, field_info in schema.items():
  111. if field_name in column_names:
  112. continue
  113. res = inf_table.add_columns({field_name: field_info})
  114. assert res.error_code == infinity.ErrorCode.OK
  115. logger.info(
  116. f"INFINITY added following column to table {table_name}: {field_name} {field_info}"
  117. )
  118. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  119. continue
  120. inf_table.create_index(
  121. f"text_idx_{field_name}",
  122. IndexInfo(
  123. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  124. ),
  125. ConflictType.Ignore,
  126. )
  127. """
  128. Database operations
  129. """
  130. def dbType(self) -> str:
  131. return "infinity"
  132. def health(self) -> dict:
  133. """
  134. Return the health status of the database.
  135. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  136. """
  137. inf_conn = self.connPool.get_conn()
  138. res = inf_conn.show_current_node()
  139. self.connPool.release_conn(inf_conn)
  140. res2 = {
  141. "type": "infinity",
  142. "status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
  143. "error": res.error_msg,
  144. }
  145. return res2
  146. """
  147. Table operations
  148. """
  149. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  150. table_name = f"{indexName}_{knowledgebaseId}"
  151. inf_conn = self.connPool.get_conn()
  152. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  153. fp_mapping = os.path.join(
  154. get_project_base_directory(), "conf", "infinity_mapping.json"
  155. )
  156. if not os.path.exists(fp_mapping):
  157. raise Exception(f"Mapping file not found at {fp_mapping}")
  158. schema = json.load(open(fp_mapping))
  159. vector_name = f"q_{vectorSize}_vec"
  160. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  161. inf_table = inf_db.create_table(
  162. table_name,
  163. schema,
  164. ConflictType.Ignore,
  165. )
  166. inf_table.create_index(
  167. "q_vec_idx",
  168. IndexInfo(
  169. vector_name,
  170. IndexType.Hnsw,
  171. {
  172. "M": "16",
  173. "ef_construction": "50",
  174. "metric": "cosine",
  175. "encode": "lvq",
  176. },
  177. ),
  178. ConflictType.Ignore,
  179. )
  180. for field_name, field_info in schema.items():
  181. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  182. continue
  183. inf_table.create_index(
  184. f"text_idx_{field_name}",
  185. IndexInfo(
  186. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  187. ),
  188. ConflictType.Ignore,
  189. )
  190. self.connPool.release_conn(inf_conn)
  191. logger.info(
  192. f"INFINITY created table {table_name}, vector size {vectorSize}"
  193. )
  194. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  195. table_name = f"{indexName}_{knowledgebaseId}"
  196. inf_conn = self.connPool.get_conn()
  197. db_instance = inf_conn.get_database(self.dbName)
  198. db_instance.drop_table(table_name, ConflictType.Ignore)
  199. self.connPool.release_conn(inf_conn)
  200. logger.info(f"INFINITY dropped table {table_name}")
  201. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  202. table_name = f"{indexName}_{knowledgebaseId}"
  203. try:
  204. inf_conn = self.connPool.get_conn()
  205. db_instance = inf_conn.get_database(self.dbName)
  206. _ = db_instance.get_table(table_name)
  207. self.connPool.release_conn(inf_conn)
  208. return True
  209. except Exception as e:
  210. logger.warning(f"INFINITY indexExist {str(e)}")
  211. return False
  212. """
  213. CRUD operations
  214. """
  215. def search(
  216. self,
  217. selectFields: list[str],
  218. highlightFields: list[str],
  219. condition: dict,
  220. matchExprs: list[MatchExpr],
  221. orderBy: OrderByExpr,
  222. offset: int,
  223. limit: int,
  224. indexNames: str | list[str],
  225. knowledgebaseIds: list[str],
  226. ) -> tuple[pl.DataFrame, int]:
  227. """
  228. TODO: Infinity doesn't provide highlight
  229. """
  230. if isinstance(indexNames, str):
  231. indexNames = indexNames.split(",")
  232. assert isinstance(indexNames, list) and len(indexNames) > 0
  233. inf_conn = self.connPool.get_conn()
  234. db_instance = inf_conn.get_database(self.dbName)
  235. df_list = list()
  236. table_list = list()
  237. for essential_field in ["id"]:
  238. if essential_field not in selectFields:
  239. selectFields.append(essential_field)
  240. if matchExprs:
  241. for essential_field in ["score()", "pagerank_fea"]:
  242. selectFields.append(essential_field)
  243. # Prepare expressions common to all tables
  244. filter_cond = None
  245. filter_fulltext = ""
  246. if condition:
  247. filter_cond = equivalent_condition_to_str(condition)
  248. for matchExpr in matchExprs:
  249. if isinstance(matchExpr, MatchTextExpr):
  250. if filter_cond and "filter" not in matchExpr.extra_options:
  251. matchExpr.extra_options.update({"filter": filter_cond})
  252. fields = ",".join(matchExpr.fields)
  253. filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  254. if filter_cond:
  255. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  256. minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
  257. if isinstance(minimum_should_match, float):
  258. str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  259. matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
  260. for k, v in matchExpr.extra_options.items():
  261. if not isinstance(v, str):
  262. matchExpr.extra_options[k] = str(v)
  263. logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
  264. elif isinstance(matchExpr, MatchDenseExpr):
  265. if filter_cond and "filter" not in matchExpr.extra_options:
  266. matchExpr.extra_options.update({"filter": filter_fulltext})
  267. for k, v in matchExpr.extra_options.items():
  268. if not isinstance(v, str):
  269. matchExpr.extra_options[k] = str(v)
  270. logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
  271. elif isinstance(matchExpr, FusionExpr):
  272. logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
  273. order_by_expr_list = list()
  274. if orderBy.fields:
  275. for order_field in orderBy.fields:
  276. if order_field[1] == 0:
  277. order_by_expr_list.append((order_field[0], SortType.Asc))
  278. else:
  279. order_by_expr_list.append((order_field[0], SortType.Desc))
  280. total_hits_count = 0
  281. # Scatter search tables and gather the results
  282. for indexName in indexNames:
  283. for knowledgebaseId in knowledgebaseIds:
  284. table_name = f"{indexName}_{knowledgebaseId}"
  285. try:
  286. table_instance = db_instance.get_table(table_name)
  287. except Exception:
  288. continue
  289. table_list.append(table_name)
  290. builder = table_instance.output(selectFields)
  291. if len(matchExprs) > 0:
  292. for matchExpr in matchExprs:
  293. if isinstance(matchExpr, MatchTextExpr):
  294. fields = ",".join(matchExpr.fields)
  295. builder = builder.match_text(
  296. fields,
  297. matchExpr.matching_text,
  298. matchExpr.topn,
  299. matchExpr.extra_options,
  300. )
  301. elif isinstance(matchExpr, MatchDenseExpr):
  302. builder = builder.match_dense(
  303. matchExpr.vector_column_name,
  304. matchExpr.embedding_data,
  305. matchExpr.embedding_data_type,
  306. matchExpr.distance_type,
  307. matchExpr.topn,
  308. matchExpr.extra_options,
  309. )
  310. elif isinstance(matchExpr, FusionExpr):
  311. builder = builder.fusion(
  312. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  313. )
  314. else:
  315. if len(filter_cond) > 0:
  316. builder.filter(filter_cond)
  317. if orderBy.fields:
  318. builder.sort(order_by_expr_list)
  319. builder.offset(offset).limit(limit)
  320. kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
  321. if extra_result:
  322. total_hits_count += int(extra_result["total_hits_count"])
  323. logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
  324. df_list.append(kb_res)
  325. self.connPool.release_conn(inf_conn)
  326. res = concat_dataframes(df_list, selectFields)
  327. if matchExprs:
  328. res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
  329. res = res.limit(limit)
  330. logger.debug(f"INFINITY search final result: {str(res)}")
  331. return res, total_hits_count
  332. def get(
  333. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  334. ) -> dict | None:
  335. inf_conn = self.connPool.get_conn()
  336. db_instance = inf_conn.get_database(self.dbName)
  337. df_list = list()
  338. assert isinstance(knowledgebaseIds, list)
  339. table_list = list()
  340. for knowledgebaseId in knowledgebaseIds:
  341. table_name = f"{indexName}_{knowledgebaseId}"
  342. table_list.append(table_name)
  343. table_instance = db_instance.get_table(table_name)
  344. kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  345. logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
  346. df_list.append(kb_res)
  347. self.connPool.release_conn(inf_conn)
  348. res = concat_dataframes(df_list, ["id"])
  349. res_fields = self.getFields(res, res.columns)
  350. return res_fields.get(chunkId, None)
  351. def insert(
  352. self, documents: list[dict], indexName: str, knowledgebaseId: str
  353. ) -> list[str]:
  354. inf_conn = self.connPool.get_conn()
  355. db_instance = inf_conn.get_database(self.dbName)
  356. table_name = f"{indexName}_{knowledgebaseId}"
  357. try:
  358. table_instance = db_instance.get_table(table_name)
  359. except InfinityException as e:
  360. # src/common/status.cppm, kTableNotExist = 3022
  361. if e.error_code != ErrorCode.TABLE_NOT_EXIST:
  362. raise
  363. vector_size = 0
  364. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  365. for k in documents[0].keys():
  366. m = patt.match(k)
  367. if m:
  368. vector_size = int(m.group("vector_size"))
  369. break
  370. if vector_size == 0:
  371. raise ValueError("Cannot infer vector size from documents")
  372. self.createIdx(indexName, knowledgebaseId, vector_size)
  373. table_instance = db_instance.get_table(table_name)
  374. docs = copy.deepcopy(documents)
  375. for d in docs:
  376. assert "_id" not in d
  377. assert "id" in d
  378. for k, v in d.items():
  379. if k in ["important_kwd", "question_kwd", "entities_kwd"]:
  380. assert isinstance(v, list)
  381. d[k] = "###".join(v)
  382. elif k == 'kb_id':
  383. if isinstance(d[k], list):
  384. d[k] = d[k][0] # since d[k] is a list, but we need a str
  385. elif k == "position_int":
  386. assert isinstance(v, list)
  387. arr = [num for row in v for num in row]
  388. d[k] = "_".join(f"{num:08x}" for num in arr)
  389. elif k in ["page_num_int", "top_int"]:
  390. assert isinstance(v, list)
  391. d[k] = "_".join(f"{num:08x}" for num in v)
  392. ids = ["'{}'".format(d["id"]) for d in docs]
  393. str_ids = ", ".join(ids)
  394. str_filter = f"id IN ({str_ids})"
  395. table_instance.delete(str_filter)
  396. # for doc in documents:
  397. # logger.info(f"insert position_int: {doc['position_int']}")
  398. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  399. table_instance.insert(docs)
  400. self.connPool.release_conn(inf_conn)
  401. logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
  402. return []
  403. def update(
  404. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  405. ) -> bool:
  406. # if 'position_int' in newValue:
  407. # logger.info(f"update position_int: {newValue['position_int']}")
  408. inf_conn = self.connPool.get_conn()
  409. db_instance = inf_conn.get_database(self.dbName)
  410. table_name = f"{indexName}_{knowledgebaseId}"
  411. table_instance = db_instance.get_table(table_name)
  412. if "exist" in condition:
  413. del condition["exist"]
  414. filter = equivalent_condition_to_str(condition)
  415. for k, v in list(newValue.items()):
  416. if k.endswith("_kwd") and isinstance(v, list):
  417. newValue[k] = " ".join(v)
  418. elif k == 'kb_id':
  419. if isinstance(newValue[k], list):
  420. newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
  421. elif k == "position_int":
  422. assert isinstance(v, list)
  423. arr = [num for row in v for num in row]
  424. newValue[k] = "_".join(f"{num:08x}" for num in arr)
  425. elif k in ["page_num_int", "top_int"]:
  426. assert isinstance(v, list)
  427. newValue[k] = "_".join(f"{num:08x}" for num in v)
  428. elif k == "remove" and v in ["pagerank_fea"]:
  429. del newValue[k]
  430. newValue[v] = 0
  431. logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
  432. table_instance.update(filter, newValue)
  433. self.connPool.release_conn(inf_conn)
  434. return True
  435. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  436. inf_conn = self.connPool.get_conn()
  437. db_instance = inf_conn.get_database(self.dbName)
  438. table_name = f"{indexName}_{knowledgebaseId}"
  439. filter = equivalent_condition_to_str(condition)
  440. try:
  441. table_instance = db_instance.get_table(table_name)
  442. except Exception:
  443. logger.warning(
  444. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  445. )
  446. return 0
  447. logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
  448. res = table_instance.delete(filter)
  449. self.connPool.release_conn(inf_conn)
  450. return res.deleted_rows
  451. """
  452. Helper functions for search result
  453. """
  454. def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
  455. if isinstance(res, tuple):
  456. return res[1]
  457. return len(res)
  458. def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
  459. if isinstance(res, tuple):
  460. res = res[0]
  461. return list(res["id"])
  462. def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
  463. if isinstance(res, tuple):
  464. res = res[0]
  465. res_fields = {}
  466. if not fields:
  467. return {}
  468. num_rows = len(res)
  469. column_id = res["id"]
  470. for i in range(num_rows):
  471. id = column_id[i]
  472. m = {"id": id}
  473. for fieldnm in fields:
  474. if fieldnm not in res:
  475. m[fieldnm] = None
  476. continue
  477. v = res[fieldnm][i]
  478. if isinstance(v, Series):
  479. v = list(v)
  480. elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd"]:
  481. assert isinstance(v, str)
  482. v = [kwd for kwd in v.split("###") if kwd]
  483. elif fieldnm == "position_int":
  484. assert isinstance(v, str)
  485. if v:
  486. arr = [int(hex_val, 16) for hex_val in v.split('_')]
  487. v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
  488. else:
  489. v = []
  490. elif fieldnm in ["page_num_int", "top_int"]:
  491. assert isinstance(v, str)
  492. if v:
  493. v = [int(hex_val, 16) for hex_val in v.split('_')]
  494. else:
  495. v = []
  496. else:
  497. if not isinstance(v, str):
  498. v = str(v)
  499. # if fieldnm.endswith("_tks"):
  500. # v = rmSpace(v)
  501. m[fieldnm] = v
  502. res_fields[id] = m
  503. return res_fields
  504. def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
  505. if isinstance(res, tuple):
  506. res = res[0]
  507. ans = {}
  508. num_rows = len(res)
  509. column_id = res["id"]
  510. for i in range(num_rows):
  511. id = column_id[i]
  512. txt = res[fieldnm][i]
  513. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  514. txts = []
  515. for t in re.split(r"[.?!;\n]", txt):
  516. for w in keywords:
  517. t = re.sub(
  518. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  519. % re.escape(w),
  520. r"\1<em>\2</em>\3",
  521. t,
  522. flags=re.IGNORECASE | re.MULTILINE,
  523. )
  524. if not re.search(
  525. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  526. ):
  527. continue
  528. txts.append(t)
  529. ans[id] = "...".join(txts)
  530. return ans
  531. def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
  532. """
  533. TODO: Infinity doesn't provide aggregation
  534. """
  535. return list()
  536. """
  537. SQL
  538. """
  539. def sql(sql: str, fetch_size: int, format: str):
  540. raise NotImplementedError("Not implemented")