You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

index.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from concurrent.futures import ThreadPoolExecutor
  17. import json
  18. from functools import reduce
  19. from typing import List
  20. import networkx as nx
  21. from api.db import LLMType
  22. from api.db.services.llm_service import LLMBundle
  23. from api.db.services.user_service import TenantService
  24. from graphrag.community_reports_extractor import CommunityReportsExtractor
  25. from graphrag.entity_resolution import EntityResolution
  26. from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
  27. from graphrag.mind_map_extractor import MindMapExtractor
  28. from rag.nlp import rag_tokenizer
  29. from rag.utils import num_tokens_from_string
  30. def graph_merge(g1, g2):
  31. g = g2.copy()
  32. for n, attr in g1.nodes(data=True):
  33. if n not in g2.nodes():
  34. g.add_node(n, **attr)
  35. continue
  36. g.nodes[n]["weight"] += 1
  37. if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
  38. g.nodes[n]["description"] += "\n" + attr["description"]
  39. for source, target, attr in g1.edges(data=True):
  40. if g.has_edge(source, target):
  41. g[source][target].update({"weight": attr["weight"]+1})
  42. continue
  43. g.add_edge(source, target, **attr)
  44. for node_degree in g.degree:
  45. g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  46. return g
  47. def build_knowledge_graph_chunks(tenant_id: str, chunks: List[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
  48. _, tenant = TenantService.get_by_id(tenant_id)
  49. llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
  50. ext = GraphExtractor(llm_bdl)
  51. left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
  52. left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
  53. assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
  54. BATCH_SIZE=4
  55. texts, graphs = [], []
  56. cnt = 0
  57. threads = []
  58. exe = ThreadPoolExecutor(max_workers=50)
  59. for i in range(len(chunks)):
  60. tkn_cnt = num_tokens_from_string(chunks[i])
  61. if cnt+tkn_cnt >= left_token_count and texts:
  62. for b in range(0, len(texts), BATCH_SIZE):
  63. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  64. texts = []
  65. cnt = 0
  66. texts.append(chunks[i])
  67. cnt += tkn_cnt
  68. if texts:
  69. for b in range(0, len(texts), BATCH_SIZE):
  70. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  71. callback(0.5, "Extracting entities.")
  72. graphs = []
  73. for i, _ in enumerate(threads):
  74. graphs.append(_.result().output)
  75. callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
  76. graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
  77. er = EntityResolution(llm_bdl)
  78. graph = er(graph).output
  79. _chunks = chunks
  80. chunks = []
  81. for n, attr in graph.nodes(data=True):
  82. if attr.get("rank", 0) == 0:
  83. print(f"Ignore entity: {n}")
  84. continue
  85. chunk = {
  86. "name_kwd": n,
  87. "important_kwd": [n],
  88. "title_tks": rag_tokenizer.tokenize(n),
  89. "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
  90. "content_ltks": rag_tokenizer.tokenize(attr["description"]),
  91. "knowledge_graph_kwd": "entity",
  92. "rank_int": attr["rank"],
  93. "weight_int": attr["weight"]
  94. }
  95. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  96. chunks.append(chunk)
  97. callback(0.6, "Extracting community reports.")
  98. cr = CommunityReportsExtractor(llm_bdl)
  99. cr = cr(graph, callback=callback)
  100. for community, desc in zip(cr.structured_output, cr.output):
  101. chunk = {
  102. "title_tks": rag_tokenizer.tokenize(community["title"]),
  103. "content_with_weight": desc,
  104. "content_ltks": rag_tokenizer.tokenize(desc),
  105. "knowledge_graph_kwd": "community_report",
  106. "weight_flt": community["weight"],
  107. "entities_kwd": community["entities"],
  108. "important_kwd": community["entities"]
  109. }
  110. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  111. chunks.append(chunk)
  112. chunks.append(
  113. {
  114. "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
  115. "knowledge_graph_kwd": "graph"
  116. })
  117. callback(0.75, "Extracting mind graph.")
  118. mindmap = MindMapExtractor(llm_bdl)
  119. mg = mindmap(_chunks).output
  120. if not len(mg.keys()): return chunks
  121. print(json.dumps(mg, ensure_ascii=False, indent=2))
  122. chunks.append(
  123. {
  124. "content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
  125. "knowledge_graph_kwd": "mind_map"
  126. })
  127. return chunks