You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from copy import deepcopy
  18. from typing import Dict
  19. import pandas as pd
  20. from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
  21. from rag.nlp.search import Dealer
  22. class KGSearch(Dealer):
  23. def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
  24. def merge_into_first(sres, title="") -> Dict[str, str]:
  25. if not sres:
  26. return {}
  27. content_with_weight = ""
  28. df, texts = [],[]
  29. for d in sres.values():
  30. try:
  31. df.append(json.loads(d["content_with_weight"]))
  32. except Exception:
  33. texts.append(d["content_with_weight"])
  34. if df:
  35. content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
  36. else:
  37. content_with_weight = title + "\n" + "\n".join(texts)
  38. first_id = ""
  39. first_source = {}
  40. for k, v in sres.items():
  41. first_id = id
  42. first_source = deepcopy(v)
  43. break
  44. first_source["content_with_weight"] = content_with_weight
  45. first_id = next(iter(sres))
  46. return {first_id: first_source}
  47. qst = req.get("question", "")
  48. matchText, keywords = self.qryr.question(qst, min_match=0.05)
  49. condition = self.get_filters(req)
  50. ## Entity retrieval
  51. condition.update({"knowledge_graph_kwd": ["entity"]})
  52. assert emb_mdl, "No embedding model selected"
  53. matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
  54. q_vec = matchDense.embedding_data
  55. src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
  56. "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
  57. "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
  58. "weight_int", "weight_flt", "rank_int"
  59. ])
  60. fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
  61. ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
  62. ent_res_fields = self.dataStore.getFields(ent_res, src)
  63. entities = [d["name_kwd"] for d in ent_res_fields.values()]
  64. ent_ids = self.dataStore.getChunkIds(ent_res)
  65. ent_content = merge_into_first(ent_res_fields, "-Entities-")
  66. if ent_content:
  67. ent_ids = list(ent_content.keys())
  68. ## Community retrieval
  69. condition = self.get_filters(req)
  70. condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
  71. comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
  72. comm_res_fields = self.dataStore.getFields(comm_res, src)
  73. comm_ids = self.dataStore.getChunkIds(comm_res)
  74. comm_content = merge_into_first(comm_res_fields, "-Community Report-")
  75. if comm_content:
  76. comm_ids = list(comm_content.keys())
  77. ## Text content retrieval
  78. condition = self.get_filters(req)
  79. condition.update({"knowledge_graph_kwd": ["text"]})
  80. txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
  81. txt_res_fields = self.dataStore.getFields(txt_res, src)
  82. txt_ids = self.dataStore.getChunkIds(txt_res)
  83. txt_content = merge_into_first(txt_res_fields, "-Original Content-")
  84. if txt_content:
  85. txt_ids = list(txt_content.keys())
  86. return self.SearchResult(
  87. total=len(ent_ids) + len(comm_ids) + len(txt_ids),
  88. ids=[*ent_ids, *comm_ids, *txt_ids],
  89. query_vector=q_vec,
  90. highlight=None,
  91. field={**ent_content, **comm_content, **txt_content},
  92. keywords=[]
  93. )