You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

infinity_conn.py 27KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. import json
  20. import time
  21. import copy
  22. import infinity
  23. from infinity.common import ConflictType, InfinityException, SortType
  24. from infinity.index import IndexInfo, IndexType
  25. from infinity.connection_pool import ConnectionPool
  26. from infinity.errors import ErrorCode
  27. from rag import settings
  28. from rag.settings import PAGERANK_FLD
  29. from rag.utils import singleton
  30. import polars as pl
  31. from polars.series.series import Series
  32. from api.utils.file_utils import get_project_base_directory
  33. from rag.utils.doc_store_conn import (
  34. DocStoreConnection,
  35. MatchExpr,
  36. MatchTextExpr,
  37. MatchDenseExpr,
  38. FusionExpr,
  39. OrderByExpr,
  40. )
  41. logger = logging.getLogger('ragflow.infinity_conn')
  42. def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
  43. assert "_id" not in condition
  44. clmns = {}
  45. if table_instance:
  46. for n, ty, de, _ in table_instance.show_columns().rows():
  47. clmns[n] = (ty, de)
  48. def exists(cln):
  49. nonlocal clmns
  50. assert cln in clmns, f"'{cln}' should be in '{clmns}'."
  51. ty, de = clmns[cln]
  52. if ty.lower().find("cha"):
  53. if not de:
  54. de = ""
  55. return f" {cln}!='{de}' "
  56. return f"{cln}!={de}"
  57. cond = list()
  58. for k, v in condition.items():
  59. if not isinstance(k, str) or k in ["kb_id"] or not v:
  60. continue
  61. if isinstance(v, list):
  62. inCond = list()
  63. for item in v:
  64. if isinstance(item, str):
  65. inCond.append(f"'{item}'")
  66. else:
  67. inCond.append(str(item))
  68. if inCond:
  69. strInCond = ", ".join(inCond)
  70. strInCond = f"{k} IN ({strInCond})"
  71. cond.append(strInCond)
  72. elif k == "must_not":
  73. if isinstance(v, dict):
  74. for kk, vv in v.items():
  75. if kk == "exists":
  76. cond.append("NOT (%s)" % exists(vv))
  77. elif isinstance(v, str):
  78. cond.append(f"{k}='{v}'")
  79. elif k == "exists":
  80. cond.append(exists(v))
  81. else:
  82. cond.append(f"{k}={str(v)}")
  83. return " AND ".join(cond) if cond else "1=1"
  84. def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
  85. """
  86. Concatenate multiple dataframes into one.
  87. """
  88. df_list = [df for df in df_list if not df.is_empty()]
  89. if df_list:
  90. return pl.concat(df_list)
  91. schema = dict()
  92. for field_name in selectFields:
  93. if field_name == 'score()': # Workaround: fix schema is changed to score()
  94. schema['SCORE'] = str
  95. else:
  96. schema[field_name] = str
  97. return pl.DataFrame(schema=schema)
  98. @singleton
  99. class InfinityConnection(DocStoreConnection):
  100. def __init__(self):
  101. self.dbName = settings.INFINITY.get("db_name", "default_db")
  102. infinity_uri = settings.INFINITY["uri"]
  103. if ":" in infinity_uri:
  104. host, port = infinity_uri.split(":")
  105. infinity_uri = infinity.common.NetworkAddress(host, int(port))
  106. self.connPool = None
  107. logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
  108. for _ in range(24):
  109. try:
  110. connPool = ConnectionPool(infinity_uri)
  111. inf_conn = connPool.get_conn()
  112. res = inf_conn.show_current_node()
  113. if res.error_code == ErrorCode.OK and res.server_status == "started":
  114. self._migrate_db(inf_conn)
  115. self.connPool = connPool
  116. connPool.release_conn(inf_conn)
  117. break
  118. connPool.release_conn(inf_conn)
  119. logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
  120. time.sleep(5)
  121. except Exception as e:
  122. logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
  123. time.sleep(5)
  124. if self.connPool is None:
  125. msg = f"Infinity {infinity_uri} is unhealthy in 120s."
  126. logger.error(msg)
  127. raise Exception(msg)
  128. logger.info(f"Infinity {infinity_uri} is healthy.")
  129. def _migrate_db(self, inf_conn):
  130. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  131. fp_mapping = os.path.join(
  132. get_project_base_directory(), "conf", "infinity_mapping.json"
  133. )
  134. if not os.path.exists(fp_mapping):
  135. raise Exception(f"Mapping file not found at {fp_mapping}")
  136. schema = json.load(open(fp_mapping))
  137. table_names = inf_db.list_tables().table_names
  138. for table_name in table_names:
  139. inf_table = inf_db.get_table(table_name)
  140. index_names = inf_table.list_indexes().index_names
  141. if "q_vec_idx" not in index_names:
  142. # Skip tables not created by me
  143. continue
  144. column_names = inf_table.show_columns()["name"]
  145. column_names = set(column_names)
  146. for field_name, field_info in schema.items():
  147. if field_name in column_names:
  148. continue
  149. res = inf_table.add_columns({field_name: field_info})
  150. assert res.error_code == infinity.ErrorCode.OK
  151. logger.info(
  152. f"INFINITY added following column to table {table_name}: {field_name} {field_info}"
  153. )
  154. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  155. continue
  156. inf_table.create_index(
  157. f"text_idx_{field_name}",
  158. IndexInfo(
  159. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  160. ),
  161. ConflictType.Ignore,
  162. )
  163. """
  164. Database operations
  165. """
  166. def dbType(self) -> str:
  167. return "infinity"
  168. def health(self) -> dict:
  169. """
  170. Return the health status of the database.
  171. """
  172. inf_conn = self.connPool.get_conn()
  173. res = inf_conn.show_current_node()
  174. self.connPool.release_conn(inf_conn)
  175. res2 = {
  176. "type": "infinity",
  177. "status": "green" if res.error_code == 0 and res.server_status == "started" else "red",
  178. "error": res.error_msg,
  179. }
  180. return res2
  181. """
  182. Table operations
  183. """
  184. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  185. table_name = f"{indexName}_{knowledgebaseId}"
  186. inf_conn = self.connPool.get_conn()
  187. inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
  188. fp_mapping = os.path.join(
  189. get_project_base_directory(), "conf", "infinity_mapping.json"
  190. )
  191. if not os.path.exists(fp_mapping):
  192. raise Exception(f"Mapping file not found at {fp_mapping}")
  193. schema = json.load(open(fp_mapping))
  194. vector_name = f"q_{vectorSize}_vec"
  195. schema[vector_name] = {"type": f"vector,{vectorSize},float"}
  196. inf_table = inf_db.create_table(
  197. table_name,
  198. schema,
  199. ConflictType.Ignore,
  200. )
  201. inf_table.create_index(
  202. "q_vec_idx",
  203. IndexInfo(
  204. vector_name,
  205. IndexType.Hnsw,
  206. {
  207. "M": "16",
  208. "ef_construction": "50",
  209. "metric": "cosine",
  210. "encode": "lvq",
  211. },
  212. ),
  213. ConflictType.Ignore,
  214. )
  215. for field_name, field_info in schema.items():
  216. if field_info["type"] != "varchar" or "analyzer" not in field_info:
  217. continue
  218. inf_table.create_index(
  219. f"text_idx_{field_name}",
  220. IndexInfo(
  221. field_name, IndexType.FullText, {"ANALYZER": field_info["analyzer"]}
  222. ),
  223. ConflictType.Ignore,
  224. )
  225. self.connPool.release_conn(inf_conn)
  226. logger.info(
  227. f"INFINITY created table {table_name}, vector size {vectorSize}"
  228. )
  229. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  230. table_name = f"{indexName}_{knowledgebaseId}"
  231. inf_conn = self.connPool.get_conn()
  232. db_instance = inf_conn.get_database(self.dbName)
  233. db_instance.drop_table(table_name, ConflictType.Ignore)
  234. self.connPool.release_conn(inf_conn)
  235. logger.info(f"INFINITY dropped table {table_name}")
  236. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  237. table_name = f"{indexName}_{knowledgebaseId}"
  238. try:
  239. inf_conn = self.connPool.get_conn()
  240. db_instance = inf_conn.get_database(self.dbName)
  241. _ = db_instance.get_table(table_name)
  242. self.connPool.release_conn(inf_conn)
  243. return True
  244. except Exception as e:
  245. logger.warning(f"INFINITY indexExist {str(e)}")
  246. return False
  247. """
  248. CRUD operations
  249. """
  250. def search(
  251. self, selectFields: list[str],
  252. highlightFields: list[str],
  253. condition: dict,
  254. matchExprs: list[MatchExpr],
  255. orderBy: OrderByExpr,
  256. offset: int,
  257. limit: int,
  258. indexNames: str | list[str],
  259. knowledgebaseIds: list[str],
  260. aggFields: list[str] = [],
  261. rank_feature: dict | None = None
  262. ) -> list[dict] | pl.DataFrame:
  263. """
  264. TODO: Infinity doesn't provide highlight
  265. """
  266. if isinstance(indexNames, str):
  267. indexNames = indexNames.split(",")
  268. assert isinstance(indexNames, list) and len(indexNames) > 0
  269. inf_conn = self.connPool.get_conn()
  270. db_instance = inf_conn.get_database(self.dbName)
  271. df_list = list()
  272. table_list = list()
  273. for essential_field in ["id"]:
  274. if essential_field not in selectFields:
  275. selectFields.append(essential_field)
  276. score_func = ""
  277. score_column = ""
  278. for matchExpr in matchExprs:
  279. if isinstance(matchExpr, MatchTextExpr):
  280. score_func = "score()"
  281. score_column = "SCORE"
  282. break
  283. if not score_func:
  284. for matchExpr in matchExprs:
  285. if isinstance(matchExpr, MatchDenseExpr):
  286. score_func = "similarity()"
  287. score_column = "SIMILARITY"
  288. break
  289. if matchExprs:
  290. selectFields.append(score_func)
  291. selectFields.append(PAGERANK_FLD)
  292. selectFields = [f for f in selectFields if f != "_score"]
  293. # Prepare expressions common to all tables
  294. filter_cond = None
  295. filter_fulltext = ""
  296. if condition:
  297. for indexName in indexNames:
  298. table_name = f"{indexName}_{knowledgebaseIds[0]}"
  299. filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
  300. break
  301. for matchExpr in matchExprs:
  302. if isinstance(matchExpr, MatchTextExpr):
  303. if filter_cond and "filter" not in matchExpr.extra_options:
  304. matchExpr.extra_options.update({"filter": filter_cond})
  305. fields = ",".join(matchExpr.fields)
  306. filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
  307. if filter_cond:
  308. filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
  309. minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
  310. if isinstance(minimum_should_match, float):
  311. str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
  312. matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
  313. for k, v in matchExpr.extra_options.items():
  314. if not isinstance(v, str):
  315. matchExpr.extra_options[k] = str(v)
  316. logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
  317. elif isinstance(matchExpr, MatchDenseExpr):
  318. if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options:
  319. matchExpr.extra_options.update({"filter": filter_fulltext})
  320. for k, v in matchExpr.extra_options.items():
  321. if not isinstance(v, str):
  322. matchExpr.extra_options[k] = str(v)
  323. logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
  324. elif isinstance(matchExpr, FusionExpr):
  325. logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
  326. order_by_expr_list = list()
  327. if orderBy.fields:
  328. for order_field in orderBy.fields:
  329. if order_field[1] == 0:
  330. order_by_expr_list.append((order_field[0], SortType.Asc))
  331. else:
  332. order_by_expr_list.append((order_field[0], SortType.Desc))
  333. total_hits_count = 0
  334. # Scatter search tables and gather the results
  335. for indexName in indexNames:
  336. for knowledgebaseId in knowledgebaseIds:
  337. table_name = f"{indexName}_{knowledgebaseId}"
  338. try:
  339. table_instance = db_instance.get_table(table_name)
  340. except Exception:
  341. continue
  342. table_list.append(table_name)
  343. builder = table_instance.output(selectFields)
  344. if len(matchExprs) > 0:
  345. for matchExpr in matchExprs:
  346. if isinstance(matchExpr, MatchTextExpr):
  347. fields = ",".join(matchExpr.fields)
  348. builder = builder.match_text(
  349. fields,
  350. matchExpr.matching_text,
  351. matchExpr.topn,
  352. matchExpr.extra_options,
  353. )
  354. elif isinstance(matchExpr, MatchDenseExpr):
  355. builder = builder.match_dense(
  356. matchExpr.vector_column_name,
  357. matchExpr.embedding_data,
  358. matchExpr.embedding_data_type,
  359. matchExpr.distance_type,
  360. matchExpr.topn,
  361. matchExpr.extra_options,
  362. )
  363. elif isinstance(matchExpr, FusionExpr):
  364. builder = builder.fusion(
  365. matchExpr.method, matchExpr.topn, matchExpr.fusion_params
  366. )
  367. else:
  368. if len(filter_cond) > 0:
  369. builder.filter(filter_cond)
  370. if orderBy.fields:
  371. builder.sort(order_by_expr_list)
  372. builder.offset(offset).limit(limit)
  373. kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
  374. if extra_result:
  375. total_hits_count += int(extra_result["total_hits_count"])
  376. logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
  377. df_list.append(kb_res)
  378. self.connPool.release_conn(inf_conn)
  379. res = concat_dataframes(df_list, selectFields)
  380. if matchExprs:
  381. res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
  382. if score_column and score_column != "SCORE":
  383. res = res.rename({score_column: "_score"})
  384. res = res.limit(limit)
  385. logger.debug(f"INFINITY search final result: {str(res)}")
  386. return res, total_hits_count
  387. def get(
  388. self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
  389. ) -> dict | None:
  390. inf_conn = self.connPool.get_conn()
  391. db_instance = inf_conn.get_database(self.dbName)
  392. df_list = list()
  393. assert isinstance(knowledgebaseIds, list)
  394. table_list = list()
  395. for knowledgebaseId in knowledgebaseIds:
  396. table_name = f"{indexName}_{knowledgebaseId}"
  397. table_list.append(table_name)
  398. table_instance = None
  399. try:
  400. table_instance = db_instance.get_table(table_name)
  401. except Exception:
  402. logger.warning(
  403. f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.")
  404. continue
  405. kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
  406. logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
  407. df_list.append(kb_res)
  408. self.connPool.release_conn(inf_conn)
  409. res = concat_dataframes(df_list, ["id"])
  410. res_fields = self.getFields(res, res.columns)
  411. return res_fields.get(chunkId, None)
  412. def insert(
  413. self, documents: list[dict], indexName: str, knowledgebaseId: str = None
  414. ) -> list[str]:
  415. inf_conn = self.connPool.get_conn()
  416. db_instance = inf_conn.get_database(self.dbName)
  417. table_name = f"{indexName}_{knowledgebaseId}"
  418. try:
  419. table_instance = db_instance.get_table(table_name)
  420. except InfinityException as e:
  421. # src/common/status.cppm, kTableNotExist = 3022
  422. if e.error_code != ErrorCode.TABLE_NOT_EXIST:
  423. raise
  424. vector_size = 0
  425. patt = re.compile(r"q_(?P<vector_size>\d+)_vec")
  426. for k in documents[0].keys():
  427. m = patt.match(k)
  428. if m:
  429. vector_size = int(m.group("vector_size"))
  430. break
  431. if vector_size == 0:
  432. raise ValueError("Cannot infer vector size from documents")
  433. self.createIdx(indexName, knowledgebaseId, vector_size)
  434. table_instance = db_instance.get_table(table_name)
  435. # embedding fields can't have a default value....
  436. embedding_clmns = []
  437. clmns = table_instance.show_columns().rows()
  438. for n, ty, _, _ in clmns:
  439. r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
  440. if not r:
  441. continue
  442. embedding_clmns.append((n, int(r.group(1))))
  443. docs = copy.deepcopy(documents)
  444. for d in docs:
  445. assert "_id" not in d
  446. assert "id" in d
  447. for k, v in d.items():
  448. if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
  449. assert isinstance(v, list)
  450. d[k] = "###".join(v)
  451. elif re.search(r"_feas$", k):
  452. d[k] = json.dumps(v)
  453. elif k == 'kb_id':
  454. if isinstance(d[k], list):
  455. d[k] = d[k][0] # since d[k] is a list, but we need a str
  456. elif k == "position_int":
  457. assert isinstance(v, list)
  458. arr = [num for row in v for num in row]
  459. d[k] = "_".join(f"{num:08x}" for num in arr)
  460. elif k in ["page_num_int", "top_int"]:
  461. assert isinstance(v, list)
  462. d[k] = "_".join(f"{num:08x}" for num in v)
  463. for n, vs in embedding_clmns:
  464. if n in d:
  465. continue
  466. d[n] = [0] * vs
  467. ids = ["'{}'".format(d["id"]) for d in docs]
  468. str_ids = ", ".join(ids)
  469. str_filter = f"id IN ({str_ids})"
  470. table_instance.delete(str_filter)
  471. # for doc in documents:
  472. # logger.info(f"insert position_int: {doc['position_int']}")
  473. # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
  474. table_instance.insert(docs)
  475. self.connPool.release_conn(inf_conn)
  476. logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
  477. return []
  478. def update(
  479. self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
  480. ) -> bool:
  481. # if 'position_int' in newValue:
  482. # logger.info(f"update position_int: {newValue['position_int']}")
  483. inf_conn = self.connPool.get_conn()
  484. db_instance = inf_conn.get_database(self.dbName)
  485. table_name = f"{indexName}_{knowledgebaseId}"
  486. table_instance = db_instance.get_table(table_name)
  487. #if "exists" in condition:
  488. # del condition["exists"]
  489. filter = equivalent_condition_to_str(condition, table_instance)
  490. for k, v in list(newValue.items()):
  491. if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
  492. assert isinstance(v, list)
  493. newValue[k] = "###".join(v)
  494. elif re.search(r"_feas$", k):
  495. newValue[k] = json.dumps(v)
  496. elif k.endswith("_kwd") and isinstance(v, list):
  497. newValue[k] = " ".join(v)
  498. elif k == 'kb_id':
  499. if isinstance(newValue[k], list):
  500. newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
  501. elif k == "position_int":
  502. assert isinstance(v, list)
  503. arr = [num for row in v for num in row]
  504. newValue[k] = "_".join(f"{num:08x}" for num in arr)
  505. elif k in ["page_num_int", "top_int"]:
  506. assert isinstance(v, list)
  507. newValue[k] = "_".join(f"{num:08x}" for num in v)
  508. elif k == "remove":
  509. del newValue[k]
  510. if v in [PAGERANK_FLD]:
  511. newValue[v] = 0
  512. logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
  513. table_instance.update(filter, newValue)
  514. self.connPool.release_conn(inf_conn)
  515. return True
  516. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  517. inf_conn = self.connPool.get_conn()
  518. db_instance = inf_conn.get_database(self.dbName)
  519. table_name = f"{indexName}_{knowledgebaseId}"
  520. try:
  521. table_instance = db_instance.get_table(table_name)
  522. except Exception:
  523. logger.warning(
  524. f"Skipped deleting from table {table_name} since the table doesn't exist."
  525. )
  526. return 0
  527. filter = equivalent_condition_to_str(condition, table_instance)
  528. logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
  529. res = table_instance.delete(filter)
  530. self.connPool.release_conn(inf_conn)
  531. return res.deleted_rows
  532. """
  533. Helper functions for search result
  534. """
  535. def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int:
  536. if isinstance(res, tuple):
  537. return res[1]
  538. return len(res)
  539. def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]:
  540. if isinstance(res, tuple):
  541. res = res[0]
  542. return list(res["id"])
  543. def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]:
  544. if isinstance(res, tuple):
  545. res = res[0]
  546. res_fields = {}
  547. if not fields:
  548. return {}
  549. num_rows = len(res)
  550. column_id = res["id"]
  551. for i in range(num_rows):
  552. id = column_id[i]
  553. m = {"id": id}
  554. for fieldnm in fields:
  555. if fieldnm not in res:
  556. m[fieldnm] = None
  557. continue
  558. v = res[fieldnm][i]
  559. if isinstance(v, Series):
  560. v = list(v)
  561. elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]:
  562. assert isinstance(v, str)
  563. v = [kwd for kwd in v.split("###") if kwd]
  564. elif fieldnm == "position_int":
  565. assert isinstance(v, str)
  566. if v:
  567. arr = [int(hex_val, 16) for hex_val in v.split('_')]
  568. v = [arr[i:i + 5] for i in range(0, len(arr), 5)]
  569. else:
  570. v = []
  571. elif fieldnm in ["page_num_int", "top_int"]:
  572. assert isinstance(v, str)
  573. if v:
  574. v = [int(hex_val, 16) for hex_val in v.split('_')]
  575. else:
  576. v = []
  577. else:
  578. if not isinstance(v, str):
  579. v = str(v)
  580. # if fieldnm.endswith("_tks"):
  581. # v = rmSpace(v)
  582. m[fieldnm] = v
  583. res_fields[id] = m
  584. return res_fields
  585. def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str):
  586. if isinstance(res, tuple):
  587. res = res[0]
  588. ans = {}
  589. num_rows = len(res)
  590. column_id = res["id"]
  591. if fieldnm not in res:
  592. return {}
  593. for i in range(num_rows):
  594. id = column_id[i]
  595. txt = res[fieldnm][i]
  596. txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
  597. txts = []
  598. for t in re.split(r"[.?!;\n]", txt):
  599. for w in keywords:
  600. t = re.sub(
  601. r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"
  602. % re.escape(w),
  603. r"\1<em>\2</em>\3",
  604. t,
  605. flags=re.IGNORECASE | re.MULTILINE,
  606. )
  607. if not re.search(
  608. r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
  609. ):
  610. continue
  611. txts.append(t)
  612. ans[id] = "...".join(txts)
  613. return ans
  614. def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str):
  615. """
  616. TODO: Infinity doesn't provide aggregation
  617. """
  618. return list()
  619. """
  620. SQL
  621. """
  622. def sql(sql: str, fetch_size: int, format: str):
  623. raise NotImplementedError("Not implemented")