Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

infinity_conn.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import logging
  2. import os
  3. import re
  4. import json
  5. import time
  6. import infinity
  7. from infinity.common import ConflictType, InfinityException, SortType
  8. from infinity.index import IndexInfo, IndexType
  9. from infinity.connection_pool import ConnectionPool
  10. from rag import settings
  11. from rag.utils import singleton
  12. import polars as pl
  13. from polars.series.series import Series
  14. from api.utils.file_utils import get_project_base_directory
  15. from rag.utils.doc_store_conn import (
  16. DocStoreConnection,
  17. MatchExpr,
  18. MatchTextExpr,
  19. MatchDenseExpr,
  20. FusionExpr,
  21. OrderByExpr,
  22. )
  23. def equivalent_condition_to_str(condition: dict) -> str:
  24. assert "_id" not in condition
  25. cond = list()
  26. for k, v in condition.items():
  27. if not isinstance(k, str) or not v:
  28. continue
  29. if isinstance(v, list):
  30. inCond = list()
  31. for item in v:
  32. if isinstance(item, str):
  33. inCond.append(f"'{item}'")
  34. else:
  35. inCond.append(str(item))
  36. if inCond:
  37. strInCond = ", ".join(inCond)
  38. strInCond = f"{k} IN ({strInCond})"
  39. cond.append(strInCond)
  40. elif isinstance(v, str):
  41. cond.append(f"{k}='{v}'")
  42. else:
  43. cond.append(f"{k}={str(v)}")
  44. return " AND ".join(cond)
  45. @singleton
  46. class InfinityConnection(DocStoreConnection):
  47. def __init__(self):
  48. self.dbName = settings.INFINITY.get("db_name", "default_db")
  49. infinity_uri = settings.INFINITY["uri"]
  50. if ":" in infinity_uri:
  51. host, port = infinity_uri.split(":")
  52. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  53. self.connPool = None
  54. logging.info(f"Use Infinity {infinity_uri} as the doc engine.")
  55. for _ in range(24):
  56. try:
  57. connPool = ConnectionPool(infinity_uri)
  58. inf_conn = connPool.get_conn()
  59. _ = inf_conn.show_current_node()
  60. connPool.release_conn(inf_conn)
  61. self.connPool = connPool
  62. break
  63. except Exception as e:
  64. logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  65. time.sleep(5)
  66. if self.connPool is None:
  67. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  68. logging.error(msg)
  69. raise Exception(msg)
  70. logging.info(f"Infinity {infinity_uri} is healthy.")
  71. """
  72. Database operations
  73. """
  74. def dbType(self) -> str:
  75. return "infinity"
  76. def health(self) -> dict:
  77. """
  78. Return the health status of the database.
  79. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  80. """
  81. inf_conn = self.connPool.get_conn()
  82. res = inf_conn.show_current_node()
  83. self.connPool.release_conn(inf_conn)
  84. color = "green" if res.error_code == 0 else "red"
  85. res2 = {
  86. "type": "infinity",
  87. "status": f"{res.role} {color}",
  88. "error": res.error_msg,
  89. }
  90. return res2
  91. """
  92. Table operations
  93. """
  94. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  95. table_name = f"{indexName}_{knowledgebaseId}"
  96. inf_conn = self.connPool.get_conn()
  97. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  98. fp_mapping = os.path.join(
  99. get_project_base_directory(), "conf", "infinity_mapping.json"
  100. )
  101. if not os.path.exists(fp_mapping):
  102. raise Exception(f"Mapping file not found at {fp_mapping}")
  103. schema = json.load(open(fp_mapping))
  104. vector_name = f"q_{vectorSize}_vec"
  105. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  106. inf_table = inf_db.create_table(
  107. table_name,
  108. schema,
  109. ConflictType.Ignore,
  110. )
  111. inf_table.create_index(
  112. "q_vec_idx",
  113. IndexInfo(
  114. vector_name,
  115. IndexType.Hnsw,
  116. {
  117. "M": "16",
  118. "ef_construction": "50",
  119. "metric": "cosine",
  120. "encode": "lvq",
  121. },
  122. ),
  123. ConflictType.Ignore,
  124. )
  125. text_suffix = ["_tks", "_ltks", "_kwd"]
  126. for field_name, field_info in schema.items():
  127. if field_info["type"] != "varchar":
  128. continue
  129. for suffix in text_suffix:
  130. if field_name.endswith(suffix):
  131. inf_table.create_index(
  132. f"text_idx_{field_name}",
  133. IndexInfo(
  134. field_name, IndexType.FullText, {"ANALYZER": "standard"}
  135. ),
  136. ConflictType.Ignore,
  137. )
  138. break
  139. self.connPool.release_conn(inf_conn)
  140. logging.info(
  141. f"INFINITY created table {table_name}, vector size {vectorSize}"
  142. )
  143. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  144. table_name = f"{indexName}_{knowledgebaseId}"
  145. inf_conn = self.connPool.get_conn()
  146. db_instance = inf_conn.get_database(self.dbName)
  147. db_instance.drop_table(table_name, ConflictType.Ignore)
  148. self.connPool.release_conn(inf_conn)
  149. logging.info(f"INFINITY dropped table {table_name}")
  150. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  151. table_name = f"{indexName}_{knowledgebaseId}"
  152. try:
  153. inf_conn = self.connPool.get_conn()
  154. db_instance = inf_conn.get_database(self.dbName)
  155. _ = db_instance.get_table(table_name)
  156. self.connPool.release_conn(inf_conn)
  157. return True
  158. except Exception as e:
  159. logging.warning(f"INFINITY indexExist {str(e)}")
  160. return False
  161. """
  162. CRUD operations
  163. """
  164. def search(
  165. self,
  166. selectFields: list[str],
  167. highlightFields: list[str],
  168. condition: dict,
  169. matchExprs: list[MatchExpr],
  170. orderBy: OrderByExpr,
  171. offset: int,
  172. limit: int,
  173. indexNames: str | list[str],
  174. knowledgebaseIds: list[str],
  175. ) -> list[dict] | pl.DataFrame:
  176. """
  177. TODO: Infinity doesn't provide highlight
  178. """
  179. if isinstance(indexNames, str):
  180. indexNames = indexNames.split(",")
  181. assert isinstance(indexNames, list) and len(indexNames) > 0
  182. inf_conn = self.connPool.get_conn()
  183. db_instance = inf_conn.get_database(self.dbName)
  184. df_list = list()
  185. table_list = list()
  186. if "id" not in selectFields:
  187. selectFields.append("id")
  188. # Prepare expressions common to all tables
  189. filter_cond = ""
  190. filter_fulltext = ""
  191. if condition:
  192. filter_cond = equivalent_condition_to_str(condition)
  193. for matchExpr in matchExprs:
  194. if isinstance(matchExpr, MatchTextExpr):
  195. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  196. matchExpr.extra_options.update({"filter": filter_cond})
  197. fields = ",".join(matchExpr.fields)
  198. filter_fulltext = (
  199. f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  200. )
  201. if len(filter_cond) != 0:
  202. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  203. logging.debug(f"filter_fulltext: {filter_fulltext}")
  204. minimum_should_match = "0%"
  205. if "minimum_should_match" in matchExpr.extra_options:
  206. minimum_should_match = (
  207. str(int(matchExpr.extra_options["minimum_should_match"] * 100))
  208. + "%"
  209. )
  210. matchExpr.extra_options.update(
  211. {"minimum_should_match": minimum_should_match}
  212. )
  213. for k, v in matchExpr.extra_options.items():
  214. if not isinstance(v, str):
  215. matchExpr.extra_options[k] = str(v)
  216. elif isinstance(matchExpr, MatchDenseExpr):
  217. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  218. matchExpr.extra_options.update({"filter": filter_fulltext})
  219. for k, v in matchExpr.extra_options.items():
  220. if not isinstance(v, str):
  221. matchExpr.extra_options[k] = str(v)
  222. order_by_expr_list = list()
  223. if orderBy.fields:
  224. for order_field in orderBy.fields:
  225. if order_field[1] == 0:
  226. order_by_expr_list.append((order_field[0], SortType.Asc))
  227. else:
  228. order_by_expr_list.append((order_field[0], SortType.Desc))
  229. # Scatter search tables and gather the results
  230. for indexName in indexNames:
  231. for knowledgebaseId in knowledgebaseIds:
  232. table_name = f"{indexName}_{knowledgebaseId}"
  233. try:
  234. table_instance = db_instance.get_table(table_name)
  235. except Exception:
  236. continue
  237. table_list.append(table_name)
  238. builder = table_instance.output(selectFields)
  239. if len(matchExprs) > 0:
  240. for matchExpr in matchExprs:
  241. if isinstance(matchExpr, MatchTextExpr):
  242. fields = ",".join(matchExpr.fields)
  243. builder = builder.match_text(
  244. fields,
  245. matchExpr.matching_text,
  246. matchExpr.topn,
  247. matchExpr.extra_options,
  248. )
  249. elif isinstance(matchExpr, MatchDenseExpr):
  250. builder = builder.match_dense(
  251. matchExpr.vector_column_name,
  252. matchExpr.embedding_data,
  253. matchExpr.embedding_data_type,
  254. matchExpr.distance_type,
  255. matchExpr.topn,
  256. matchExpr.extra_options,
  257. )
  258. elif isinstance(matchExpr, FusionExpr):
  259. builder = builder.fusion(
  260. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  261. )
  262. else:
  263. if len(filter_cond) > 0:
  264. builder.filter(filter_cond)
  265. if orderBy.fields:
  266. builder.sort(order_by_expr_list)
  267. builder.offset(offset).limit(limit)
  268. kb_res = builder.to_pl()
  269. df_list.append(kb_res)
  270. self.connPool.release_conn(inf_conn)
  271. res = pl.concat(df_list)
  272. logging.debug("INFINITY search tables: " + str(table_list))
  273. return res
  274. def get(
  275. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  276. ) -> dict | None:
  277. inf_conn = self.connPool.get_conn()
  278. db_instance = inf_conn.get_database(self.dbName)
  279. df_list = list()
  280. assert isinstance(knowledgebaseIds, list)
  281. for knowledgebaseId in knowledgebaseIds:
  282. table_name = f"{indexName}_{knowledgebaseId}"
  283. table_instance = db_instance.get_table(table_name)
  284. kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  285. df_list.append(kb_res)
  286. self.connPool.release_conn(inf_conn)
  287. res = pl.concat(df_list)
  288. res_fields = self.getFields(res, res.columns)
  289. return res_fields.get(chunkId, None)
  290. def insert(
  291. self, documents: list[dict], indexName: str, knowledgebaseId: str
  292. ) -> list[str]:
  293. inf_conn = self.connPool.get_conn()
  294. db_instance = inf_conn.get_database(self.dbName)
  295. table_name = f"{indexName}_{knowledgebaseId}"
  296. try:
  297. table_instance = db_instance.get_table(table_name)
  298. except InfinityException as e:
  299. # src/common/status.cppm, kTableNotExist = 3022
  300. if e.error_code != 3022:
  301. raise
  302. vector_size = 0
  303. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  304. for k in documents[0].keys():
  305. m = patt.match(k)
  306. if m:
  307. vector_size = int(m.group("vector_size"))
  308. break
  309. if vector_size == 0:
  310. raise ValueError("Cannot infer vector size from documents")
  311. self.createIdx(indexName, knowledgebaseId, vector_size)
  312. table_instance = db_instance.get_table(table_name)
  313. for d in documents:
  314. assert "_id" not in d
  315. assert "id" in d
  316. for k, v in d.items():
  317. if k.endswith("_kwd") and isinstance(v, list):
  318. d[k] = " ".join(v)
  319. ids = ["'{}'".format(d["id"]) for d in documents]
  320. str_ids = ", ".join(ids)
  321. str_filter = f"id IN ({str_ids})"
  322. table_instance.delete(str_filter)
  323. # for doc in documents:
  324. # logging.info(f"insert position_list: {doc['position_list']}")
  325. # logging.info(f"InfinityConnection.insert {json.dumps(documents)}")
  326. table_instance.insert(documents)
  327. self.connPool.release_conn(inf_conn)
  328. logging.debug(f"inserted into {table_name} {str_ids}.")
  329. return []
  330. def update(
  331. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  332. ) -> bool:
  333. # if 'position_list' in newValue:
  334. # logging.info(f"upsert position_list: {newValue['position_list']}")
  335. inf_conn = self.connPool.get_conn()
  336. db_instance = inf_conn.get_database(self.dbName)
  337. table_name = f"{indexName}_{knowledgebaseId}"
  338. table_instance = db_instance.get_table(table_name)
  339. filter = equivalent_condition_to_str(condition)
  340. for k, v in newValue.items():
  341. if k.endswith("_kwd") and isinstance(v, list):
  342. newValue[k] = " ".join(v)
  343. table_instance.update(filter, newValue)
  344. self.connPool.release_conn(inf_conn)
  345. return True
  346. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  347. inf_conn = self.connPool.get_conn()
  348. db_instance = inf_conn.get_database(self.dbName)
  349. table_name = f"{indexName}_{knowledgebaseId}"
  350. filter = equivalent_condition_to_str(condition)
  351. try:
  352. table_instance = db_instance.get_table(table_name)
  353. except Exception:
  354. logging.warning(
  355. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  356. )
  357. return 0
  358. res = table_instance.delete(filter)
  359. self.connPool.release_conn(inf_conn)
  360. return res.deleted_rows
  361. """
  362. Helper functions for search result
  363. """
  364. def getTotal(self, res):
  365. return len(res)
  366. def getChunkIds(self, res):
  367. return list(res["id"])
  368. def getFields(self, res, fields: list[str]) -> list[str, dict]:
  369. res_fields = {}
  370. if not fields:
  371. return {}
  372. num_rows = len(res)
  373. column_id = res["id"]
  374. for i in range(num_rows):
  375. id = column_id[i]
  376. m = {"id": id}
  377. for fieldnm in fields:
  378. if fieldnm not in res:
  379. m[fieldnm] = None
  380. continue
  381. v = res[fieldnm][i]
  382. if isinstance(v, Series):
  383. v = list(v)
  384. elif fieldnm == "important_kwd":
  385. assert isinstance(v, str)
  386. v = v.split(" ")
  387. else:
  388. if not isinstance(v, str):
  389. v = str(v)
  390. # if fieldnm.endswith("_tks"):
  391. # v = rmSpace(v)
  392. m[fieldnm] = v
  393. res_fields[id] = m
  394. return res_fields
  395. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  396. ans = {}
  397. num_rows = len(res)
  398. column_id = res["id"]
  399. for i in range(num_rows):
  400. id = column_id[i]
  401. txt = res[fieldnm][i]
  402. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  403. txts = []
  404. for t in re.split(r"[.?!;\n]", txt):
  405. for w in keywords:
  406. t = re.sub(
  407. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  408. % re.escape(w),
  409. r"\1<em>\2</em>\3",
  410. t,
  411. flags=re.IGNORECASE | re.MULTILINE,
  412. )
  413. if not re.search(
  414. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  415. ):
  416. continue
  417. txts.append(t)
  418. ans[id] = "...".join(txts)
  419. return ans
  420. def getAggregation(self, res, fieldnm: str):
  421. """
  422. TODO: Infinity doesn't provide aggregation
  423. """
  424. return list()
  425. """
  426. SQL
  427. """
  428. def sql(sql: str, fetch_size: int, format: str):
  429. raise NotImplementedError("Not implemented")