Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

infinity_conn.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import os
  2. import re
  3. import json
  4. from typing import List, Dict
  5. import infinity
  6. from infinity.common import ConflictType, InfinityException
  7. from infinity.index import IndexInfo, IndexType
  8. from infinity.connection_pool import ConnectionPool
  9. from rag import settings
  10. from api.utils.log_utils import logger
  11. from rag.utils import singleton
  12. import polars as pl
  13. from polars.series.series import Series
  14. from api.utils.file_utils import get_project_base_directory
  15. from rag.utils.doc_store_conn import (
  16. DocStoreConnection,
  17. MatchExpr,
  18. MatchTextExpr,
  19. MatchDenseExpr,
  20. FusionExpr,
  21. OrderByExpr,
  22. )
  23. def equivalent_condition_to_str(condition: dict) -> str:
  24. assert "_id" not in condition
  25. cond = list()
  26. for k, v in condition.items():
  27. if not isinstance(k, str) or not v:
  28. continue
  29. if isinstance(v, list):
  30. inCond = list()
  31. for item in v:
  32. if isinstance(item, str):
  33. inCond.append(f"'{item}'")
  34. else:
  35. inCond.append(str(item))
  36. if inCond:
  37. strInCond = ", ".join(inCond)
  38. strInCond = f"{k} IN ({strInCond})"
  39. cond.append(strInCond)
  40. elif isinstance(v, str):
  41. cond.append(f"{k}='{v}'")
  42. else:
  43. cond.append(f"{k}={str(v)}")
  44. return " AND ".join(cond)
  45. @singleton
  46. class InfinityConnection(DocStoreConnection):
  47. def __init__(self):
  48. self.dbName = settings.INFINITY.get("db_name", "default_db")
  49. infinity_uri = settings.INFINITY["uri"]
  50. if ":" in infinity_uri:
  51. host, port = infinity_uri.split(":")
  52. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  53. self.connPool = ConnectionPool(infinity_uri)
  54. logger.info(f"Connected to infinity {infinity_uri}.")
  55. """
  56. Database operations
  57. """
  58. def dbType(self) -> str:
  59. return "infinity"
  60. def health(self) -> dict:
  61. """
  62. Return the health status of the database.
  63. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  64. """
  65. inf_conn = self.connPool.get_conn()
  66. res = inf_conn.show_current_node()
  67. self.connPool.release_conn(inf_conn)
  68. color = "green" if res.error_code == 0 else "red"
  69. res2 = {
  70. "type": "infinity",
  71. "status": f"{res.role} {color}",
  72. "error": res.error_msg,
  73. }
  74. return res2
  75. """
  76. Table operations
  77. """
  78. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  79. table_name = f"{indexName}_{knowledgebaseId}"
  80. inf_conn = self.connPool.get_conn()
  81. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  82. fp_mapping = os.path.join(
  83. get_project_base_directory(), "conf", "infinity_mapping.json"
  84. )
  85. if not os.path.exists(fp_mapping):
  86. raise Exception(f"Mapping file not found at {fp_mapping}")
  87. schema = json.load(open(fp_mapping))
  88. vector_name = f"q_{vectorSize}_vec"
  89. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  90. inf_table = inf_db.create_table(
  91. table_name,
  92. schema,
  93. ConflictType.Ignore,
  94. )
  95. inf_table.create_index(
  96. "q_vec_idx",
  97. IndexInfo(
  98. vector_name,
  99. IndexType.Hnsw,
  100. {
  101. "M": "16",
  102. "ef_construction": "50",
  103. "metric": "cosine",
  104. "encode": "lvq",
  105. },
  106. ),
  107. ConflictType.Ignore,
  108. )
  109. text_suffix = ["_tks", "_ltks", "_kwd"]
  110. for field_name, field_info in schema.items():
  111. if field_info["type"] != "varchar":
  112. continue
  113. for suffix in text_suffix:
  114. if field_name.endswith(suffix):
  115. inf_table.create_index(
  116. f"text_idx_{field_name}",
  117. IndexInfo(
  118. field_name, IndexType.FullText, {"ANALYZER": "standard"}
  119. ),
  120. ConflictType.Ignore,
  121. )
  122. break
  123. self.connPool.release_conn(inf_conn)
  124. logger.info(
  125. f"INFINITY created table {table_name}, vector size {vectorSize}"
  126. )
  127. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  128. table_name = f"{indexName}_{knowledgebaseId}"
  129. inf_conn = self.connPool.get_conn()
  130. db_instance = inf_conn.get_database(self.dbName)
  131. db_instance.drop_table(table_name, ConflictType.Ignore)
  132. self.connPool.release_conn(inf_conn)
  133. logger.info(f"INFINITY dropped table {table_name}")
  134. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  135. table_name = f"{indexName}_{knowledgebaseId}"
  136. try:
  137. inf_conn = self.connPool.get_conn()
  138. db_instance = inf_conn.get_database(self.dbName)
  139. _ = db_instance.get_table(table_name)
  140. self.connPool.release_conn(inf_conn)
  141. return True
  142. except Exception:
  143. logger.exception("INFINITY indexExist")
  144. return False
  145. """
  146. CRUD operations
  147. """
  148. def search(
  149. self,
  150. selectFields: list[str],
  151. highlightFields: list[str],
  152. condition: dict,
  153. matchExprs: list[MatchExpr],
  154. orderBy: OrderByExpr,
  155. offset: int,
  156. limit: int,
  157. indexNames: str|list[str],
  158. knowledgebaseIds: list[str],
  159. ) -> list[dict] | pl.DataFrame:
  160. """
  161. TODO: Infinity doesn't provide highlight
  162. """
  163. if isinstance(indexNames, str):
  164. indexNames = indexNames.split(",")
  165. assert isinstance(indexNames, list) and len(indexNames) > 0
  166. inf_conn = self.connPool.get_conn()
  167. db_instance = inf_conn.get_database(self.dbName)
  168. df_list = list()
  169. table_list = list()
  170. if "id" not in selectFields:
  171. selectFields.append("id")
  172. # Prepare expressions common to all tables
  173. filter_cond = ""
  174. filter_fulltext = ""
  175. if condition:
  176. filter_cond = equivalent_condition_to_str(condition)
  177. for matchExpr in matchExprs:
  178. if isinstance(matchExpr, MatchTextExpr):
  179. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  180. matchExpr.extra_options.update({"filter": filter_cond})
  181. fields = ",".join(matchExpr.fields)
  182. filter_fulltext = (
  183. f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  184. )
  185. if len(filter_cond) != 0:
  186. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  187. # doc_store_logger.info(f"filter_fulltext: {filter_fulltext}")
  188. minimum_should_match = "0%"
  189. if "minimum_should_match" in matchExpr.extra_options:
  190. minimum_should_match = (
  191. str(int(matchExpr.extra_options["minimum_should_match"] * 100))
  192. + "%"
  193. )
  194. matchExpr.extra_options.update(
  195. {"minimum_should_match": minimum_should_match}
  196. )
  197. for k, v in matchExpr.extra_options.items():
  198. if not isinstance(v, str):
  199. matchExpr.extra_options[k] = str(v)
  200. elif isinstance(matchExpr, MatchDenseExpr):
  201. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  202. matchExpr.extra_options.update({"filter": filter_fulltext})
  203. for k, v in matchExpr.extra_options.items():
  204. if not isinstance(v, str):
  205. matchExpr.extra_options[k] = str(v)
  206. if orderBy.fields:
  207. order_by_expr_list = list()
  208. for order_field in orderBy.fields:
  209. order_by_expr_list.append((order_field[0], order_field[1] == 0))
  210. # Scatter search tables and gather the results
  211. for indexName in indexNames:
  212. for knowledgebaseId in knowledgebaseIds:
  213. table_name = f"{indexName}_{knowledgebaseId}"
  214. try:
  215. table_instance = db_instance.get_table(table_name)
  216. except Exception:
  217. continue
  218. table_list.append(table_name)
  219. builder = table_instance.output(selectFields)
  220. for matchExpr in matchExprs:
  221. if isinstance(matchExpr, MatchTextExpr):
  222. fields = ",".join(matchExpr.fields)
  223. builder = builder.match_text(
  224. fields,
  225. matchExpr.matching_text,
  226. matchExpr.topn,
  227. matchExpr.extra_options,
  228. )
  229. elif isinstance(matchExpr, MatchDenseExpr):
  230. builder = builder.match_dense(
  231. matchExpr.vector_column_name,
  232. matchExpr.embedding_data,
  233. matchExpr.embedding_data_type,
  234. matchExpr.distance_type,
  235. matchExpr.topn,
  236. matchExpr.extra_options,
  237. )
  238. elif isinstance(matchExpr, FusionExpr):
  239. builder = builder.fusion(
  240. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  241. )
  242. if orderBy.fields:
  243. builder.sort(order_by_expr_list)
  244. builder.offset(offset).limit(limit)
  245. kb_res = builder.to_pl()
  246. df_list.append(kb_res)
  247. self.connPool.release_conn(inf_conn)
  248. res = pl.concat(df_list)
  249. logger.info("INFINITY search tables: " + str(table_list))
  250. return res
  251. def get(
  252. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  253. ) -> dict | None:
  254. inf_conn = self.connPool.get_conn()
  255. db_instance = inf_conn.get_database(self.dbName)
  256. df_list = list()
  257. assert isinstance(knowledgebaseIds, list)
  258. for knowledgebaseId in knowledgebaseIds:
  259. table_name = f"{indexName}_{knowledgebaseId}"
  260. table_instance = db_instance.get_table(table_name)
  261. kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  262. df_list.append(kb_res)
  263. self.connPool.release_conn(inf_conn)
  264. res = pl.concat(df_list)
  265. res_fields = self.getFields(res, res.columns)
  266. return res_fields.get(chunkId, None)
  267. def insert(
  268. self, documents: list[dict], indexName: str, knowledgebaseId: str
  269. ) -> list[str]:
  270. inf_conn = self.connPool.get_conn()
  271. db_instance = inf_conn.get_database(self.dbName)
  272. table_name = f"{indexName}_{knowledgebaseId}"
  273. try:
  274. table_instance = db_instance.get_table(table_name)
  275. except InfinityException as e:
  276. # src/common/status.cppm, kTableNotExist = 3022
  277. if e.error_code != 3022:
  278. raise
  279. vector_size = 0
  280. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  281. for k in documents[0].keys():
  282. m = patt.match(k)
  283. if m:
  284. vector_size = int(m.group("vector_size"))
  285. break
  286. if vector_size == 0:
  287. raise ValueError("Cannot infer vector size from documents")
  288. self.createIdx(indexName, knowledgebaseId, vector_size)
  289. table_instance = db_instance.get_table(table_name)
  290. for d in documents:
  291. assert "_id" not in d
  292. assert "id" in d
  293. for k, v in d.items():
  294. if k.endswith("_kwd") and isinstance(v, list):
  295. d[k] = " ".join(v)
  296. ids = [f"'{d["id"]}'" for d in documents]
  297. str_ids = ", ".join(ids)
  298. str_filter = f"id IN ({str_ids})"
  299. table_instance.delete(str_filter)
  300. # for doc in documents:
  301. # logger.info(f"insert position_list: {doc['position_list']}")
  302. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  303. table_instance.insert(documents)
  304. self.connPool.release_conn(inf_conn)
  305. doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
  306. return []
  307. def update(
  308. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  309. ) -> bool:
  310. # if 'position_list' in newValue:
  311. # logger.info(f"upsert position_list: {newValue['position_list']}")
  312. inf_conn = self.connPool.get_conn()
  313. db_instance = inf_conn.get_database(self.dbName)
  314. table_name = f"{indexName}_{knowledgebaseId}"
  315. table_instance = db_instance.get_table(table_name)
  316. filter = equivalent_condition_to_str(condition)
  317. for k, v in newValue.items():
  318. if k.endswith("_kwd") and isinstance(v, list):
  319. newValue[k] = " ".join(v)
  320. table_instance.update(filter, newValue)
  321. self.connPool.release_conn(inf_conn)
  322. return True
  323. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  324. inf_conn = self.connPool.get_conn()
  325. db_instance = inf_conn.get_database(self.dbName)
  326. table_name = f"{indexName}_{knowledgebaseId}"
  327. filter = equivalent_condition_to_str(condition)
  328. try:
  329. table_instance = db_instance.get_table(table_name)
  330. except Exception:
  331. logger.warning(
  332. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  333. )
  334. return 0
  335. res = table_instance.delete(filter)
  336. self.connPool.release_conn(inf_conn)
  337. return res.deleted_rows
  338. """
  339. Helper functions for search result
  340. """
  341. def getTotal(self, res):
  342. return len(res)
  343. def getChunkIds(self, res):
  344. return list(res["id"])
  345. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  346. res_fields = {}
  347. if not fields:
  348. return {}
  349. num_rows = len(res)
  350. column_id = res["id"]
  351. for i in range(num_rows):
  352. id = column_id[i]
  353. m = {"id": id}
  354. for fieldnm in fields:
  355. if fieldnm not in res:
  356. m[fieldnm] = None
  357. continue
  358. v = res[fieldnm][i]
  359. if isinstance(v, Series):
  360. v = list(v)
  361. elif fieldnm == "important_kwd":
  362. assert isinstance(v, str)
  363. v = v.split(" ")
  364. else:
  365. if not isinstance(v, str):
  366. v = str(v)
  367. # if fieldnm.endswith("_tks"):
  368. # v = rmSpace(v)
  369. m[fieldnm] = v
  370. res_fields[id] = m
  371. return res_fields
  372. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  373. ans = {}
  374. num_rows = len(res)
  375. column_id = res["id"]
  376. for i in range(num_rows):
  377. id = column_id[i]
  378. txt = res[fieldnm][i]
  379. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  380. txts = []
  381. for t in re.split(r"[.?!;\n]", txt):
  382. for w in keywords:
  383. t = re.sub(
  384. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  385. % re.escape(w),
  386. r"\1<em>\2</em>\3",
  387. t,
  388. flags=re.IGNORECASE | re.MULTILINE,
  389. )
  390. if not re.search(
  391. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  392. ):
  393. continue
  394. txts.append(t)
  395. ans[id] = "...".join(txts)
  396. return ans
  397. def getAggregation(self, res, fieldnm: str):
  398. """
  399. TODO: Infinity doesn't provide aggregation
  400. """
  401. return list()
  402. """
  403. SQL
  404. """
  405. def sql(sql: str, fetch_size: int, format: str):
  406. raise NotImplementedError("Not implemented")