You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

infinity_conn.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import os
  2. import re
  3. import json
  4. import time
  5. from typing import List, Dict
  6. import infinity
  7. from infinity.common import ConflictType, InfinityException
  8. from infinity.index import IndexInfo, IndexType
  9. from infinity.connection_pool import ConnectionPool
  10. from api.utils.log_utils import logger
  11. from rag import settings
  12. from rag.utils import singleton
  13. import polars as pl
  14. from polars.series.series import Series
  15. from api.utils.file_utils import get_project_base_directory
  16. from rag.utils.doc_store_conn import (
  17. DocStoreConnection,
  18. MatchExpr,
  19. MatchTextExpr,
  20. MatchDenseExpr,
  21. FusionExpr,
  22. OrderByExpr,
  23. )
  24. def equivalent_condition_to_str(condition: dict) -> str:
  25. assert "_id" not in condition
  26. cond = list()
  27. for k, v in condition.items():
  28. if not isinstance(k, str) or not v:
  29. continue
  30. if isinstance(v, list):
  31. inCond = list()
  32. for item in v:
  33. if isinstance(item, str):
  34. inCond.append(f"'{item}'")
  35. else:
  36. inCond.append(str(item))
  37. if inCond:
  38. strInCond = ", ".join(inCond)
  39. strInCond = f"{k} IN ({strInCond})"
  40. cond.append(strInCond)
  41. elif isinstance(v, str):
  42. cond.append(f"{k}='{v}'")
  43. else:
  44. cond.append(f"{k}={str(v)}")
  45. return " AND ".join(cond)
  46. @singleton
  47. class InfinityConnection(DocStoreConnection):
  48. def __init__(self):
  49. self.dbName = settings.INFINITY.get("db_name", "default_db")
  50. infinity_uri = settings.INFINITY["uri"]
  51. if ":" in infinity_uri:
  52. host, port = infinity_uri.split(":")
  53. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  54. self.connPool = None
  55. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  56. for _ in range(24):
  57. try:
  58. connPool = ConnectionPool(infinity_uri)
  59. inf_conn = connPool.get_conn()
  60. _ = inf_conn.show_current_node()
  61. connPool.release_conn(inf_conn)
  62. self.connPool = connPool
  63. break
  64. except Exception as e:
  65. logger.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  66. time.sleep(5)
  67. if self.connPool is None:
  68. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  69. logger.error(msg)
  70. raise Exception(msg)
  71. logger.info(f"Infinity {infinity_uri} is healthy.")
  72. """
  73. Database operations
  74. """
  75. def dbType(self) -> str:
  76. return "infinity"
  77. def health(self) -> dict:
  78. """
  79. Return the health status of the database.
  80. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  81. """
  82. inf_conn = self.connPool.get_conn()
  83. res = inf_conn.show_current_node()
  84. self.connPool.release_conn(inf_conn)
  85. color = "green" if res.error_code == 0 else "red"
  86. res2 = {
  87. "type": "infinity",
  88. "status": f"{res.role} {color}",
  89. "error": res.error_msg,
  90. }
  91. return res2
  92. """
  93. Table operations
  94. """
  95. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  96. table_name = f"{indexName}_{knowledgebaseId}"
  97. inf_conn = self.connPool.get_conn()
  98. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  99. fp_mapping = os.path.join(
  100. get_project_base_directory(), "conf", "infinity_mapping.json"
  101. )
  102. if not os.path.exists(fp_mapping):
  103. raise Exception(f"Mapping file not found at {fp_mapping}")
  104. schema = json.load(open(fp_mapping))
  105. vector_name = f"q_{vectorSize}_vec"
  106. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  107. inf_table = inf_db.create_table(
  108. table_name,
  109. schema,
  110. ConflictType.Ignore,
  111. )
  112. inf_table.create_index(
  113. "q_vec_idx",
  114. IndexInfo(
  115. vector_name,
  116. IndexType.Hnsw,
  117. {
  118. "M": "16",
  119. "ef_construction": "50",
  120. "metric": "cosine",
  121. "encode": "lvq",
  122. },
  123. ),
  124. ConflictType.Ignore,
  125. )
  126. text_suffix = ["_tks", "_ltks", "_kwd"]
  127. for field_name, field_info in schema.items():
  128. if field_info["type"] != "varchar":
  129. continue
  130. for suffix in text_suffix:
  131. if field_name.endswith(suffix):
  132. inf_table.create_index(
  133. f"text_idx_{field_name}",
  134. IndexInfo(
  135. field_name, IndexType.FullText, {"ANALYZER": "standard"}
  136. ),
  137. ConflictType.Ignore,
  138. )
  139. break
  140. self.connPool.release_conn(inf_conn)
  141. logger.info(
  142. f"INFINITY created table {table_name}, vector size {vectorSize}"
  143. )
  144. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  145. table_name = f"{indexName}_{knowledgebaseId}"
  146. inf_conn = self.connPool.get_conn()
  147. db_instance = inf_conn.get_database(self.dbName)
  148. db_instance.drop_table(table_name, ConflictType.Ignore)
  149. self.connPool.release_conn(inf_conn)
  150. logger.info(f"INFINITY dropped table {table_name}")
  151. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  152. table_name = f"{indexName}_{knowledgebaseId}"
  153. try:
  154. inf_conn = self.connPool.get_conn()
  155. db_instance = inf_conn.get_database(self.dbName)
  156. _ = db_instance.get_table(table_name)
  157. self.connPool.release_conn(inf_conn)
  158. return True
  159. except Exception as e:
  160. logger.warn(f"INFINITY indexExist {str(e)}")
  161. return False
  162. """
  163. CRUD operations
  164. """
  165. def search(
  166. self,
  167. selectFields: list[str],
  168. highlightFields: list[str],
  169. condition: dict,
  170. matchExprs: list[MatchExpr],
  171. orderBy: OrderByExpr,
  172. offset: int,
  173. limit: int,
  174. indexNames: str|list[str],
  175. knowledgebaseIds: list[str],
  176. ) -> list[dict] | pl.DataFrame:
  177. """
  178. TODO: Infinity doesn't provide highlight
  179. """
  180. if isinstance(indexNames, str):
  181. indexNames = indexNames.split(",")
  182. assert isinstance(indexNames, list) and len(indexNames) > 0
  183. inf_conn = self.connPool.get_conn()
  184. db_instance = inf_conn.get_database(self.dbName)
  185. df_list = list()
  186. table_list = list()
  187. if "id" not in selectFields:
  188. selectFields.append("id")
  189. # Prepare expressions common to all tables
  190. filter_cond = ""
  191. filter_fulltext = ""
  192. if condition:
  193. filter_cond = equivalent_condition_to_str(condition)
  194. for matchExpr in matchExprs:
  195. if isinstance(matchExpr, MatchTextExpr):
  196. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  197. matchExpr.extra_options.update({"filter": filter_cond})
  198. fields = ",".join(matchExpr.fields)
  199. filter_fulltext = (
  200. f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  201. )
  202. if len(filter_cond) != 0:
  203. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  204. # logger.info(f"filter_fulltext: {filter_fulltext}")
  205. minimum_should_match = "0%"
  206. if "minimum_should_match" in matchExpr.extra_options:
  207. minimum_should_match = (
  208. str(int(matchExpr.extra_options["minimum_should_match"] * 100))
  209. + "%"
  210. )
  211. matchExpr.extra_options.update(
  212. {"minimum_should_match": minimum_should_match}
  213. )
  214. for k, v in matchExpr.extra_options.items():
  215. if not isinstance(v, str):
  216. matchExpr.extra_options[k] = str(v)
  217. elif isinstance(matchExpr, MatchDenseExpr):
  218. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  219. matchExpr.extra_options.update({"filter": filter_fulltext})
  220. for k, v in matchExpr.extra_options.items():
  221. if not isinstance(v, str):
  222. matchExpr.extra_options[k] = str(v)
  223. if orderBy.fields:
  224. order_by_expr_list = list()
  225. for order_field in orderBy.fields:
  226. order_by_expr_list.append((order_field[0], order_field[1] == 0))
  227. # Scatter search tables and gather the results
  228. for indexName in indexNames:
  229. for knowledgebaseId in knowledgebaseIds:
  230. table_name = f"{indexName}_{knowledgebaseId}"
  231. try:
  232. table_instance = db_instance.get_table(table_name)
  233. except Exception:
  234. continue
  235. table_list.append(table_name)
  236. builder = table_instance.output(selectFields)
  237. for matchExpr in matchExprs:
  238. if isinstance(matchExpr, MatchTextExpr):
  239. fields = ",".join(matchExpr.fields)
  240. builder = builder.match_text(
  241. fields,
  242. matchExpr.matching_text,
  243. matchExpr.topn,
  244. matchExpr.extra_options,
  245. )
  246. elif isinstance(matchExpr, MatchDenseExpr):
  247. builder = builder.match_dense(
  248. matchExpr.vector_column_name,
  249. matchExpr.embedding_data,
  250. matchExpr.embedding_data_type,
  251. matchExpr.distance_type,
  252. matchExpr.topn,
  253. matchExpr.extra_options,
  254. )
  255. elif isinstance(matchExpr, FusionExpr):
  256. builder = builder.fusion(
  257. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  258. )
  259. if orderBy.fields:
  260. builder.sort(order_by_expr_list)
  261. builder.offset(offset).limit(limit)
  262. kb_res = builder.to_pl()
  263. df_list.append(kb_res)
  264. self.connPool.release_conn(inf_conn)
  265. res = pl.concat(df_list)
  266. logger.info("INFINITY search tables: " + str(table_list))
  267. return res
  268. def get(
  269. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  270. ) -> dict | None:
  271. inf_conn = self.connPool.get_conn()
  272. db_instance = inf_conn.get_database(self.dbName)
  273. df_list = list()
  274. assert isinstance(knowledgebaseIds, list)
  275. for knowledgebaseId in knowledgebaseIds:
  276. table_name = f"{indexName}_{knowledgebaseId}"
  277. table_instance = db_instance.get_table(table_name)
  278. kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  279. df_list.append(kb_res)
  280. self.connPool.release_conn(inf_conn)
  281. res = pl.concat(df_list)
  282. res_fields = self.getFields(res, res.columns)
  283. return res_fields.get(chunkId, None)
  284. def insert(
  285. self, documents: list[dict], indexName: str, knowledgebaseId: str
  286. ) -> list[str]:
  287. inf_conn = self.connPool.get_conn()
  288. db_instance = inf_conn.get_database(self.dbName)
  289. table_name = f"{indexName}_{knowledgebaseId}"
  290. try:
  291. table_instance = db_instance.get_table(table_name)
  292. except InfinityException as e:
  293. # src/common/status.cppm, kTableNotExist = 3022
  294. if e.error_code != 3022:
  295. raise
  296. vector_size = 0
  297. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  298. for k in documents[0].keys():
  299. m = patt.match(k)
  300. if m:
  301. vector_size = int(m.group("vector_size"))
  302. break
  303. if vector_size == 0:
  304. raise ValueError("Cannot infer vector size from documents")
  305. self.createIdx(indexName, knowledgebaseId, vector_size)
  306. table_instance = db_instance.get_table(table_name)
  307. for d in documents:
  308. assert "_id" not in d
  309. assert "id" in d
  310. for k, v in d.items():
  311. if k.endswith("_kwd") and isinstance(v, list):
  312. d[k] = " ".join(v)
  313. ids = ["'{}'".format(d["id"]) for d in documents]
  314. str_ids = ", ".join(ids)
  315. str_filter = f"id IN ({str_ids})"
  316. table_instance.delete(str_filter)
  317. # for doc in documents:
  318. # logger.info(f"insert position_list: {doc['position_list']}")
  319. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  320. table_instance.insert(documents)
  321. self.connPool.release_conn(inf_conn)
  322. logger.info(f"inserted into {table_name} {str_ids}.")
  323. return []
  324. def update(
  325. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  326. ) -> bool:
  327. # if 'position_list' in newValue:
  328. # logger.info(f"upsert position_list: {newValue['position_list']}")
  329. inf_conn = self.connPool.get_conn()
  330. db_instance = inf_conn.get_database(self.dbName)
  331. table_name = f"{indexName}_{knowledgebaseId}"
  332. table_instance = db_instance.get_table(table_name)
  333. filter = equivalent_condition_to_str(condition)
  334. for k, v in newValue.items():
  335. if k.endswith("_kwd") and isinstance(v, list):
  336. newValue[k] = " ".join(v)
  337. table_instance.update(filter, newValue)
  338. self.connPool.release_conn(inf_conn)
  339. return True
  340. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  341. inf_conn = self.connPool.get_conn()
  342. db_instance = inf_conn.get_database(self.dbName)
  343. table_name = f"{indexName}_{knowledgebaseId}"
  344. filter = equivalent_condition_to_str(condition)
  345. try:
  346. table_instance = db_instance.get_table(table_name)
  347. except Exception:
  348. logger.warning(
  349. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  350. )
  351. return 0
  352. res = table_instance.delete(filter)
  353. self.connPool.release_conn(inf_conn)
  354. return res.deleted_rows
  355. """
  356. Helper functions for search result
  357. """
  358. def getTotal(self, res):
  359. return len(res)
  360. def getChunkIds(self, res):
  361. return list(res["id"])
  362. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  363. res_fields = {}
  364. if not fields:
  365. return {}
  366. num_rows = len(res)
  367. column_id = res["id"]
  368. for i in range(num_rows):
  369. id = column_id[i]
  370. m = {"id": id}
  371. for fieldnm in fields:
  372. if fieldnm not in res:
  373. m[fieldnm] = None
  374. continue
  375. v = res[fieldnm][i]
  376. if isinstance(v, Series):
  377. v = list(v)
  378. elif fieldnm == "important_kwd":
  379. assert isinstance(v, str)
  380. v = v.split(" ")
  381. else:
  382. if not isinstance(v, str):
  383. v = str(v)
  384. # if fieldnm.endswith("_tks"):
  385. # v = rmSpace(v)
  386. m[fieldnm] = v
  387. res_fields[id] = m
  388. return res_fields
  389. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  390. ans = {}
  391. num_rows = len(res)
  392. column_id = res["id"]
  393. for i in range(num_rows):
  394. id = column_id[i]
  395. txt = res[fieldnm][i]
  396. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  397. txts = []
  398. for t in re.split(r"[.?!;\n]", txt):
  399. for w in keywords:
  400. t = re.sub(
  401. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  402. % re.escape(w),
  403. r"\1<em>\2</em>\3",
  404. t,
  405. flags=re.IGNORECASE | re.MULTILINE,
  406. )
  407. if not re.search(
  408. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  409. ):
  410. continue
  411. txts.append(t)
  412. ans[id] = "...".join(txts)
  413. return ans
  414. def getAggregation(self, res, fieldnm: str):
  415. """
  416. TODO: Infinity doesn't provide aggregation
  417. """
  418. return list()
  419. """
  420. SQL
  421. """
  422. def sql(sql: str, fetch_size: int, format: str):
  423. raise NotImplementedError("Not implemented")