Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. from concurrent.futures import ThreadPoolExecutor
  19. import json
  20. from functools import reduce
  21. import networkx as nx
  22. from api.db import LLMType
  23. from api.db.services.llm_service import LLMBundle
  24. from api.db.services.user_service import TenantService
  25. from graphrag.community_reports_extractor import CommunityReportsExtractor
  26. from graphrag.entity_resolution import EntityResolution
  27. from graphrag.graph_extractor import GraphExtractor, DEFAULT_ENTITY_TYPES
  28. from graphrag.mind_map_extractor import MindMapExtractor
  29. from rag.nlp import rag_tokenizer
  30. from rag.utils import num_tokens_from_string
  31. def graph_merge(g1, g2):
  32. g = g2.copy()
  33. for n, attr in g1.nodes(data=True):
  34. if n not in g2.nodes():
  35. g.add_node(n, **attr)
  36. continue
  37. g.nodes[n]["weight"] += 1
  38. if g.nodes[n]["description"].lower().find(attr["description"][:32].lower()) < 0:
  39. g.nodes[n]["description"] += "\n" + attr["description"]
  40. for source, target, attr in g1.edges(data=True):
  41. if g.has_edge(source, target):
  42. g[source][target].update({"weight": attr["weight"]+1})
  43. continue
  44. g.add_edge(source, target, **attr)
  45. for node_degree in g.degree:
  46. g.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  47. return g
  48. def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, entity_types=DEFAULT_ENTITY_TYPES):
  49. _, tenant = TenantService.get_by_id(tenant_id)
  50. llm_bdl = LLMBundle(tenant_id, LLMType.CHAT, tenant.llm_id)
  51. ext = GraphExtractor(llm_bdl)
  52. left_token_count = llm_bdl.max_length - ext.prompt_token_count - 1024
  53. left_token_count = max(llm_bdl.max_length * 0.6, left_token_count)
  54. assert left_token_count > 0, f"The LLM context length({llm_bdl.max_length}) is smaller than prompt({ext.prompt_token_count})"
  55. BATCH_SIZE=4
  56. texts, graphs = [], []
  57. cnt = 0
  58. max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
  59. with ThreadPoolExecutor(max_workers=max_workers) as exe:
  60. threads = []
  61. for i in range(len(chunks)):
  62. tkn_cnt = num_tokens_from_string(chunks[i])
  63. if cnt+tkn_cnt >= left_token_count and texts:
  64. for b in range(0, len(texts), BATCH_SIZE):
  65. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  66. texts = []
  67. cnt = 0
  68. texts.append(chunks[i])
  69. cnt += tkn_cnt
  70. if texts:
  71. for b in range(0, len(texts), BATCH_SIZE):
  72. threads.append(exe.submit(ext, ["\n".join(texts[b:b+BATCH_SIZE])], {"entity_types": entity_types}, callback))
  73. callback(0.5, "Extracting entities.")
  74. graphs = []
  75. for i, _ in enumerate(threads):
  76. graphs.append(_.result().output)
  77. callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}")
  78. graph = reduce(graph_merge, graphs) if graphs else nx.Graph()
  79. er = EntityResolution(llm_bdl)
  80. graph = er(graph).output
  81. _chunks = chunks
  82. chunks = []
  83. for n, attr in graph.nodes(data=True):
  84. if attr.get("rank", 0) == 0:
  85. logging.debug(f"Ignore entity: {n}")
  86. continue
  87. chunk = {
  88. "name_kwd": n,
  89. "important_kwd": [n],
  90. "title_tks": rag_tokenizer.tokenize(n),
  91. "content_with_weight": json.dumps({"name": n, **attr}, ensure_ascii=False),
  92. "content_ltks": rag_tokenizer.tokenize(attr["description"]),
  93. "knowledge_graph_kwd": "entity",
  94. "rank_int": attr["rank"],
  95. "weight_int": attr["weight"]
  96. }
  97. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  98. chunks.append(chunk)
  99. callback(0.6, "Extracting community reports.")
  100. cr = CommunityReportsExtractor(llm_bdl)
  101. cr = cr(graph, callback=callback)
  102. for community, desc in zip(cr.structured_output, cr.output):
  103. chunk = {
  104. "title_tks": rag_tokenizer.tokenize(community["title"]),
  105. "content_with_weight": desc,
  106. "content_ltks": rag_tokenizer.tokenize(desc),
  107. "knowledge_graph_kwd": "community_report",
  108. "weight_flt": community["weight"],
  109. "entities_kwd": community["entities"],
  110. "important_kwd": community["entities"]
  111. }
  112. chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
  113. chunks.append(chunk)
  114. chunks.append(
  115. {
  116. "content_with_weight": json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2),
  117. "knowledge_graph_kwd": "graph"
  118. })
  119. callback(0.75, "Extracting mind graph.")
  120. mindmap = MindMapExtractor(llm_bdl)
  121. mg = mindmap(_chunks).output
  122. if not len(mg.keys()):
  123. return chunks
  124. logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
  125. chunks.append(
  126. {
  127. "content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),
  128. "knowledge_graph_kwd": "mind_map"
  129. })
  130. return chunks