You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.py 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. from functools import reduce, partial
  19. import networkx as nx
  20. import trio
  21. from api import settings
  22. from graphrag.general.community_reports_extractor import CommunityReportsExtractor
  23. from graphrag.entity_resolution import EntityResolution
  24. from graphrag.general.extractor import Extractor
  25. from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
  26. from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
  27. chunk_id, update_nodes_pagerank_nhop_neighbour
  28. from rag.nlp import rag_tokenizer, search
  29. from rag.utils.redis_conn import RedisDistributedLock
  30. class Dealer:
  31. def __init__(self,
  32. extractor: Extractor,
  33. tenant_id: str,
  34. kb_id: str,
  35. llm_bdl,
  36. chunks: list[tuple[str, str]],
  37. language,
  38. entity_types=DEFAULT_ENTITY_TYPES,
  39. embed_bdl=None,
  40. callback=None
  41. ):
  42. self.tenant_id = tenant_id
  43. self.kb_id = kb_id
  44. self.chunks = chunks
  45. self.llm_bdl = llm_bdl
  46. self.embed_bdl = embed_bdl
  47. self.ext = extractor(self.llm_bdl, language=language,
  48. entity_types=entity_types,
  49. get_entity=partial(get_entity, tenant_id, kb_id),
  50. set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
  51. get_relation=partial(get_relation, tenant_id, kb_id),
  52. set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
  53. )
  54. self.graph = nx.Graph()
  55. self.callback = callback
  56. async def __call__(self):
  57. docids = list(set([docid for docid, _ in self.chunks]))
  58. ents, rels = await self.ext(self.chunks, self.callback)
  59. for en in ents:
  60. self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
  61. for rel in rels:
  62. self.graph.add_edge(
  63. rel["src_id"],
  64. rel["tgt_id"],
  65. weight=rel["weight"],
  66. #description=rel["description"]
  67. )
  68. with RedisDistributedLock(self.kb_id, 60*60):
  69. old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
  70. if old_graph is not None:
  71. logging.info("Merge with an exiting graph...................")
  72. self.graph = reduce(graph_merge, [old_graph, self.graph])
  73. update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
  74. if old_doc_ids:
  75. docids.extend(old_doc_ids)
  76. docids = list(set(docids))
  77. set_graph(self.tenant_id, self.kb_id, self.graph, docids)
  78. class WithResolution(Dealer):
  79. def __init__(self,
  80. tenant_id: str,
  81. kb_id: str,
  82. llm_bdl,
  83. embed_bdl=None,
  84. callback=None
  85. ):
  86. self.tenant_id = tenant_id
  87. self.kb_id = kb_id
  88. self.llm_bdl = llm_bdl
  89. self.embed_bdl = embed_bdl
  90. self.callback = callback
  91. async def __call__(self):
  92. with RedisDistributedLock(self.kb_id, 60*60):
  93. self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
  94. if not self.graph:
  95. logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
  96. if self.callback:
  97. self.callback(-1, msg="Faild to fetch the graph.")
  98. return
  99. if self.callback:
  100. self.callback(msg="Fetch the existing graph.")
  101. er = EntityResolution(self.llm_bdl,
  102. get_entity=partial(get_entity, self.tenant_id, self.kb_id),
  103. set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
  104. get_relation=partial(get_relation, self.tenant_id, self.kb_id),
  105. set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
  106. reso = await er(self.graph)
  107. self.graph = reso.graph
  108. logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
  109. if self.callback:
  110. self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
  111. await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
  112. await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
  113. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
  114. "knowledge_graph_kwd": "relation",
  115. "kb_id": self.kb_id,
  116. "from_entity_kwd": reso.removed_entities
  117. }, search.index_name(self.tenant_id), self.kb_id))
  118. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
  119. "knowledge_graph_kwd": "relation",
  120. "kb_id": self.kb_id,
  121. "to_entity_kwd": reso.removed_entities
  122. }, search.index_name(self.tenant_id), self.kb_id))
  123. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
  124. "knowledge_graph_kwd": "entity",
  125. "kb_id": self.kb_id,
  126. "entity_kwd": reso.removed_entities
  127. }, search.index_name(self.tenant_id), self.kb_id))
  128. class WithCommunity(Dealer):
  129. def __init__(self,
  130. tenant_id: str,
  131. kb_id: str,
  132. llm_bdl,
  133. embed_bdl=None,
  134. callback=None
  135. ):
  136. self.tenant_id = tenant_id
  137. self.kb_id = kb_id
  138. self.community_structure = None
  139. self.community_reports = None
  140. self.llm_bdl = llm_bdl
  141. self.embed_bdl = embed_bdl
  142. self.callback = callback
  143. async def __call__(self):
  144. with RedisDistributedLock(self.kb_id, 60*60):
  145. self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
  146. if not self.graph:
  147. logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
  148. if self.callback:
  149. self.callback(-1, msg="Faild to fetch the graph.")
  150. return
  151. if self.callback:
  152. self.callback(msg="Fetch the existing graph.")
  153. cr = CommunityReportsExtractor(self.llm_bdl,
  154. get_entity=partial(get_entity, self.tenant_id, self.kb_id),
  155. set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
  156. get_relation=partial(get_relation, self.tenant_id, self.kb_id),
  157. set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
  158. cr = await cr(self.graph, callback=self.callback)
  159. self.community_structure = cr.structured_output
  160. self.community_reports = cr.output
  161. await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))
  162. if self.callback:
  163. self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
  164. await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
  165. "knowledge_graph_kwd": "community_report",
  166. "kb_id": self.kb_id
  167. }, search.index_name(self.tenant_id), self.kb_id))
  168. for stru, rep in zip(self.community_structure, self.community_reports):
  169. obj = {
  170. "report": rep,
  171. "evidences": "\n".join([f["explanation"] for f in stru["findings"]])
  172. }
  173. chunk = {
  174. "docnm_kwd": stru["title"],
  175. "title_tks": rag_tokenizer.tokenize(stru["title"]),
  176. "content_with_weight": json.dumps(obj, ensure_ascii=False),
  177. "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
  178. "knowledge_graph_kwd": "community_report",
  179. "weight_flt": stru["weight"],
  180. "entities_kwd": stru["entities"],
  181. "important_kwd": stru["entities"],
  182. "kb_id": self.kb_id,
  183. "source_id": doc_ids,
  184. "available_int": 0
  185. }
  186. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  187. #try:
  188. # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
  189. # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
  190. #except Exception as e:
  191. # logging.exception(f"Fail to embed entity relation: {e}")
  192. await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))