You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import re
  17. from concurrent.futures import ThreadPoolExecutor
  18. import json
  19. from functools import reduce
  20. from typing import List
  21. import networkx as nx
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.user_service import TenantService
  25. from graphrag.community_reports_extractor import CommunityReportsExtractor
  26. from graphrag.entity_resolution import EntityResolution
  27. from graphrag.graph_extractor import GraphExtractor
  28. from graphrag.mind_map_extractor import MindMapExtractor
  29. from rag.nlp import rag_tokenizer
  30. from rag.utils import num_tokens_from_string
  31. def be_children(obj: dict, keyset:set):
  32. if isinstance(obj, str):
  33. obj = [obj]
  34. if isinstance(obj, list):
  35. for i in obj: keyset.add(i)
  36. return [{"id": re.sub(r"\*+", "", i), "children":[]} for i in obj]
  37. arr = []
  38. for k,v in obj.items():
  39. k = re.sub(r"\*+", "", k)
  40. if not k or k in keyset:continue
  41. keyset.add(k)
  42. arr.append({
  43. "id": k,
  44. "children": be_children(v, keyset)
  45. })
  46. return arr
  47. def graph_merge(g1, g2):
  48. g = g2.copy()
  49. for n, attr in g1.nodes(data=True):
  50. if n not in g2.nodes():
  51. g.add_node(n, **attr)
  52. continue
  53. g.nodes[n]["weight"] += 1
  54. if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
  55. g.nodes[n]["description"] += "\n" + attr["description"]
  56. for source, target, attr in g1.edges(data=True):
  57. if g.has_edge(source, target):
  58. g[source][target].update({"weight": attr["weight"]+1})
  59. continue
  60. g.add_edge(source, target, **attr)
  61. for node_degree in g.degree:
  62. g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  63. return g
  64. def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=["organization", "person", "location", "event", "time"]):
  65. _, tenant = TenantService.get_by_id(tenant_id)
  66. llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
  67. ext = GraphExtractor(llm_bdl)
  68. left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
  69. left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
  70. assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
  71. BATCH_SIZE=1
  72. texts, graphs = [], []
  73. cnt = 0
  74. threads = []
  75. exe = ThreadPoolExecutor(max_workers=12)
  76. for i in range(len(chunks)):
  77. tkn_cnt = num_tokens_from_string(chunks[i])
  78. if cnt+tkn_cnt >= left_token_count and texts:
  79. for b in range(0, len(texts), BATCH_SIZE):
  80. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  81. texts = []
  82. cnt = 0
  83. texts.append(chunks[i])
  84. cnt += tkn_cnt
  85. if texts:
  86. for b in range(0, len(texts), BATCH_SIZE):
  87. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  88. callback(0.5, "Extracting entities.")
  89. graphs = []
  90. for i, _ in enumerate(threads):
  91. graphs.append(_.result().output)
  92. callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
  93. graph = reduce(graph_merge, graphs)
  94. er = EntityResolution(llm_bdl)
  95. graph = er(graph).output
  96. _chunks = chunks
  97. chunks = []
  98. for n, attr in graph.nodes(data=True):
  99. if attr.get("rank", 0) == 0:
  100. print(f"Ignore entity: {n}")
  101. continue
  102. chunk = {
  103. "name_kwd": n,
  104. "important_kwd": [n],
  105. "title_tks": rag_tokenizer.tokenize(n),
  106. "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
  107. "content_ltks": rag_tokenizer.tokenize(attr["description"]),
  108. "knowledge_graph_kwd": "entity",
  109. "rank_int": attr["rank"],
  110. "weight_int": attr["weight"]
  111. }
  112. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  113. chunks.append(chunk)
  114. callback(0.6, "Extracting community reports.")
  115. cr = CommunityReportsExtractor(llm_bdl)
  116. cr = cr(graph, callback=callback)
  117. for community, desc in zip(cr.structured_output, cr.output):
  118. chunk = {
  119. "title_tks": rag_tokenizer.tokenize(community["title"]),
  120. "content_with_weight": desc,
  121. "content_ltks": rag_tokenizer.tokenize(desc),
  122. "knowledge_graph_kwd": "community_report",
  123. "weight_flt": community["weight"],
  124. "entities_kwd": community["entities"],
  125. "important_kwd": community["entities"]
  126. }
  127. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  128. chunks.append(chunk)
  129. chunks.append(
  130. {
  131. "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
  132. "knowledge_graph_kwd": "graph"
  133. })
  134. callback(0.75, "Extracting mind graph.")
  135. mindmap = MindMapExtractor(llm_bdl)
  136. mg = mindmap(_chunks).output
  137. if not len(mg.keys()): return chunks
  138. if len(mg.keys()) > 1:
  139. keyset = set([re.sub(r"\*+", "", k) for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
  140. md_map = {"id": "root", "children": [{"id": re.sub(r"\*+", "", k), "children": be_children(v, keyset)} for k,v in mg.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)]}
  141. else:
  142. k = re.sub(r"\*+", "", list(mg.keys())[0])
  143. md_map = {"id": k, "children": be_children(list(mg.items())[0][1], set([k]))}
  144. print(json.dumps(md_map, ensure_ascii=False, indent=2))
  145. chunks.append(
  146. {
  147. "content_with_weight": json.dumps(md_map, ensure_ascii=False, indent=2),
  148. "knowledge_graph_kwd": "mind_map"
  149. })
  150. return chunks