Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

infinity_conn.py 20KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import logging
  2. import os
  3. import re
  4. import json
  5. import time
  6. import infinity
  7. from infinity.common import ConflictType, InfinityException, SortType
  8. from infinity.index import IndexInfo, IndexType
  9. from infinity.connection_pool import ConnectionPool
  10. from infinity.errors import ErrorCode
  11. from rag import settings
  12. from rag.utils import singleton
  13. import polars as pl
  14. from polars.series.series import Series
  15. from api.utils.file_utils import get_project_base_directory
  16. from rag.utils.doc_store_conn import (
  17. DocStoreConnection,
  18. MatchExpr,
  19. MatchTextExpr,
  20. MatchDenseExpr,
  21. FusionExpr,
  22. OrderByExpr,
  23. )
  24. logger = logging.getLogger('ragflow.infinity_conn')
  25. def equivalent_condition_to_str(condition: dict) -> str:
  26. assert "_id" not in condition
  27. cond = list()
  28. for k, v in condition.items():
  29. if not isinstance(k, str) or not v:
  30. continue
  31. if isinstance(v, list):
  32. inCond = list()
  33. for item in v:
  34. if isinstance(item, str):
  35. inCond.append(f"'{item}'")
  36. else:
  37. inCond.append(str(item))
  38. if inCond:
  39. strInCond = ", ".join(inCond)
  40. strInCond = f"{k} IN ({strInCond})"
  41. cond.append(strInCond)
  42. elif isinstance(v, str):
  43. cond.append(f"{k}='{v}'")
  44. else:
  45. cond.append(f"{k}={str(v)}")
  46. return " AND ".join(cond)
  47. def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
  48. """
  49. Concatenate multiple dataframes into one.
  50. """
  51. if df_list:
  52. return pl.concat(df_list)
  53. schema = dict()
  54. for fieldnm in selectFields:
  55. schema[fieldnm] = str
  56. return pl.DataFrame(schema=schema)
  57. @singleton
  58. class InfinityConnection(DocStoreConnection):
  59. def __init__(self):
  60. self.dbName = settings.INFINITY.get("db_name", "default_db")
  61. infinity_uri = settings.INFINITY["uri"]
  62. if ":" in infinity_uri:
  63. host, port = infinity_uri.split(":")
  64. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  65. self.connPool = None
  66. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  67. for _ in range(24):
  68. try:
  69. connPool = ConnectionPool(infinity_uri)
  70. inf_conn = connPool.get_conn()
  71. res = inf_conn.show_current_node()
  72. connPool.release_conn(inf_conn)
  73. self.connPool = connPool
  74. if res.error_code == ErrorCode.OK and res.server_status=="started":
  75. break
  76. logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
  77. time.sleep(5)
  78. except Exception as e:
  79. logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  80. time.sleep(5)
  81. if self.connPool is None:
  82. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  83. logger.error(msg)
  84. raise Exception(msg)
  85. logger.info(f"Infinity {infinity_uri} is healthy.")
  86. """
  87. Database operations
  88. """
  89. def dbType(self) -> str:
  90. return "infinity"
  91. def health(self) -> dict:
  92. """
  93. Return the health status of the database.
  94. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  95. """
  96. inf_conn = self.connPool.get_conn()
  97. res = inf_conn.show_current_node()
  98. self.connPool.release_conn(inf_conn)
  99. res2 = {
  100. "type": "infinity",
  101. "status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
  102. "error": res.error_msg,
  103. }
  104. return res2
  105. """
  106. Table operations
  107. """
  108. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  109. table_name = f"{indexName}_{knowledgebaseId}"
  110. inf_conn = self.connPool.get_conn()
  111. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  112. fp_mapping = os.path.join(
  113. get_project_base_directory(), "conf", "infinity_mapping.json"
  114. )
  115. if not os.path.exists(fp_mapping):
  116. raise Exception(f"Mapping file not found at {fp_mapping}")
  117. schema = json.load(open(fp_mapping))
  118. vector_name = f"q_{vectorSize}_vec"
  119. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  120. inf_table = inf_db.create_table(
  121. table_name,
  122. schema,
  123. ConflictType.Ignore,
  124. )
  125. inf_table.create_index(
  126. "q_vec_idx",
  127. IndexInfo(
  128. vector_name,
  129. IndexType.Hnsw,
  130. {
  131. "M": "16",
  132. "ef_construction": "50",
  133. "metric": "cosine",
  134. "encode": "lvq",
  135. },
  136. ),
  137. ConflictType.Ignore,
  138. )
  139. text_suffix = ["_tks", "_ltks", "_kwd"]
  140. for field_name, field_info in schema.items():
  141. if field_info["type"] != "varchar":
  142. continue
  143. for suffix in text_suffix:
  144. if field_name.endswith(suffix):
  145. inf_table.create_index(
  146. f"text_idx_{field_name}",
  147. IndexInfo(
  148. field_name, IndexType.FullText, {"ANALYZER": "standard"}
  149. ),
  150. ConflictType.Ignore,
  151. )
  152. break
  153. self.connPool.release_conn(inf_conn)
  154. logger.info(
  155. f"INFINITY created table {table_name}, vector size {vectorSize}"
  156. )
  157. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  158. table_name = f"{indexName}_{knowledgebaseId}"
  159. inf_conn = self.connPool.get_conn()
  160. db_instance = inf_conn.get_database(self.dbName)
  161. db_instance.drop_table(table_name, ConflictType.Ignore)
  162. self.connPool.release_conn(inf_conn)
  163. logger.info(f"INFINITY dropped table {table_name}")
  164. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  165. table_name = f"{indexName}_{knowledgebaseId}"
  166. try:
  167. inf_conn = self.connPool.get_conn()
  168. db_instance = inf_conn.get_database(self.dbName)
  169. _ = db_instance.get_table(table_name)
  170. self.connPool.release_conn(inf_conn)
  171. return True
  172. except Exception as e:
  173. logger.warning(f"INFINITY indexExist {str(e)}")
  174. return False
  175. """
  176. CRUD operations
  177. """
  178. def search(
  179. self,
  180. selectFields: list[str],
  181. highlightFields: list[str],
  182. condition: dict,
  183. matchExprs: list[MatchExpr],
  184. orderBy: OrderByExpr,
  185. offset: int,
  186. limit: int,
  187. indexNames: str | list[str],
  188. knowledgebaseIds: list[str],
  189. ) -> list[dict] | pl.DataFrame:
  190. """
  191. TODO: Infinity doesn't provide highlight
  192. """
  193. if isinstance(indexNames, str):
  194. indexNames = indexNames.split(",")
  195. assert isinstance(indexNames, list) and len(indexNames) > 0
  196. inf_conn = self.connPool.get_conn()
  197. db_instance = inf_conn.get_database(self.dbName)
  198. df_list = list()
  199. table_list = list()
  200. if "id" not in selectFields:
  201. selectFields.append("id")
  202. # Prepare expressions common to all tables
  203. filter_cond = ""
  204. filter_fulltext = ""
  205. if condition:
  206. filter_cond = equivalent_condition_to_str(condition)
  207. for matchExpr in matchExprs:
  208. if isinstance(matchExpr, MatchTextExpr):
  209. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  210. matchExpr.extra_options.update({"filter": filter_cond})
  211. fields = ",".join(matchExpr.fields)
  212. filter_fulltext = (
  213. f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  214. )
  215. if len(filter_cond) != 0:
  216. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  217. logger.debug(f"filter_fulltext: {filter_fulltext}")
  218. minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
  219. if isinstance(minimum_should_match, float):
  220. str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  221. matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
  222. for k, v in matchExpr.extra_options.items():
  223. if not isinstance(v, str):
  224. matchExpr.extra_options[k] = str(v)
  225. elif isinstance(matchExpr, MatchDenseExpr):
  226. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  227. matchExpr.extra_options.update({"filter": filter_fulltext})
  228. for k, v in matchExpr.extra_options.items():
  229. if not isinstance(v, str):
  230. matchExpr.extra_options[k] = str(v)
  231. order_by_expr_list = list()
  232. if orderBy.fields:
  233. for order_field in orderBy.fields:
  234. if order_field[1] == 0:
  235. order_by_expr_list.append((order_field[0], SortType.Asc))
  236. else:
  237. order_by_expr_list.append((order_field[0], SortType.Desc))
  238. # Scatter search tables and gather the results
  239. for indexName in indexNames:
  240. for knowledgebaseId in knowledgebaseIds:
  241. table_name = f"{indexName}_{knowledgebaseId}"
  242. try:
  243. table_instance = db_instance.get_table(table_name)
  244. except Exception:
  245. continue
  246. table_list.append(table_name)
  247. builder = table_instance.output(selectFields)
  248. if len(matchExprs) > 0:
  249. for matchExpr in matchExprs:
  250. if isinstance(matchExpr, MatchTextExpr):
  251. fields = ",".join(matchExpr.fields)
  252. builder = builder.match_text(
  253. fields,
  254. matchExpr.matching_text,
  255. matchExpr.topn,
  256. matchExpr.extra_options,
  257. )
  258. elif isinstance(matchExpr, MatchDenseExpr):
  259. builder = builder.match_dense(
  260. matchExpr.vector_column_name,
  261. matchExpr.embedding_data,
  262. matchExpr.embedding_data_type,
  263. matchExpr.distance_type,
  264. matchExpr.topn,
  265. matchExpr.extra_options,
  266. )
  267. elif isinstance(matchExpr, FusionExpr):
  268. builder = builder.fusion(
  269. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  270. )
  271. else:
  272. if len(filter_cond) > 0:
  273. builder.filter(filter_cond)
  274. if orderBy.fields:
  275. builder.sort(order_by_expr_list)
  276. builder.offset(offset).limit(limit)
  277. kb_res = builder.to_pl()
  278. df_list.append(kb_res)
  279. self.connPool.release_conn(inf_conn)
  280. res = concat_dataframes(df_list, selectFields)
  281. logger.debug(f"INFINITY search tables: {str(table_list)}, result: {str(res)}")
  282. return res
  283. def get(
  284. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  285. ) -> dict | None:
  286. inf_conn = self.connPool.get_conn()
  287. db_instance = inf_conn.get_database(self.dbName)
  288. df_list = list()
  289. assert isinstance(knowledgebaseIds, list)
  290. table_list = list()
  291. for knowledgebaseId in knowledgebaseIds:
  292. table_name = f"{indexName}_{knowledgebaseId}"
  293. table_list.append(table_name)
  294. table_instance = db_instance.get_table(table_name)
  295. kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  296. if len(kb_res) != 0 and kb_res.shape[0] > 0:
  297. df_list.append(kb_res)
  298. self.connPool.release_conn(inf_conn)
  299. res = concat_dataframes(df_list, ["id"])
  300. logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
  301. res_fields = self.getFields(res, res.columns)
  302. return res_fields.get(chunkId, None)
  303. def insert(
  304. self, documents: list[dict], indexName: str, knowledgebaseId: str
  305. ) -> list[str]:
  306. inf_conn = self.connPool.get_conn()
  307. db_instance = inf_conn.get_database(self.dbName)
  308. table_name = f"{indexName}_{knowledgebaseId}"
  309. try:
  310. table_instance = db_instance.get_table(table_name)
  311. except InfinityException as e:
  312. # src/common/status.cppm, kTableNotExist = 3022
  313. if e.error_code != ErrorCode.TABLE_NOT_EXIST:
  314. raise
  315. vector_size = 0
  316. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  317. for k in documents[0].keys():
  318. m = patt.match(k)
  319. if m:
  320. vector_size = int(m.group("vector_size"))
  321. break
  322. if vector_size == 0:
  323. raise ValueError("Cannot infer vector size from documents")
  324. self.createIdx(indexName, knowledgebaseId, vector_size)
  325. table_instance = db_instance.get_table(table_name)
  326. for d in documents:
  327. assert "_id" not in d
  328. assert "id" in d
  329. for k, v in d.items():
  330. if k in ["important_kwd", "question_kwd", "entities_kwd"]:
  331. assert isinstance(v, list)
  332. d[k] = "###".join(v)
  333. elif k == 'kb_id':
  334. if isinstance(d[k], list):
  335. d[k] = d[k][0] # since d[k] is a list, but we need a str
  336. elif k == "position_int":
  337. assert isinstance(v, list)
  338. arr = [num for row in v for num in row]
  339. d[k] = "_".join(f"{num:08x}" for num in arr)
  340. elif k in ["page_num_int", "top_int", "position_int"]:
  341. assert isinstance(v, list)
  342. d[k] = "_".join(f"{num:08x}" for num in v)
  343. ids = ["'{}'".format(d["id"]) for d in documents]
  344. str_ids = ", ".join(ids)
  345. str_filter = f"id IN ({str_ids})"
  346. table_instance.delete(str_filter)
  347. # for doc in documents:
  348. # logger.info(f"insert position_int: {doc['position_int']}")
  349. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  350. table_instance.insert(documents)
  351. self.connPool.release_conn(inf_conn)
  352. logger.debug(f"inserted into {table_name} {str_ids}.")
  353. return []
  354. def update(
  355. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  356. ) -> bool:
  357. # if 'position_int' in newValue:
  358. # logger.info(f"update position_int: {newValue['position_int']}")
  359. inf_conn = self.connPool.get_conn()
  360. db_instance = inf_conn.get_database(self.dbName)
  361. table_name = f"{indexName}_{knowledgebaseId}"
  362. table_instance = db_instance.get_table(table_name)
  363. filter = equivalent_condition_to_str(condition)
  364. for k, v in newValue.items():
  365. if k.endswith("_kwd") and isinstance(v, list):
  366. newValue[k] = " ".join(v)
  367. elif k == 'kb_id':
  368. if isinstance(newValue[k], list):
  369. newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
  370. elif k == "position_int":
  371. assert isinstance(v, list)
  372. arr = [num for row in v for num in row]
  373. newValue[k] = "_".join(f"{num:08x}" for num in arr)
  374. elif k in ["page_num_int", "top_int"]:
  375. assert isinstance(v, list)
  376. newValue[k] = "_".join(f"{num:08x}" for num in v)
  377. table_instance.update(filter, newValue)
  378. self.connPool.release_conn(inf_conn)
  379. return True
  380. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  381. inf_conn = self.connPool.get_conn()
  382. db_instance = inf_conn.get_database(self.dbName)
  383. table_name = f"{indexName}_{knowledgebaseId}"
  384. filter = equivalent_condition_to_str(condition)
  385. try:
  386. table_instance = db_instance.get_table(table_name)
  387. except Exception:
  388. logger.warning(
  389. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  390. )
  391. return 0
  392. res = table_instance.delete(filter)
  393. self.connPool.release_conn(inf_conn)
  394. return res.deleted_rows
  395. """
  396. Helper functions for search result
  397. """
  398. def getTotal(self, res):
  399. return len(res)
  400. def getChunkIds(self, res):
  401. return list(res["id"])
  402. def getFields(self, res, fields: list[str]) -> list[str, dict]:
  403. res_fields = {}
  404. if not fields:
  405. return {}
  406. num_rows = len(res)
  407. column_id = res["id"]
  408. for i in range(num_rows):
  409. id = column_id[i]
  410. m = {"id": id}
  411. for fieldnm in fields:
  412. if fieldnm not in res:
  413. m[fieldnm] = None
  414. continue
  415. v = res[fieldnm][i]
  416. if isinstance(v, Series):
  417. v = list(v)
  418. elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd"]:
  419. assert isinstance(v, str)
  420. v = [kwd for kwd in v.split("###") if kwd]
  421. elif fieldnm == "position_int":
  422. assert isinstance(v, str)
  423. if v:
  424. arr = [int(hex_val, 16) for hex_val in v.split('_')]
  425. v = [arr[i:i + 4] for i in range(0, len(arr), 4)]
  426. else:
  427. v = []
  428. elif fieldnm in ["page_num_int", "top_int"]:
  429. assert isinstance(v, str)
  430. if v:
  431. v = [int(hex_val, 16) for hex_val in v.split('_')]
  432. else:
  433. v = []
  434. else:
  435. if not isinstance(v, str):
  436. v = str(v)
  437. # if fieldnm.endswith("_tks"):
  438. # v = rmSpace(v)
  439. m[fieldnm] = v
  440. res_fields[id] = m
  441. return res_fields
  442. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  443. ans = {}
  444. num_rows = len(res)
  445. column_id = res["id"]
  446. for i in range(num_rows):
  447. id = column_id[i]
  448. txt = res[fieldnm][i]
  449. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  450. txts = []
  451. for t in re.split(r"[.?!;\n]", txt):
  452. for w in keywords:
  453. t = re.sub(
  454. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  455. % re.escape(w),
  456. r"\1<em>\2</em>\3",
  457. t,
  458. flags=re.IGNORECASE | re.MULTILINE,
  459. )
  460. if not re.search(
  461. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  462. ):
  463. continue
  464. txts.append(t)
  465. ans[id] = "...".join(txts)
  466. return ans
  467. def getAggregation(self, res, fieldnm: str):
  468. """
  469. TODO: Infinity doesn't provide aggregation
  470. """
  471. return list()
  472. """
  473. SQL
  474. """
  475. def sql(sql: str, fetch_size: int, format: str):
  476. raise NotImplementedError("Not implemented")