| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- from copy import deepcopy
-
- import pandas as pd
- from elasticsearch_dsl import Q, Search
-
- from rag.nlp.search import Dealer
-
-
- class KGSearch(Dealer):
- def search(self, req, idxnm, emb_mdl=None):
- def merge_into_first(sres, title=""):
- df,texts = [],[]
- for d in sres["hits"]["hits"]:
- try:
- df.append(json.loads(d["_source"]["content_with_weight"]))
- except Exception as e:
- texts.append(d["_source"]["content_with_weight"])
- pass
- if not df and not texts: return False
- if df:
- try:
- sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
- except Exception as e:
- pass
- else:
- sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
- return True
-
- src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
- "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
- "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
- "weight_int", "weight_flt", "rank_int"
- ])
-
- qst = req.get("question", "")
- binary_query, keywords = self.qryr.question(qst, min_match="5%")
- binary_query = self._add_filters(binary_query, req)
-
- ## Entity retrieval
- bqry = deepcopy(binary_query)
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
- s = Search()
- s = s.query(bqry)[0: 32]
-
- s = s.to_dict()
- q_vec = []
- if req.get("vector"):
- assert emb_mdl, "No embedding model selected"
- s["knn"] = self._vector(
- qst, emb_mdl, req.get(
- "similarity", 0.1), 1024)
- s["knn"]["filter"] = bqry.to_dict()
- q_vec = s["knn"]["query_vector"]
-
- ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
- entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
- ent_ids = self.es.getDocIds(ent_res)
- if merge_into_first(ent_res, "-Entities-"):
- ent_ids = ent_ids[0:1]
-
- ## Community retrieval
- bqry = deepcopy(binary_query)
- bqry.filter.append(Q("terms", entities_kwd=entities))
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
- s = Search()
- s = s.query(bqry)[0: 32]
- s = s.to_dict()
- comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
- comm_ids = self.es.getDocIds(comm_res)
- if merge_into_first(comm_res, "-Community Report-"):
- comm_ids = comm_ids[0:1]
-
- ## Text content retrieval
- bqry = deepcopy(binary_query)
- bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
- s = Search()
- s = s.query(bqry)[0: 6]
- s = s.to_dict()
- txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
- txt_ids = self.es.getDocIds(txt_res)
- if merge_into_first(txt_res, "-Original Content-"):
- txt_ids = txt_ids[0:1]
-
- return self.SearchResult(
- total=len(ent_ids) + len(comm_ids) + len(txt_ids),
- ids=[*ent_ids, *comm_ids, *txt_ids],
- query_vector=q_vec,
- aggregation=None,
- highlight=None,
- field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
- keywords=[]
- )
|