You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

community_reports_extractor.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import json
  9. import os
  10. import re
  11. from typing import Callable
  12. from dataclasses import dataclass
  13. import networkx as nx
  14. import pandas as pd
  15. from api.utils.api_utils import timeout
  16. from graphrag.general import leiden
  17. from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
  18. from graphrag.general.extractor import Extractor
  19. from graphrag.general.leiden import add_community_info2graph
  20. from rag.llm.chat_model import Base as CompletionLLM
  21. from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
  22. from rag.utils import num_tokens_from_string
  23. import trio
  24. @dataclass
  25. class CommunityReportsResult:
  26. """Community reports result class definition."""
  27. output: list[str]
  28. structured_output: list[dict]
  29. class CommunityReportsExtractor(Extractor):
  30. """Community reports extractor class definition."""
  31. _extraction_prompt: str
  32. _output_formatter_prompt: str
  33. _max_report_length: int
  34. def __init__(
  35. self,
  36. llm_invoker: CompletionLLM,
  37. max_report_length: int | None = None,
  38. ):
  39. super().__init__(llm_invoker)
  40. """Init method definition."""
  41. self._llm = llm_invoker
  42. self._extraction_prompt = COMMUNITY_REPORT_PROMPT
  43. self._max_report_length = max_report_length or 1500
  44. async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  45. enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
  46. for node_degree in graph.degree:
  47. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  48. communities: dict[str, dict[str, list]] = leiden.run(graph, {})
  49. total = sum([len(comm.items()) for _, comm in communities.items()])
  50. res_str = []
  51. res_dict = []
  52. over, token_count = 0, 0
  53. @timeout(120)
  54. async def extract_community_report(community):
  55. nonlocal res_str, res_dict, over, token_count
  56. cm_id, cm = community
  57. weight = cm["weight"]
  58. ents = cm["nodes"]
  59. if len(ents) < 2:
  60. return
  61. ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
  62. ent_df = pd.DataFrame(ent_list)
  63. rela_list = []
  64. k = 0
  65. for i in range(0, len(ents)):
  66. if k >= 10000:
  67. break
  68. for j in range(i + 1, len(ents)):
  69. if k >= 10000:
  70. break
  71. edge = graph.get_edge_data(ents[i], ents[j])
  72. if edge is None:
  73. continue
  74. rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
  75. k += 1
  76. rela_df = pd.DataFrame(rela_list)
  77. prompt_variables = {
  78. "entity_df": ent_df.to_csv(index_label="id"),
  79. "relation_df": rela_df.to_csv(index_label="id")
  80. }
  81. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  82. async with chat_limiter:
  83. try:
  84. with trio.move_on_after(180 if enable_timeout_assertion else 1000000000) as cancel_scope:
  85. response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
  86. if cancel_scope.cancelled_caught:
  87. logging.warning("extract_community_report._chat timeout, skipping...")
  88. return
  89. except Exception as e:
  90. logging.error(f"extract_community_report._chat failed: {e}")
  91. return
  92. token_count += num_tokens_from_string(text + response)
  93. response = re.sub(r"^[^\{]*", "", response)
  94. response = re.sub(r"[^\}]*$", "", response)
  95. response = re.sub(r"\{\{", "{", response)
  96. response = re.sub(r"\}\}", "}", response)
  97. logging.debug(response)
  98. try:
  99. response = json.loads(response)
  100. except json.JSONDecodeError as e:
  101. logging.error(f"Failed to parse JSON response: {e}")
  102. logging.error(f"Response content: {response}")
  103. return
  104. if not dict_has_keys_with_types(response, [
  105. ("title", str),
  106. ("summary", str),
  107. ("findings", list),
  108. ("rating", float),
  109. ("rating_explanation", str),
  110. ]):
  111. return
  112. response["weight"] = weight
  113. response["entities"] = ents
  114. add_community_info2graph(graph, ents, response["title"])
  115. res_str.append(self._get_text_output(response))
  116. res_dict.append(response)
  117. over += 1
  118. if callback:
  119. callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
  120. st = trio.current_time()
  121. async with trio.open_nursery() as nursery:
  122. for level, comm in communities.items():
  123. logging.info(f"Level {level}: Community: {len(comm.keys())}")
  124. for community in comm.items():
  125. nursery.start_soon(extract_community_report, community)
  126. if callback:
  127. callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
  128. return CommunityReportsResult(
  129. structured_output=res_dict,
  130. output=res_str,
  131. )
  132. def _get_text_output(self, parsed_output: dict) -> str:
  133. title = parsed_output.get("title", "Report")
  134. summary = parsed_output.get("summary", "")
  135. findings = parsed_output.get("findings", [])
  136. def finding_summary(finding: dict):
  137. if isinstance(finding, str):
  138. return finding
  139. return finding.get("summary")
  140. def finding_explanation(finding: dict):
  141. if isinstance(finding, str):
  142. return ""
  143. return finding.get("explanation")
  144. report_sections = "\n\n".join(
  145. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  146. )
  147. return f"# {title}\n\n{summary}\n\n{report_sections}"