You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from copy import deepcopy
  18. import pandas as pd
  19. from elasticsearch_dsl import Q, Search
  20. from rag.nlp.search import Dealer
  21. class KGSearch(Dealer):
  22. def search(self, req, idxnm, emb_mdl=None):
  23. def merge_into_first(sres, title=""):
  24. df,texts = [],[]
  25. for d in sres["hits"]["hits"]:
  26. try:
  27. df.append(json.loads(d["_source"]["content_with_weight"]))
  28. except Exception as e:
  29. texts.append(d["_source"]["content_with_weight"])
  30. pass
  31. if not df and not texts: return False
  32. if df:
  33. try:
  34. sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
  35. except Exception as e:
  36. pass
  37. else:
  38. sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
  39. return True
  40. src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
  41. "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
  42. "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
  43. "weight_int", "weight_flt", "rank_int"
  44. ])
  45. qst = req.get("question", "")
  46. binary_query, keywords = self.qryr.question(qst, min_match="5%")
  47. binary_query = self._add_filters(binary_query, req)
  48. ## Entity retrieval
  49. bqry = deepcopy(binary_query)
  50. bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
  51. s = Search()
  52. s = s.query(bqry)[0: 32]
  53. s = s.to_dict()
  54. q_vec = []
  55. if req.get("vector"):
  56. assert emb_mdl, "No embedding model selected"
  57. s["knn"] = self._vector(
  58. qst, emb_mdl, req.get(
  59. "similarity", 0.1), 1024)
  60. s["knn"]["filter"] = bqry.to_dict()
  61. q_vec = s["knn"]["query_vector"]
  62. ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
  63. entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
  64. ent_ids = self.es.getDocIds(ent_res)
  65. if merge_into_first(ent_res, "-Entities-"):
  66. ent_ids = ent_ids[0:1]
  67. ## Community retrieval
  68. bqry = deepcopy(binary_query)
  69. bqry.filter.append(Q("terms", entities_kwd=entities))
  70. bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
  71. s = Search()
  72. s = s.query(bqry)[0: 32]
  73. s = s.to_dict()
  74. comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
  75. comm_ids = self.es.getDocIds(comm_res)
  76. if merge_into_first(comm_res, "-Community Report-"):
  77. comm_ids = comm_ids[0:1]
  78. ## Text content retrieval
  79. bqry = deepcopy(binary_query)
  80. bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
  81. s = Search()
  82. s = s.query(bqry)[0: 6]
  83. s = s.to_dict()
  84. txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
  85. txt_ids = self.es.getDocIds(comm_res)
  86. if merge_into_first(txt_res, "-Original Content-"):
  87. txt_ids = comm_ids[0:1]
  88. return self.SearchResult(
  89. total=len(ent_ids) + len(comm_ids) + len(txt_ids),
  90. ids=[*ent_ids, *comm_ids, *txt_ids],
  91. query_vector=q_vec,
  92. aggregation=None,
  93. highlight=None,
  94. field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
  95. keywords=[]
  96. )