選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

search.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. from collections import defaultdict
  19. from copy import deepcopy
  20. import json_repair
  21. import pandas as pd
  22. from api.utils import get_uuid
  23. from graphrag.query_analyze_prompt import PROMPTS
  24. from graphrag.utils import get_entity_type2sampels, get_llm_cache, set_llm_cache, get_relation
  25. from rag.utils import num_tokens_from_string
  26. from rag.utils.doc_store_conn import OrderByExpr
  27. from rag.nlp.search import Dealer, index_name
  28. class KGSearch(Dealer):
  29. def _chat(self, llm_bdl, system, history, gen_conf):
  30. response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
  31. if response:
  32. return response
  33. response = llm_bdl.chat(system, history, gen_conf)
  34. if response.find("**ERROR**") >= 0:
  35. raise Exception(response)
  36. set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
  37. return response
  38. def query_rewrite(self, llm, question, idxnms, kb_ids):
  39. ty2ents = get_entity_type2sampels(idxnms, kb_ids)
  40. hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
  41. TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
  42. result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {"temperature": .5})
  43. try:
  44. keywords_data = json_repair.loads(result)
  45. type_keywords = keywords_data.get("answer_type_keywords", [])
  46. entities_from_query = keywords_data.get("entities_from_query", [])[:5]
  47. return type_keywords, entities_from_query
  48. except json_repair.JSONDecodeError:
  49. try:
  50. result = result.replace(hint_prompt[:-1], '').replace('user', '').replace('model', '').strip()
  51. result = '{' + result.split('{')[1].split('}')[0] + '}'
  52. keywords_data = json_repair.loads(result)
  53. type_keywords = keywords_data.get("answer_type_keywords", [])
  54. entities_from_query = keywords_data.get("entities_from_query", [])[:5]
  55. return type_keywords, entities_from_query
  56. # Handle parsing error
  57. except Exception as e:
  58. logging.exception(f"JSON parsing error: {result} -> {e}")
  59. raise e
  60. def _ent_info_from_(self, es_res, sim_thr=0.3):
  61. res = {}
  62. es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "entity_kwd", "rank_flt",
  63. "n_hop_with_weight"])
  64. for _, ent in es_res.items():
  65. if float(ent.get("_score", 0)) < sim_thr:
  66. continue
  67. if isinstance(ent["entity_kwd"], list):
  68. ent["entity_kwd"] = ent["entity_kwd"][0]
  69. res[ent["entity_kwd"]] = {
  70. "sim": float(ent.get("_score", 0)),
  71. "pagerank": float(ent.get("rank_flt", 0)),
  72. "n_hop_ents": json.loads(ent.get("n_hop_with_weight", "[]")),
  73. "description": ent.get("content_with_weight", "{}")
  74. }
  75. return res
  76. def _relation_info_from_(self, es_res, sim_thr=0.3):
  77. res = {}
  78. es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
  79. "weight_int"])
  80. for _, ent in es_res.items():
  81. if float(ent["_score"]) < sim_thr:
  82. continue
  83. f, t = sorted([ent["from_entity_kwd"], ent["to_entity_kwd"]])
  84. if isinstance(f, list):
  85. f = f[0]
  86. if isinstance(t, list):
  87. t = t[0]
  88. res[(f, t)] = {
  89. "sim": float(ent["_score"]),
  90. "pagerank": float(ent.get("weight_int", 0)),
  91. "description": ent["content_with_weight"]
  92. }
  93. return res
  94. def get_relevant_ents_by_keywords(self, keywords, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
  95. if not keywords:
  96. return {}
  97. filters = deepcopy(filters)
  98. filters["knowledge_graph_kwd"] = "entity"
  99. matchDense = self.get_vector(", ".join(keywords), emb_mdl, 1024, sim_thr)
  100. es_res = self.dataStore.search(["content_with_weight", "entity_kwd", "rank_flt"], [], filters, [matchDense],
  101. OrderByExpr(), 0, N,
  102. idxnms, kb_ids)
  103. return self._ent_info_from_(es_res, sim_thr)
  104. def get_relevant_relations_by_txt(self, txt, filters, idxnms, kb_ids, emb_mdl, sim_thr=0.3, N=56):
  105. if not txt:
  106. return {}
  107. filters = deepcopy(filters)
  108. filters["knowledge_graph_kwd"] = "relation"
  109. matchDense = self.get_vector(txt, emb_mdl, 1024, sim_thr)
  110. es_res = self.dataStore.search(
  111. ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd", "weight_int"],
  112. [], filters, [matchDense], OrderByExpr(), 0, N, idxnms, kb_ids)
  113. return self._relation_info_from_(es_res, sim_thr)
  114. def get_relevant_ents_by_types(self, types, filters, idxnms, kb_ids, N=56):
  115. if not types:
  116. return {}
  117. filters = deepcopy(filters)
  118. filters["knowledge_graph_kwd"] = "entity"
  119. filters["entity_type_kwd"] = types
  120. ordr = OrderByExpr()
  121. ordr.desc("rank_flt")
  122. es_res = self.dataStore.search(["entity_kwd", "rank_flt"], [], filters, [], ordr, 0, N,
  123. idxnms, kb_ids)
  124. return self._ent_info_from_(es_res, 0)
  125. def retrieval(self, question: str,
  126. tenant_ids: str | list[str],
  127. kb_ids: list[str],
  128. emb_mdl,
  129. llm,
  130. max_token: int = 8196,
  131. ent_topn: int = 6,
  132. rel_topn: int = 6,
  133. comm_topn: int = 1,
  134. ent_sim_threshold: float = 0.3,
  135. rel_sim_threshold: float = 0.3,
  136. ):
  137. qst = question
  138. filters = self.get_filters({"kb_ids": kb_ids})
  139. if isinstance(tenant_ids, str):
  140. tenant_ids = tenant_ids.split(",")
  141. idxnms = [index_name(tid) for tid in tenant_ids]
  142. ty_kwds = []
  143. ents = []
  144. try:
  145. ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
  146. logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
  147. except Exception as e:
  148. logging.exception(e)
  149. ents = [qst]
  150. pass
  151. ents_from_query = self.get_relevant_ents_by_keywords(ents, filters, idxnms, kb_ids, emb_mdl, ent_sim_threshold)
  152. ents_from_types = self.get_relevant_ents_by_types(ty_kwds, filters, idxnms, kb_ids, 10000)
  153. rels_from_txt = self.get_relevant_relations_by_txt(qst, filters, idxnms, kb_ids, emb_mdl, rel_sim_threshold)
  154. nhop_pathes = defaultdict(dict)
  155. for _, ent in ents_from_query.items():
  156. nhops = ent.get("n_hop_ents", [])
  157. for nbr in nhops:
  158. path = nbr["path"]
  159. wts = nbr["weights"]
  160. for i in range(len(path) - 1):
  161. f, t = path[i], path[i + 1]
  162. if (f, t) in nhop_pathes:
  163. nhop_pathes[(f, t)]["sim"] += ent["sim"] / (2 + i)
  164. else:
  165. nhop_pathes[(f, t)]["sim"] = ent["sim"] / (2 + i)
  166. nhop_pathes[(f, t)]["pagerank"] = wts[i]
  167. logging.info("Retrieved entities: {}".format(list(ents_from_query.keys())))
  168. logging.info("Retrieved relations: {}".format(list(rels_from_txt.keys())))
  169. logging.info("Retrieved entities from types({}): {}".format(ty_kwds, list(ents_from_types.keys())))
  170. logging.info("Retrieved N-hops: {}".format(list(nhop_pathes.keys())))
  171. # P(E|Q) => P(E) * P(Q|E) => pagerank * sim
  172. for ent in ents_from_types.keys():
  173. if ent not in ents_from_query:
  174. continue
  175. ents_from_query[ent]["sim"] *= 2
  176. for (f, t) in rels_from_txt.keys():
  177. pair = tuple(sorted([f, t]))
  178. s = 0
  179. if pair in nhop_pathes:
  180. s += nhop_pathes[pair]["sim"]
  181. del nhop_pathes[pair]
  182. if f in ents_from_types:
  183. s += 1
  184. if t in ents_from_types:
  185. s += 1
  186. rels_from_txt[(f, t)]["sim"] *= s + 1
  187. # This is for the relations from n-hop but not by query search
  188. for (f, t) in nhop_pathes.keys():
  189. s = 0
  190. if f in ents_from_types:
  191. s += 1
  192. if t in ents_from_types:
  193. s += 1
  194. rels_from_txt[(f, t)] = {
  195. "sim": nhop_pathes[(f, t)]["sim"] * (s + 1),
  196. "pagerank": nhop_pathes[(f, t)]["pagerank"]
  197. }
  198. ents_from_query = sorted(ents_from_query.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
  199. :ent_topn]
  200. rels_from_txt = sorted(rels_from_txt.items(), key=lambda x: x[1]["sim"] * x[1]["pagerank"], reverse=True)[
  201. :rel_topn]
  202. ents = []
  203. relas = []
  204. for n, ent in ents_from_query:
  205. ents.append({
  206. "Entity": n,
  207. "Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
  208. "Description": json.loads(ent["description"]).get("description", "")
  209. })
  210. max_token -= num_tokens_from_string(str(ents[-1]))
  211. if max_token <= 0:
  212. ents = ents[:-1]
  213. break
  214. for (f, t), rel in rels_from_txt:
  215. if not rel.get("description"):
  216. for tid in tenant_ids:
  217. rela = get_relation(tid, kb_ids, f, t)
  218. if rela:
  219. break
  220. else:
  221. continue
  222. rel["description"] = rela["description"]
  223. relas.append({
  224. "From Entity": f,
  225. "To Entity": t,
  226. "Score": "%.2f" % (rel["sim"] * rel["pagerank"]),
  227. "Description": json.loads(ent["description"]).get("description", "")
  228. })
  229. max_token -= num_tokens_from_string(str(relas[-1]))
  230. if max_token <= 0:
  231. relas = relas[:-1]
  232. break
  233. if ents:
  234. ents = "\n-Entities-\n{}".format(pd.DataFrame(ents).to_csv())
  235. else:
  236. ents = ""
  237. if relas:
  238. relas = "\n-Relations-\n{}".format(pd.DataFrame(relas).to_csv())
  239. else:
  240. relas = ""
  241. return {
  242. "chunk_id": get_uuid(),
  243. "content_ltks": "",
  244. "content_with_weight": ents + relas + self._community_retrival_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
  245. comm_topn, max_token),
  246. "doc_id": "",
  247. "docnm_kwd": "Related content in Knowledge Graph",
  248. "kb_id": kb_ids,
  249. "important_kwd": [],
  250. "image_id": "",
  251. "similarity": 1.,
  252. "vector_similarity": 1.,
  253. "term_similarity": 0,
  254. "vector": [],
  255. "positions": [],
  256. }
  257. def _community_retrival_(self, entities, condition, kb_ids, idxnms, topn, max_token):
  258. ## Community retrieval
  259. fields = ["docnm_kwd", "content_with_weight"]
  260. odr = OrderByExpr()
  261. odr.desc("weight_flt")
  262. fltr = deepcopy(condition)
  263. fltr["knowledge_graph_kwd"] = "community_report"
  264. fltr["entities_kwd"] = entities
  265. comm_res = self.dataStore.search(fields, [], fltr, [],
  266. OrderByExpr(), 0, topn, idxnms, kb_ids)
  267. comm_res_fields = self.dataStore.getFields(comm_res, fields)
  268. txts = []
  269. for ii, (_, row) in enumerate(comm_res_fields.items()):
  270. obj = json.loads(row["content_with_weight"])
  271. txts.append("# {}. {}\n## Content\n{}\n## Evidences\n{}\n".format(
  272. ii + 1, row["docnm_kwd"], obj["report"], obj["evidences"]))
  273. max_token -= num_tokens_from_string(str(txts[-1]))
  274. if not txts:
  275. return ""
  276. return "\n-Community Report-\n" + "\n".join(txts)
  277. if __name__ == "__main__":
  278. from api import settings
  279. import argparse
  280. from api.db import LLMType
  281. from api.db.services.knowledgebase_service import KnowledgebaseService
  282. from api.db.services.llm_service import LLMBundle
  283. from api.db.services.user_service import TenantService
  284. from rag.nlp import search
  285. settings.init_settings()
  286. parser = argparse.ArgumentParser()
  287. parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
  288. parser.add_argument('-d', '--kb_id', default=False, help="Knowledge base ID", action='store', required=True)
  289. parser.add_argument('-q', '--question', default=False, help="Question", action='store', required=True)
  290. args = parser.parse_args()
  291. kb_id = args.kb_id
  292. _, tenant = TenantService.get_by_id(args.tenant_id)
  293. llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
  294. _, kb = KnowledgebaseService.get_by_id(kb_id)
  295. embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
  296. kg = KGSearch(settings.docStoreConn)
  297. print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
  298. search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))