Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

infinity_conn.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. import json
  20. import time
  21. import copy
  22. import infinity
  23. from infinity.common import ConflictType, InfinityException, SortType
  24. from infinity.index import IndexInfo, IndexType
  25. from infinity.connection_pool import ConnectionPool
  26. from infinity.errors import ErrorCode
  27. from rag import settings
  28. from rag.settings import PAGERANK_FLD, TAG_FLD
  29. from rag.utils import singleton
  30. import pandas as pd
  31. from api.utils.file_utils import get_project_base_directory
  32. from rag.utils.doc_store_conn import (
  33. DocStoreConnection,
  34. MatchExpr,
  35. MatchTextExpr,
  36. MatchDenseExpr,
  37. FusionExpr,
  38. OrderByExpr,
  39. )
  40. logger = logging.getLogger('ragflow.infinity_conn')
  41. def field_keyword(field_name: str):
  42. # The "docnm_kwd" field is always a string, not list.
  43. if field_name == "source_id" or (field_name.endswith("_kwd") and field_name != "docnm_kwd" and field_name != "knowledge_graph_kwd"):
  44. return True
  45. return False
  46. def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
  47. assert "_id" not in condition
  48. clmns = {}
  49. if table_instance:
  50. for n, ty, de, _ in table_instance.show_columns().rows():
  51. clmns[n] = (ty, de)
  52. def exists(cln):
  53. nonlocal clmns
  54. assert cln in clmns, f"'{cln}' should be in '{clmns}'."
  55. ty, de = clmns[cln]
  56. if ty.lower().find("cha"):
  57. if not de:
  58. de = ""
  59. return f" {cln}!='{de}' "
  60. return f"{cln}!={de}"
  61. cond = list()
  62. for k, v in condition.items():
  63. if not isinstance(k, str) or k in ["kb_id"] or not v:
  64. continue
  65. if field_keyword(k):
  66. if isinstance(v, list):
  67. inCond = list()
  68. for item in v:
  69. if isinstance(item, str):
  70. item = item.replace("'","''")
  71. inCond.append(f"filter_fulltext('{k}', '{item}')")
  72. if inCond:
  73. strInCond = " or ".join(inCond)
  74. strInCond = f"({strInCond})"
  75. cond.append(strInCond)
  76. else:
  77. cond.append(f"filter_fulltext('{k}', '{v}')")
  78. elif isinstance(v, list):
  79. inCond = list()
  80. for item in v:
  81. if isinstance(item, str):
  82. item = item.replace("'","''")
  83. inCond.append(f"'{item}'")
  84. else:
  85. inCond.append(str(item))
  86. if inCond:
  87. strInCond = ", ".join(inCond)
  88. strInCond = f"{k} IN ({strInCond})"
  89. cond.append(strInCond)
  90. elif k == "must_not":
  91. if isinstance(v, dict):
  92. for kk, vv in v.items():
  93. if kk == "exists":
  94. cond.append("NOT (%s)" % exists(vv))
  95. elif isinstance(v, str):
  96. cond.append(f"{k}='{v}'")
  97. elif k == "exists":
  98. cond.append(exists(v))
  99. else:
  100. cond.append(f"{k}={str(v)}")
  101. return " AND ".join(cond) if cond else "1=1"
  102. def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
  103. df_list2 = [df for df in df_list if not df.empty]
  104. if df_list2:
  105. return pd.concat(df_list2, axis=0).reset_index(drop=True)
  106. schema = []
  107. for field_name in selectFields:
  108. if field_name == 'score()': # Workaround: fix schema is changed to score()
  109. schema.append('SCORE')
  110. elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity()
  111. schema.append('SIMILARITY')
  112. else:
  113. schema.append(field_name)
  114. return pd.DataFrame(columns=schema)
  115. @singleton
  116. class InfinityConnection(DocStoreConnection):
  117. def __init__(self):
  118. self.dbName = settings.INFINITY.get("db_name", "default_db")
  119. infinity_uri = settings.INFINITY["uri"]
  120. if ":" in infinity_uri:
  121. host, port = infinity_uri.split(":")
  122. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  123. self.connPool = None
  124. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  125. for _ in range(24):
  126. try:
  127. connPool = ConnectionPool(infinity_uri, max_size=32)
  128. inf_conn = connPool.get_conn()
  129. res = inf_conn.show_current_node()
  130. if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
  131. self._migrate_db(inf_conn)
  132. self.connPool = connPool
  133. connPool.release_conn(inf_conn)
  134. break
  135. connPool.release_conn(inf_conn)
  136. logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
  137. time.sleep(5)
  138. except Exception as e:
  139. logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  140. time.sleep(5)
  141. if self.connPool is None:
  142. msg = f"Infinity {infinity_uri} is unhealthy in 120s."
  143. logger.error(msg)
  144. raise Exception(msg)
  145. logger.info(f"Infinity {infinity_uri} is healthy.")
  146. def _migrate_db(self, inf_conn):
  147. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  148. fp_mapping = os.path.join(
  149. get_project_base_directory(), "conf", "infinity_mapping.json"
  150. )
  151. if not os.path.exists(fp_mapping):
  152. raise Exception(f"Mapping file not found at {fp_mapping}")
  153. schema = json.load(open(fp_mapping))
  154. table_names = inf_db.list_tables().table_names
  155. for table_name in table_names:
  156. inf_table = inf_db.get_table(table_name)
  157. index_names = inf_table.list_indexes().index_names
  158. if "q_vec_idx" not in index_names:
  159. # Skip tables not created by me
  160. continue
  161. column_names = inf_table.show_columns()["name"]
  162. column_names = set(column_names)
  163. for field_name, field_info in schema.items():
  164. if field_name in column_names:
  165. continue
  166. res = inf_table.add_columns({field_name: field_info})
  167. assert res.error_code == infinity.ErrorCode.OK
  168. logger.info(
  169. f"INFINITY added following column to table {table_name}: {field_name} {field_info}"
  170. )
  171. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  172. continue
  173. inf_table.create_index(
  174. f"text_idx_{field_name}",
  175. IndexInfo(
  176. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  177. ),
  178. ConflictType.Ignore,
  179. )
  180. """
  181. Database operations
  182. """
  183. def dbType(self) -> str:
  184. return "infinity"
  185. def health(self) -> dict:
  186. """
  187. Return the health status of the database.
  188. """
  189. inf_conn = self.connPool.get_conn()
  190. res = inf_conn.show_current_node()
  191. self.connPool.release_conn(inf_conn)
  192. res2 = {
  193. "type": "infinity",
  194. "status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
  195. "error": res.error_msg,
  196. }
  197. return res2
  198. """
  199. Table operations
  200. """
  201. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  202. table_name = f"{indexName}_{knowledgebaseId}"
  203. inf_conn = self.connPool.get_conn()
  204. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  205. fp_mapping = os.path.join(
  206. get_project_base_directory(), "conf", "infinity_mapping.json"
  207. )
  208. if not os.path.exists(fp_mapping):
  209. raise Exception(f"Mapping file not found at {fp_mapping}")
  210. schema = json.load(open(fp_mapping))
  211. vector_name = f"q_{vectorSize}_vec"
  212. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  213. inf_table = inf_db.create_table(
  214. table_name,
  215. schema,
  216. ConflictType.Ignore,
  217. )
  218. inf_table.create_index(
  219. "q_vec_idx",
  220. IndexInfo(
  221. vector_name,
  222. IndexType.Hnsw,
  223. {
  224. "M": "16",
  225. "ef_construction": "50",
  226. "metric": "cosine",
  227. "encode": "lvq",
  228. },
  229. ),
  230. ConflictType.Ignore,
  231. )
  232. for field_name, field_info in schema.items():
  233. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  234. continue
  235. inf_table.create_index(
  236. f"text_idx_{field_name}",
  237. IndexInfo(
  238. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  239. ),
  240. ConflictType.Ignore,
  241. )
  242. self.connPool.release_conn(inf_conn)
  243. logger.info(
  244. f"INFINITY created table {table_name}, vector size {vectorSize}"
  245. )
  246. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  247. table_name = f"{indexName}_{knowledgebaseId}"
  248. inf_conn = self.connPool.get_conn()
  249. db_instance = inf_conn.get_database(self.dbName)
  250. db_instance.drop_table(table_name, ConflictType.Ignore)
  251. self.connPool.release_conn(inf_conn)
  252. logger.info(f"INFINITY dropped table {table_name}")
  253. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  254. table_name = f"{indexName}_{knowledgebaseId}"
  255. try:
  256. inf_conn = self.connPool.get_conn()
  257. db_instance = inf_conn.get_database(self.dbName)
  258. _ = db_instance.get_table(table_name)
  259. self.connPool.release_conn(inf_conn)
  260. return True
  261. except Exception as e:
  262. logger.warning(f"INFINITY indexExist {str(e)}")
  263. return False
  264. """
  265. CRUD operations
  266. """
  267. def search(
  268. self, selectFields: list[str],
  269. highlightFields: list[str],
  270. condition: dict,
  271. matchExprs: list[MatchExpr],
  272. orderBy: OrderByExpr,
  273. offset: int,
  274. limit: int,
  275. indexNames: str | list[str],
  276. knowledgebaseIds: list[str],
  277. aggFields: list[str] = [],
  278. rank_feature: dict | None = None
  279. ) -> tuple[pd.DataFrame, int]:
  280. """
  281. TODO: Infinity doesn't provide highlight
  282. """
  283. if isinstance(indexNames, str):
  284. indexNames = indexNames.split(",")
  285. assert isinstance(indexNames, list) and len(indexNames) > 0
  286. inf_conn = self.connPool.get_conn()
  287. db_instance = inf_conn.get_database(self.dbName)
  288. df_list = list()
  289. table_list = list()
  290. output = selectFields.copy()
  291. for essential_field in ["id"] + aggFields:
  292. if essential_field not in output:
  293. output.append(essential_field)
  294. score_func = ""
  295. score_column = ""
  296. for matchExpr in matchExprs:
  297. if isinstance(matchExpr, MatchTextExpr):
  298. score_func = "score()"
  299. score_column = "SCORE"
  300. break
  301. if not score_func:
  302. for matchExpr in matchExprs:
  303. if isinstance(matchExpr, MatchDenseExpr):
  304. score_func = "similarity()"
  305. score_column = "SIMILARITY"
  306. break
  307. if matchExprs:
  308. if score_func not in output:
  309. output.append(score_func)
  310. if PAGERANK_FLD not in output:
  311. output.append(PAGERANK_FLD)
  312. output = [f for f in output if f != "_score"]
  313. if limit <= 0:
  314. # ElasticSearch default limit is 10000
  315. limit = 10000
  316. # Prepare expressions common to all tables
  317. filter_cond = None
  318. filter_fulltext = ""
  319. if condition:
  320. table_found = False
  321. for indexName in indexNames:
  322. for kb_id in knowledgebaseIds:
  323. table_name = f"{indexName}_{kb_id}"
  324. try:
  325. filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
  326. table_found = True
  327. break
  328. except Exception:
  329. pass
  330. if table_found:
  331. break
  332. if not table_found:
  333. logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
  334. return pd.DataFrame(), 0
  335. for matchExpr in matchExprs:
  336. if isinstance(matchExpr, MatchTextExpr):
  337. if filter_cond and "filter" not in matchExpr.extra_options:
  338. matchExpr.extra_options.update({"filter": filter_cond})
  339. fields = ",".join(matchExpr.fields)
  340. filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  341. if filter_cond:
  342. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  343. minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
  344. if isinstance(minimum_should_match, float):
  345. str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  346. matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
  347. # Add rank_feature support
  348. if rank_feature and "rank_features" not in matchExpr.extra_options:
  349. # Convert rank_feature dict to Infinity's rank_features string format
  350. # Format: "field^feature_name^weight,field^feature_name^weight"
  351. rank_features_list = []
  352. for feature_name, weight in rank_feature.items():
  353. # Use TAG_FLD as the field containing rank features
  354. rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
  355. if rank_features_list:
  356. matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
  357. for k, v in matchExpr.extra_options.items():
  358. if not isinstance(v, str):
  359. matchExpr.extra_options[k] = str(v)
  360. logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
  361. elif isinstance(matchExpr, MatchDenseExpr):
  362. if filter_fulltext and "filter" not in matchExpr.extra_options:
  363. matchExpr.extra_options.update({"filter": filter_fulltext})
  364. for k, v in matchExpr.extra_options.items():
  365. if not isinstance(v, str):
  366. matchExpr.extra_options[k] = str(v)
  367. similarity = matchExpr.extra_options.get("similarity")
  368. if similarity:
  369. matchExpr.extra_options["threshold"] = similarity
  370. del matchExpr.extra_options["similarity"]
  371. logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
  372. elif isinstance(matchExpr, FusionExpr):
  373. logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
  374. order_by_expr_list = list()
  375. if orderBy.fields:
  376. for order_field in orderBy.fields:
  377. if order_field[1] == 0:
  378. order_by_expr_list.append((order_field[0], SortType.Asc))
  379. else:
  380. order_by_expr_list.append((order_field[0], SortType.Desc))
  381. total_hits_count = 0
  382. # Scatter search tables and gather the results
  383. for indexName in indexNames:
  384. for knowledgebaseId in knowledgebaseIds:
  385. table_name = f"{indexName}_{knowledgebaseId}"
  386. try:
  387. table_instance = db_instance.get_table(table_name)
  388. except Exception:
  389. continue
  390. table_list.append(table_name)
  391. builder = table_instance.output(output)
  392. if len(matchExprs) > 0:
  393. for matchExpr in matchExprs:
  394. if isinstance(matchExpr, MatchTextExpr):
  395. fields = ",".join(matchExpr.fields)
  396. builder = builder.match_text(
  397. fields,
  398. matchExpr.matching_text,
  399. matchExpr.topn,
  400. matchExpr.extra_options.copy(),
  401. )
  402. elif isinstance(matchExpr, MatchDenseExpr):
  403. builder = builder.match_dense(
  404. matchExpr.vector_column_name,
  405. matchExpr.embedding_data,
  406. matchExpr.embedding_data_type,
  407. matchExpr.distance_type,
  408. matchExpr.topn,
  409. matchExpr.extra_options.copy(),
  410. )
  411. elif isinstance(matchExpr, FusionExpr):
  412. builder = builder.fusion(
  413. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  414. )
  415. else:
  416. if filter_cond and len(filter_cond) > 0:
  417. builder.filter(filter_cond)
  418. if orderBy.fields:
  419. builder.sort(order_by_expr_list)
  420. builder.offset(offset).limit(limit)
  421. kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
  422. if extra_result:
  423. total_hits_count += int(extra_result["total_hits_count"])
  424. logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
  425. df_list.append(kb_res)
  426. self.connPool.release_conn(inf_conn)
  427. res = concat_dataframes(df_list, output)
  428. if matchExprs:
  429. res['Sum'] = res[score_column] + res[PAGERANK_FLD]
  430. res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum'])
  431. res = res.head(limit)
  432. logger.debug(f"INFINITY search final result: {str(res)}")
  433. return res, total_hits_count
  434. def get(
  435. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  436. ) -> dict | None:
  437. inf_conn = self.connPool.get_conn()
  438. db_instance = inf_conn.get_database(self.dbName)
  439. df_list = list()
  440. assert isinstance(knowledgebaseIds, list)
  441. table_list = list()
  442. for knowledgebaseId in knowledgebaseIds:
  443. table_name = f"{indexName}_{knowledgebaseId}"
  444. table_list.append(table_name)
  445. table_instance = None
  446. try:
  447. table_instance = db_instance.get_table(table_name)
  448. except Exception:
  449. logger.warning(
  450. f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
  451. continue
  452. kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
  453. logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
  454. df_list.append(kb_res)
  455. self.connPool.release_conn(inf_conn)
  456. res = concat_dataframes(df_list, ["id"])
  457. res_fields = self.getFields(res, res.columns.tolist())
  458. return res_fields.get(chunkId, None)
  459. def insert(
  460. self, documents: list[dict], indexName: str, knowledgebaseId: str = None
  461. ) -> list[str]:
  462. inf_conn = self.connPool.get_conn()
  463. db_instance = inf_conn.get_database(self.dbName)
  464. table_name = f"{indexName}_{knowledgebaseId}"
  465. try:
  466. table_instance = db_instance.get_table(table_name)
  467. except InfinityException as e:
  468. # src/common/status.cppm, kTableNotExist = 3022
  469. if e.error_code != ErrorCode.TABLE_NOT_EXIST:
  470. raise
  471. vector_size = 0
  472. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  473. for k in documents[0].keys():
  474. m = patt.match(k)
  475. if m:
  476. vector_size = int(m.group("vector_size"))
  477. break
  478. if vector_size == 0:
  479. raise ValueError("Cannot infer vector size from documents")
  480. self.createIdx(indexName, knowledgebaseId, vector_size)
  481. table_instance = db_instance.get_table(table_name)
  482. # embedding fields can't have a default value....
  483. embedding_clmns = []
  484. clmns = table_instance.show_columns().rows()
  485. for n, ty, _, _ in clmns:
  486. r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
  487. if not r:
  488. continue
  489. embedding_clmns.append((n, int(r.group(1))))
  490. docs = copy.deepcopy(documents)
  491. for d in docs:
  492. assert "_id" not in d
  493. assert "id" in d
  494. for k, v in d.items():
  495. if field_keyword(k):
  496. if isinstance(v, list):
  497. d[k] = "###".join(v)
  498. else:
  499. d[k] = v
  500. elif re.search(r"_feas$", k):
  501. d[k] = json.dumps(v)
  502. elif k == 'kb_id':
  503. if isinstance(d[k], list):
  504. d[k] = d[k][0] # since d[k] is a list, but we need a str
  505. elif k == "position_int":
  506. assert isinstance(v, list)
  507. arr = [num for row in v for num in row]
  508. d[k] = "_".join(f"{num:08x}" for num in arr)
  509. elif k in ["page_num_int", "top_int"]:
  510. assert isinstance(v, list)
  511. d[k] = "_".join(f"{num:08x}" for num in v)
  512. else:
  513. d[k] = v
  514. for n, vs in embedding_clmns:
  515. if n in d:
  516. continue
  517. d[n] = [0] * vs
  518. ids = ["'{}'".format(d["id"]) for d in docs]
  519. str_ids = ", ".join(ids)
  520. str_filter = f"id IN ({str_ids})"
  521. table_instance.delete(str_filter)
  522. # for doc in documents:
  523. # logger.info(f"insert position_int: {doc['position_int']}")
  524. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  525. table_instance.insert(docs)
  526. self.connPool.release_conn(inf_conn)
  527. logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
  528. return []
  529. def update(
  530. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  531. ) -> bool:
  532. # if 'position_int' in newValue:
  533. # logger.info(f"update position_int: {newValue['position_int']}")
  534. inf_conn = self.connPool.get_conn()
  535. db_instance = inf_conn.get_database(self.dbName)
  536. table_name = f"{indexName}_{knowledgebaseId}"
  537. table_instance = db_instance.get_table(table_name)
  538. #if "exists" in condition:
  539. # del condition["exists"]
  540. clmns = {}
  541. if table_instance:
  542. for n, ty, de, _ in table_instance.show_columns().rows():
  543. clmns[n] = (ty, de)
  544. filter = equivalent_condition_to_str(condition, table_instance)
  545. removeValue = {}
  546. for k, v in list(newValue.items()):
  547. if field_keyword(k):
  548. if isinstance(v, list):
  549. newValue[k] = "###".join(v)
  550. else:
  551. newValue[k] = v
  552. elif re.search(r"_feas$", k):
  553. newValue[k] = json.dumps(v)
  554. elif k == 'kb_id':
  555. if isinstance(newValue[k], list):
  556. newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
  557. elif k == "position_int":
  558. assert isinstance(v, list)
  559. arr = [num for row in v for num in row]
  560. newValue[k] = "_".join(f"{num:08x}" for num in arr)
  561. elif k in ["page_num_int", "top_int"]:
  562. assert isinstance(v, list)
  563. newValue[k] = "_".join(f"{num:08x}" for num in v)
  564. elif k == "remove":
  565. if isinstance(v, str):
  566. assert v in clmns, f"'{v}' should be in '{clmns}'."
  567. ty, de = clmns[v]
  568. if ty.lower().find("cha"):
  569. if not de:
  570. de = ""
  571. newValue[v] = de
  572. else:
  573. for kk, vv in v.items():
  574. removeValue[kk] = vv
  575. del newValue[k]
  576. else:
  577. newValue[k] = v
  578. remove_opt = {} # "[k,new_value]": [id_to_update, ...]
  579. if removeValue:
  580. col_to_remove = list(removeValue.keys())
  581. row_to_opt = table_instance.output(col_to_remove + ['id']).filter(filter).to_df()
  582. logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
  583. row_to_opt = self.getFields(row_to_opt, col_to_remove)
  584. for id, old_v in row_to_opt.items():
  585. for k, remove_v in removeValue.items():
  586. if remove_v in old_v[k]:
  587. new_v = old_v[k].copy()
  588. new_v.remove(remove_v)
  589. kv_key = json.dumps([k, new_v])
  590. if kv_key not in remove_opt:
  591. remove_opt[kv_key] = [id]
  592. else:
  593. remove_opt[kv_key].append(id)
  594. logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
  595. for update_kv, ids in remove_opt.items():
  596. k, v = json.loads(update_kv)
  597. table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k:"###".join(v)})
  598. table_instance.update(filter, newValue)
  599. self.connPool.release_conn(inf_conn)
  600. return True
  601. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  602. inf_conn = self.connPool.get_conn()
  603. db_instance = inf_conn.get_database(self.dbName)
  604. table_name = f"{indexName}_{knowledgebaseId}"
  605. try:
  606. table_instance = db_instance.get_table(table_name)
  607. except Exception:
  608. logger.warning(
  609. f"Skipped deleting from table {table_name} since the table doesn't exist."
  610. )
  611. return 0
  612. filter = equivalent_condition_to_str(condition, table_instance)
  613. logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
  614. res = table_instance.delete(filter)
  615. self.connPool.release_conn(inf_conn)
  616. return res.deleted_rows
  617. """
  618. Helper functions for search result
  619. """
  620. def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
  621. if isinstance(res, tuple):
  622. return res[1]
  623. return len(res)
  624. def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
  625. if isinstance(res, tuple):
  626. res = res[0]
  627. return list(res["id"])
  628. def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
  629. if isinstance(res, tuple):
  630. res = res[0]
  631. if not fields:
  632. return {}
  633. fieldsAll = fields.copy()
  634. fieldsAll.append('id')
  635. column_map = {col.lower(): col for col in res.columns}
  636. matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map}
  637. none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map]
  638. res2 = res[matched_columns.keys()]
  639. res2 = res2.rename(columns=matched_columns)
  640. res2.drop_duplicates(subset=['id'], inplace=True)
  641. for column in res2.columns:
  642. k = column.lower()
  643. if field_keyword(k):
  644. res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd])
  645. elif re.search(r"_feas$", k):
  646. res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
  647. elif k == "position_int":
  648. def to_position_int(v):
  649. if v:
  650. arr = [int(hex_val, 16) for hex_val in v.split('_')]
  651. v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
  652. else:
  653. v = []
  654. return v
  655. res2[column] = res2[column].apply(to_position_int)
  656. elif k in ["page_num_int", "top_int"]:
  657. res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else [])
  658. else:
  659. pass
  660. for column in none_columns:
  661. res2[column] = None
  662. return res2.set_index("id").to_dict(orient="index")
  663. def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
  664. if isinstance(res, tuple):
  665. res = res[0]
  666. ans = {}
  667. num_rows = len(res)
  668. column_id = res["id"]
  669. if fieldnm not in res:
  670. return {}
  671. for i in range(num_rows):
  672. id = column_id[i]
  673. txt = res[fieldnm][i]
  674. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  675. txts = []
  676. for t in re.split(r"[.?!;\n]", txt):
  677. for w in keywords:
  678. t = re.sub(
  679. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  680. % re.escape(w),
  681. r"\1<em>\2</em>\3",
  682. t,
  683. flags=re.IGNORECASE | re.MULTILINE,
  684. )
  685. if not re.search(
  686. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  687. ):
  688. continue
  689. txts.append(t)
  690. ans[id] = "...".join(txts)
  691. return ans
  692. def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
  693. """
  694. Manual aggregation for tag fields since Infinity doesn't provide native aggregation
  695. """
  696. from collections import Counter
  697. # Extract DataFrame from result
  698. if isinstance(res, tuple):
  699. df, _ = res
  700. else:
  701. df = res
  702. if df.empty or fieldnm not in df.columns:
  703. return []
  704. # Aggregate tag counts
  705. tag_counter = Counter()
  706. for value in df[fieldnm]:
  707. if pd.isna(value) or not value:
  708. continue
  709. # Handle different tag formats
  710. if isinstance(value, str):
  711. # Split by ### for tag_kwd field or comma for other formats
  712. if fieldnm == "tag_kwd" and "###" in value:
  713. tags = [tag.strip() for tag in value.split("###") if tag.strip()]
  714. else:
  715. # Try comma separation as fallback
  716. tags = [tag.strip() for tag in value.split(",") if tag.strip()]
  717. for tag in tags:
  718. if tag: # Only count non-empty tags
  719. tag_counter[tag] += 1
  720. elif isinstance(value, list):
  721. # Handle list format
  722. for tag in value:
  723. if tag and isinstance(tag, str):
  724. tag_counter[tag.strip()] += 1
  725. # Return as list of [tag, count] pairs, sorted by count descending
  726. return [[tag, count] for tag, count in tag_counter.most_common()]
  727. """
  728. SQL
  729. """
  730. def sql(sql: str, fetch_size: int, format: str):
  731. raise NotImplementedError("Not implemented")