Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

index.py 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. from functools import reduce, partial
  19. import networkx as nx
  20. from api import settings
  21. from graphrag.general.community_reports_extractor import CommunityReportsExtractor
  22. from graphrag.entity_resolution import EntityResolution
  23. from graphrag.general.extractor import Extractor
  24. from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
  25. from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
  26. chunk_id, update_nodes_pagerank_nhop_neighbour
  27. from rag.nlp import rag_tokenizer, search
  28. from rag.utils.redis_conn import RedisDistributedLock
  29. class Dealer:
  30. def __init__(self,
  31. extractor: Extractor,
  32. tenant_id: str,
  33. kb_id: str,
  34. llm_bdl,
  35. chunks: list[tuple[str, str]],
  36. language,
  37. entity_types=DEFAULT_ENTITY_TYPES,
  38. embed_bdl=None,
  39. callback=None
  40. ):
  41. docids = list(set([docid for docid,_ in chunks]))
  42. self.llm_bdl = llm_bdl
  43. self.embed_bdl = embed_bdl
  44. ext = extractor(self.llm_bdl, language=language,
  45. entity_types=entity_types,
  46. get_entity=partial(get_entity, tenant_id, kb_id),
  47. set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
  48. get_relation=partial(get_relation, tenant_id, kb_id),
  49. set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
  50. )
  51. ents, rels = ext(chunks, callback)
  52. self.graph = nx.Graph()
  53. for en in ents:
  54. self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
  55. for rel in rels:
  56. self.graph.add_edge(
  57. rel["src_id"],
  58. rel["tgt_id"],
  59. weight=rel["weight"],
  60. #description=rel["description"]
  61. )
  62. with RedisDistributedLock(kb_id, 60*60):
  63. old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
  64. if old_graph is not None:
  65. logging.info("Merge with an exiting graph...................")
  66. self.graph = reduce(graph_merge, [old_graph, self.graph])
  67. update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
  68. if old_doc_ids:
  69. docids.extend(old_doc_ids)
  70. docids = list(set(docids))
  71. set_graph(tenant_id, kb_id, self.graph, docids)
  72. class WithResolution(Dealer):
  73. def __init__(self,
  74. tenant_id: str,
  75. kb_id: str,
  76. llm_bdl,
  77. embed_bdl=None,
  78. callback=None
  79. ):
  80. self.llm_bdl = llm_bdl
  81. self.embed_bdl = embed_bdl
  82. with RedisDistributedLock(kb_id, 60*60):
  83. self.graph, doc_ids = get_graph(tenant_id, kb_id)
  84. if not self.graph:
  85. logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
  86. if callback:
  87. callback(-1, msg="Faild to fetch the graph.")
  88. return
  89. if callback:
  90. callback(msg="Fetch the existing graph.")
  91. er = EntityResolution(self.llm_bdl,
  92. get_entity=partial(get_entity, tenant_id, kb_id),
  93. set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
  94. get_relation=partial(get_relation, tenant_id, kb_id),
  95. set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
  96. reso = er(self.graph)
  97. self.graph = reso.graph
  98. logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
  99. if callback:
  100. callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
  101. update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
  102. set_graph(tenant_id, kb_id, self.graph, doc_ids)
  103. settings.docStoreConn.delete({
  104. "knowledge_graph_kwd": "relation",
  105. "kb_id": kb_id,
  106. "from_entity_kwd": reso.removed_entities
  107. }, search.index_name(tenant_id), kb_id)
  108. settings.docStoreConn.delete({
  109. "knowledge_graph_kwd": "relation",
  110. "kb_id": kb_id,
  111. "to_entity_kwd": reso.removed_entities
  112. }, search.index_name(tenant_id), kb_id)
  113. settings.docStoreConn.delete({
  114. "knowledge_graph_kwd": "entity",
  115. "kb_id": kb_id,
  116. "entity_kwd": reso.removed_entities
  117. }, search.index_name(tenant_id), kb_id)
  118. class WithCommunity(Dealer):
  119. def __init__(self,
  120. tenant_id: str,
  121. kb_id: str,
  122. llm_bdl,
  123. embed_bdl=None,
  124. callback=None
  125. ):
  126. self.community_structure = None
  127. self.community_reports = None
  128. self.llm_bdl = llm_bdl
  129. self.embed_bdl = embed_bdl
  130. with RedisDistributedLock(kb_id, 60*60):
  131. self.graph, doc_ids = get_graph(tenant_id, kb_id)
  132. if not self.graph:
  133. logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
  134. if callback:
  135. callback(-1, msg="Faild to fetch the graph.")
  136. return
  137. if callback:
  138. callback(msg="Fetch the existing graph.")
  139. cr = CommunityReportsExtractor(self.llm_bdl,
  140. get_entity=partial(get_entity, tenant_id, kb_id),
  141. set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
  142. get_relation=partial(get_relation, tenant_id, kb_id),
  143. set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
  144. cr = cr(self.graph, callback=callback)
  145. self.community_structure = cr.structured_output
  146. self.community_reports = cr.output
  147. set_graph(tenant_id, kb_id, self.graph, doc_ids)
  148. if callback:
  149. callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
  150. settings.docStoreConn.delete({
  151. "knowledge_graph_kwd": "community_report",
  152. "kb_id": kb_id
  153. }, search.index_name(tenant_id), kb_id)
  154. for stru, rep in zip(self.community_structure, self.community_reports):
  155. obj = {
  156. "report": rep,
  157. "evidences": "\n".join([f["explanation"] for f in stru["findings"]])
  158. }
  159. chunk = {
  160. "docnm_kwd": stru["title"],
  161. "title_tks": rag_tokenizer.tokenize(stru["title"]),
  162. "content_with_weight": json.dumps(obj, ensure_ascii=False),
  163. "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
  164. "knowledge_graph_kwd": "community_report",
  165. "weight_flt": stru["weight"],
  166. "entities_kwd": stru["entities"],
  167. "important_kwd": stru["entities"],
  168. "kb_id": kb_id,
  169. "source_id": doc_ids,
  170. "available_int": 0
  171. }
  172. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  173. #try:
  174. # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
  175. # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
  176. #except Exception as e:
  177. # logging.exception(f"Fail to embed entity relation: {e}")
  178. settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))