選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

infinity_conn.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import logging
  2. import os
  3. import re
  4. import json
  5. import time
  6. import infinity
  7. from infinity.common import ConflictType, InfinityException
  8. from infinity.index import IndexInfo, IndexType
  9. from infinity.connection_pool import ConnectionPool
  10. from rag import settings
  11. from rag.utils import singleton
  12. import polars as pl
  13. from polars.series.series import Series
  14. from api.utils.file_utils import get_project_base_directory
  15. from rag.utils.doc_store_conn import (
  16. DocStoreConnection,
  17. MatchExpr,
  18. MatchTextExpr,
  19. MatchDenseExpr,
  20. FusionExpr,
  21. OrderByExpr,
  22. )
  23. def equivalent_condition_to_str(condition: dict) -> str:
  24. assert "_id" not in condition
  25. cond = list()
  26. for k, v in condition.items():
  27. if not isinstance(k, str) or not v:
  28. continue
  29. if isinstance(v, list):
  30. inCond = list()
  31. for item in v:
  32. if isinstance(item, str):
  33. inCond.append(f"'{item}'")
  34. else:
  35. inCond.append(str(item))
  36. if inCond:
  37. strInCond = ", ".join(inCond)
  38. strInCond = f"{k} IN ({strInCond})"
  39. cond.append(strInCond)
  40. elif isinstance(v, str):
  41. cond.append(f"{k}='{v}'")
  42. else:
  43. cond.append(f"{k}={str(v)}")
  44. return " AND ".join(cond)
  45. @singleton
  46. class InfinityConnection(DocStoreConnection):
  47. def __init__(self):
  48. self.dbName = settings.INFINITY.get("db_name", "default_db")
  49. infinity_uri = settings.INFINITY["uri"]
  50. if ":" in infinity_uri:
  51. host, port = infinity_uri.split(":")
  52. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  53. self.connPool = None
  54. logging.info(f"Use Infinity {infinity_uri} as the doc engine.")
  55. for _ in range(24):
  56. try:
  57. connPool = ConnectionPool(infinity_uri)
  58. inf_conn = connPool.get_conn()
  59. _ = inf_conn.show_current_node()
  60. connPool.release_conn(inf_conn)
  61. self.connPool = connPool
  62. break
  63. except Exception as e:
  64. logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  65. time.sleep(5)
  66. if self.connPool is None:
  67. msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
  68. logging.error(msg)
  69. raise Exception(msg)
  70. logging.info(f"Infinity {infinity_uri} is healthy.")
  71. """
  72. Database operations
  73. """
  74. def dbType(self) -> str:
  75. return "infinity"
  76. def health(self) -> dict:
  77. """
  78. Return the health status of the database.
  79. TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
  80. """
  81. inf_conn = self.connPool.get_conn()
  82. res = inf_conn.show_current_node()
  83. self.connPool.release_conn(inf_conn)
  84. color = "green" if res.error_code == 0 else "red"
  85. res2 = {
  86. "type": "infinity",
  87. "status": f"{res.role} {color}",
  88. "error": res.error_msg,
  89. }
  90. return res2
  91. """
  92. Table operations
  93. """
  94. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  95. table_name = f"{indexName}_{knowledgebaseId}"
  96. inf_conn = self.connPool.get_conn()
  97. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  98. fp_mapping = os.path.join(
  99. get_project_base_directory(), "conf", "infinity_mapping.json"
  100. )
  101. if not os.path.exists(fp_mapping):
  102. raise Exception(f"Mapping file not found at {fp_mapping}")
  103. schema = json.load(open(fp_mapping))
  104. vector_name = f"q_{vectorSize}_vec"
  105. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  106. inf_table = inf_db.create_table(
  107. table_name,
  108. schema,
  109. ConflictType.Ignore,
  110. )
  111. inf_table.create_index(
  112. "q_vec_idx",
  113. IndexInfo(
  114. vector_name,
  115. IndexType.Hnsw,
  116. {
  117. "M": "16",
  118. "ef_construction": "50",
  119. "metric": "cosine",
  120. "encode": "lvq",
  121. },
  122. ),
  123. ConflictType.Ignore,
  124. )
  125. text_suffix = ["_tks", "_ltks", "_kwd"]
  126. for field_name, field_info in schema.items():
  127. if field_info["type"] != "varchar":
  128. continue
  129. for suffix in text_suffix:
  130. if field_name.endswith(suffix):
  131. inf_table.create_index(
  132. f"text_idx_{field_name}",
  133. IndexInfo(
  134. field_name, IndexType.FullText, {"ANALYZER": "standard"}
  135. ),
  136. ConflictType.Ignore,
  137. )
  138. break
  139. self.connPool.release_conn(inf_conn)
  140. logging.info(
  141. f"INFINITY created table {table_name}, vector size {vectorSize}"
  142. )
  143. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  144. table_name = f"{indexName}_{knowledgebaseId}"
  145. inf_conn = self.connPool.get_conn()
  146. db_instance = inf_conn.get_database(self.dbName)
  147. db_instance.drop_table(table_name, ConflictType.Ignore)
  148. self.connPool.release_conn(inf_conn)
  149. logging.info(f"INFINITY dropped table {table_name}")
  150. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  151. table_name = f"{indexName}_{knowledgebaseId}"
  152. try:
  153. inf_conn = self.connPool.get_conn()
  154. db_instance = inf_conn.get_database(self.dbName)
  155. _ = db_instance.get_table(table_name)
  156. self.connPool.release_conn(inf_conn)
  157. return True
  158. except Exception as e:
  159. logging.warn(f"INFINITY indexExist {str(e)}")
  160. return False
  161. """
  162. CRUD operations
  163. """
  164. def search(
  165. self,
  166. selectFields: list[str],
  167. highlightFields: list[str],
  168. condition: dict,
  169. matchExprs: list[MatchExpr],
  170. orderBy: OrderByExpr,
  171. offset: int,
  172. limit: int,
  173. indexNames: str|list[str],
  174. knowledgebaseIds: list[str],
  175. ) -> list[dict] | pl.DataFrame:
  176. """
  177. TODO: Infinity doesn't provide highlight
  178. """
  179. if isinstance(indexNames, str):
  180. indexNames = indexNames.split(",")
  181. assert isinstance(indexNames, list) and len(indexNames) > 0
  182. inf_conn = self.connPool.get_conn()
  183. db_instance = inf_conn.get_database(self.dbName)
  184. df_list = list()
  185. table_list = list()
  186. if "id" not in selectFields:
  187. selectFields.append("id")
  188. # Prepare expressions common to all tables
  189. filter_cond = ""
  190. filter_fulltext = ""
  191. if condition:
  192. filter_cond = equivalent_condition_to_str(condition)
  193. for matchExpr in matchExprs:
  194. if isinstance(matchExpr, MatchTextExpr):
  195. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  196. matchExpr.extra_options.update({"filter": filter_cond})
  197. fields = ",".join(matchExpr.fields)
  198. filter_fulltext = (
  199. f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  200. )
  201. if len(filter_cond) != 0:
  202. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  203. logging.debug(f"filter_fulltext: {filter_fulltext}")
  204. minimum_should_match = "0%"
  205. if "minimum_should_match" in matchExpr.extra_options:
  206. minimum_should_match = (
  207. str(int(matchExpr.extra_options["minimum_should_match"] * 100))
  208. + "%"
  209. )
  210. matchExpr.extra_options.update(
  211. {"minimum_should_match": minimum_should_match}
  212. )
  213. for k, v in matchExpr.extra_options.items():
  214. if not isinstance(v, str):
  215. matchExpr.extra_options[k] = str(v)
  216. elif isinstance(matchExpr, MatchDenseExpr):
  217. if len(filter_cond) != 0 and "filter" not in matchExpr.extra_options:
  218. matchExpr.extra_options.update({"filter": filter_fulltext})
  219. for k, v in matchExpr.extra_options.items():
  220. if not isinstance(v, str):
  221. matchExpr.extra_options[k] = str(v)
  222. if orderBy.fields:
  223. order_by_expr_list = list()
  224. for order_field in orderBy.fields:
  225. order_by_expr_list.append((order_field[0], order_field[1] == 0))
  226. # Scatter search tables and gather the results
  227. for indexName in indexNames:
  228. for knowledgebaseId in knowledgebaseIds:
  229. table_name = f"{indexName}_{knowledgebaseId}"
  230. try:
  231. table_instance = db_instance.get_table(table_name)
  232. except Exception:
  233. continue
  234. table_list.append(table_name)
  235. builder = table_instance.output(selectFields)
  236. for matchExpr in matchExprs:
  237. if isinstance(matchExpr, MatchTextExpr):
  238. fields = ",".join(matchExpr.fields)
  239. builder = builder.match_text(
  240. fields,
  241. matchExpr.matching_text,
  242. matchExpr.topn,
  243. matchExpr.extra_options,
  244. )
  245. elif isinstance(matchExpr, MatchDenseExpr):
  246. builder = builder.match_dense(
  247. matchExpr.vector_column_name,
  248. matchExpr.embedding_data,
  249. matchExpr.embedding_data_type,
  250. matchExpr.distance_type,
  251. matchExpr.topn,
  252. matchExpr.extra_options,
  253. )
  254. elif isinstance(matchExpr, FusionExpr):
  255. builder = builder.fusion(
  256. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  257. )
  258. if orderBy.fields:
  259. builder.sort(order_by_expr_list)
  260. builder.offset(offset).limit(limit)
  261. kb_res = builder.to_pl()
  262. df_list.append(kb_res)
  263. self.connPool.release_conn(inf_conn)
  264. res = pl.concat(df_list)
  265. logging.debug("INFINITY search tables: " + str(table_list))
  266. return res
  267. def get(
  268. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  269. ) -> dict | None:
  270. inf_conn = self.connPool.get_conn()
  271. db_instance = inf_conn.get_database(self.dbName)
  272. df_list = list()
  273. assert isinstance(knowledgebaseIds, list)
  274. for knowledgebaseId in knowledgebaseIds:
  275. table_name = f"{indexName}_{knowledgebaseId}"
  276. table_instance = db_instance.get_table(table_name)
  277. kb_res = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  278. df_list.append(kb_res)
  279. self.connPool.release_conn(inf_conn)
  280. res = pl.concat(df_list)
  281. res_fields = self.getFields(res, res.columns)
  282. return res_fields.get(chunkId, None)
  283. def insert(
  284. self, documents: list[dict], indexName: str, knowledgebaseId: str
  285. ) -> list[str]:
  286. inf_conn = self.connPool.get_conn()
  287. db_instance = inf_conn.get_database(self.dbName)
  288. table_name = f"{indexName}_{knowledgebaseId}"
  289. try:
  290. table_instance = db_instance.get_table(table_name)
  291. except InfinityException as e:
  292. # src/common/status.cppm, kTableNotExist = 3022
  293. if e.error_code != 3022:
  294. raise
  295. vector_size = 0
  296. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  297. for k in documents[0].keys():
  298. m = patt.match(k)
  299. if m:
  300. vector_size = int(m.group("vector_size"))
  301. break
  302. if vector_size == 0:
  303. raise ValueError("Cannot infer vector size from documents")
  304. self.createIdx(indexName, knowledgebaseId, vector_size)
  305. table_instance = db_instance.get_table(table_name)
  306. for d in documents:
  307. assert "_id" not in d
  308. assert "id" in d
  309. for k, v in d.items():
  310. if k.endswith("_kwd") and isinstance(v, list):
  311. d[k] = " ".join(v)
  312. ids = ["'{}'".format(d["id"]) for d in documents]
  313. str_ids = ", ".join(ids)
  314. str_filter = f"id IN ({str_ids})"
  315. table_instance.delete(str_filter)
  316. # for doc in documents:
  317. # logging.info(f"insert position_list: {doc['position_list']}")
  318. # logging.info(f"InfinityConnection.insert {json.dumps(documents)}")
  319. table_instance.insert(documents)
  320. self.connPool.release_conn(inf_conn)
  321. logging.debug(f"inserted into {table_name} {str_ids}.")
  322. return []
  323. def update(
  324. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  325. ) -> bool:
  326. # if 'position_list' in newValue:
  327. # logging.info(f"upsert position_list: {newValue['position_list']}")
  328. inf_conn = self.connPool.get_conn()
  329. db_instance = inf_conn.get_database(self.dbName)
  330. table_name = f"{indexName}_{knowledgebaseId}"
  331. table_instance = db_instance.get_table(table_name)
  332. filter = equivalent_condition_to_str(condition)
  333. for k, v in newValue.items():
  334. if k.endswith("_kwd") and isinstance(v, list):
  335. newValue[k] = " ".join(v)
  336. table_instance.update(filter, newValue)
  337. self.connPool.release_conn(inf_conn)
  338. return True
  339. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  340. inf_conn = self.connPool.get_conn()
  341. db_instance = inf_conn.get_database(self.dbName)
  342. table_name = f"{indexName}_{knowledgebaseId}"
  343. filter = equivalent_condition_to_str(condition)
  344. try:
  345. table_instance = db_instance.get_table(table_name)
  346. except Exception:
  347. logging.warning(
  348. f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
  349. )
  350. return 0
  351. res = table_instance.delete(filter)
  352. self.connPool.release_conn(inf_conn)
  353. return res.deleted_rows
  354. """
  355. Helper functions for search result
  356. """
  357. def getTotal(self, res):
  358. return len(res)
  359. def getChunkIds(self, res):
  360. return list(res["id"])
  361. def getFields(self, res, fields: list[str]) -> list[str, dict]:
  362. res_fields = {}
  363. if not fields:
  364. return {}
  365. num_rows = len(res)
  366. column_id = res["id"]
  367. for i in range(num_rows):
  368. id = column_id[i]
  369. m = {"id": id}
  370. for fieldnm in fields:
  371. if fieldnm not in res:
  372. m[fieldnm] = None
  373. continue
  374. v = res[fieldnm][i]
  375. if isinstance(v, Series):
  376. v = list(v)
  377. elif fieldnm == "important_kwd":
  378. assert isinstance(v, str)
  379. v = v.split(" ")
  380. else:
  381. if not isinstance(v, str):
  382. v = str(v)
  383. # if fieldnm.endswith("_tks"):
  384. # v = rmSpace(v)
  385. m[fieldnm] = v
  386. res_fields[id] = m
  387. return res_fields
  388. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  389. ans = {}
  390. num_rows = len(res)
  391. column_id = res["id"]
  392. for i in range(num_rows):
  393. id = column_id[i]
  394. txt = res[fieldnm][i]
  395. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  396. txts = []
  397. for t in re.split(r"[.?!;\n]", txt):
  398. for w in keywords:
  399. t = re.sub(
  400. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  401. % re.escape(w),
  402. r"\1<em>\2</em>\3",
  403. t,
  404. flags=re.IGNORECASE | re.MULTILINE,
  405. )
  406. if not re.search(
  407. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  408. ):
  409. continue
  410. txts.append(t)
  411. ans[id] = "...".join(txts)
  412. return ans
  413. def getAggregation(self, res, fieldnm: str):
  414. """
  415. TODO: Infinity doesn't provide aggregation
  416. """
  417. return list()
  418. """
  419. SQL
  420. """
  421. def sql(sql: str, fetch_size: int, format: str):
  422. raise NotImplementedError("Not implemented")