You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. from concurrent.futures import ThreadPoolExecutor
  18. import json
  19. from functools import reduce
  20. from typing import List
  21. import networkx as nx
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from graphrag.community_reports_extractor import CommunityReportsExtractor
  25. from graphrag.entity_resolution import EntityResolution
  26. from graphrag.graph_extractor import GraphExtractor
  27. from graphrag.mind_map_extractor import MindMapExtractor
  28. from rag.nlp import rag_tokenizer
  29. from rag.utils import num_tokens_from_string
  30. def be_children(obj: dict):
  31. arr = []
  32. for k,v in obj.items():
  33. k = re.sub(r"\*+", "", k)
  34. if not k :continue
  35. arr.append({
  36. "id": k,
  37. "children": be_children(v) if isinstance(v, dict) else []
  38. })
  39. return arr
  40. def graph_merge(g1, g2):
  41. g = g2.copy()
  42. for n, attr in g1.nodes(data=True):
  43. if n not in g2.nodes():
  44. g2.add_node(n, **attr)
  45. continue
  46. g.nodes[n]["weight"] += 1
  47. if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
  48. g.nodes[n]["description"] += "\n" + attr["description"]
  49. for source, target, attr in g1.edges(data=True):
  50. if g.has_edge(source, target):
  51. g[source][target].update({"weight": attr["weight"]+1})
  52. continue
  53. g.add_edge(source, target, **attr)
  54. for node_degree in g.degree:
  55. g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  56. return g
  57. def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=["organization", "person", "location", "event", "time"]):
  58. llm_bdl = LLMBundle(tenant_id, LLMType.CHAT)
  59. ext = GraphExtractor(llm_bdl)
  60. left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
  61. left_token_count = llm_bdl.max_length * 0.4
  62. assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
  63. texts, graphs = [], []
  64. cnt = 0
  65. threads = []
  66. exe = ThreadPoolExecutor(max_workers=12)
  67. for i in range(len(chunks[:512])):
  68. tkn_cnt = num_tokens_from_string(chunks[i])
  69. if cnt+tkn_cnt >= left_token_count and texts:
  70. threads.append(exe.submit(ext, texts, {"entity_types": entity_types}))
  71. texts = []
  72. cnt = 0
  73. texts.append(chunks[i])
  74. cnt += tkn_cnt
  75. if texts:
  76. threads.append(exe.submit(ext, texts))
  77. callback(0.5, "Extracting entities.")
  78. graphs = []
  79. for i, _ in enumerate(threads):
  80. graphs.append(_.result().output)
  81. callback(0.5 + 0.1*i/len(threads))
  82. graph = reduce(graph_merge, graphs)
  83. er = EntityResolution(llm_bdl)
  84. graph = er(graph).output
  85. _chunks = chunks
  86. chunks = []
  87. for n, attr in graph.nodes(data=True):
  88. if attr.get("rank", 0) == 0:
  89. print(f"Ignore entity: {n}")
  90. continue
  91. chunk = {
  92. "name_kwd": n,
  93. "important_kwd": [n],
  94. "title_tks": rag_tokenizer.tokenize(n),
  95. "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
  96. "content_ltks": rag_tokenizer.tokenize(attr["description"]),
  97. "knowledge_graph_kwd": "entity",
  98. "rank_int": attr["rank"],
  99. "weight_int": attr["weight"]
  100. }
  101. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  102. chunks.append(chunk)
  103. callback(0.6, "Extracting community reports.")
  104. cr = CommunityReportsExtractor(llm_bdl)
  105. cr = cr(graph)
  106. for community, desc in zip(cr.structured_output, cr.output):
  107. chunk = {
  108. "title_tks": rag_tokenizer.tokenize(community["title"]),
  109. "content_with_weight": desc,
  110. "content_ltks": rag_tokenizer.tokenize(desc),
  111. "knowledge_graph_kwd": "community_report",
  112. "weight_flt": community["weight"],
  113. "entities_kwd": community["entities"],
  114. "important_kwd": community["entities"]
  115. }
  116. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  117. chunks.append(chunk)
  118. chunks.append(
  119. {
  120. "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
  121. "knowledge_graph_kwd": "graph"
  122. })
  123. callback(0.75, "Extracting mind graph.")
  124. mindmap = MindMapExtractor(llm_bdl)
  125. mg = mindmap(_chunks).output
  126. if not len(mg.keys()): return chunks
  127. if len(mg.keys()) > 1: md_map = {"id": "root", "children": [{"id": re.sub(r"\*+", "", k), "children": be_children(v)} for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]}
  128. else: md_map = {"id": re.sub(r"\*+", "", list(mg.keys())[0]), "children": be_children(list(mg.items())[1])}
  129. print(json.dumps(md_map, ensure_ascii=False, indent=2))
  130. chunks.append(
  131. {
  132. "content_with_weight": json.dumps(md_map, ensure_ascii=False, indent=2),
  133. "knowledge_graph_kwd": "mind_map"
  134. })
  135. return chunks