You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

community_reports_extractor.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import json
  9. import re
  10. import traceback
  11. from typing import Callable
  12. from dataclasses import dataclass
  13. import networkx as nx
  14. import pandas as pd
  15. from graphrag.general import leiden
  16. from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
  17. from graphrag.general.extractor import Extractor
  18. from graphrag.general.leiden import add_community_info2graph
  19. from rag.llm.chat_model import Base as CompletionLLM
  20. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
  21. from rag.utils import num_tokens_from_string
  22. from timeit import default_timer as timer
  23. @dataclass
  24. class CommunityReportsResult:
  25. """Community reports result class definition."""
  26. output: list[str]
  27. structured_output: list[dict]
  28. class CommunityReportsExtractor(Extractor):
  29. """Community reports extractor class definition."""
  30. _extraction_prompt: str
  31. _output_formatter_prompt: str
  32. _on_error: ErrorHandlerFn
  33. _max_report_length: int
  34. def __init__(
  35. self,
  36. llm_invoker: CompletionLLM,
  37. get_entity: Callable | None = None,
  38. set_entity: Callable | None = None,
  39. get_relation: Callable | None = None,
  40. set_relation: Callable | None = None,
  41. max_report_length: int | None = None,
  42. ):
  43. super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
  44. """Init method definition."""
  45. self._llm = llm_invoker
  46. self._extraction_prompt = COMMUNITY_REPORT_PROMPT
  47. self._max_report_length = max_report_length or 1500
  48. def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  49. for node_degree in graph.degree:
  50. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  51. communities: dict[str, dict[str, list]] = leiden.run(graph, {})
  52. total = sum([len(comm.items()) for _, comm in communities.items()])
  53. res_str = []
  54. res_dict = []
  55. over, token_count = 0, 0
  56. st = timer()
  57. for level, comm in communities.items():
  58. logging.info(f"Level {level}: Community: {len(comm.keys())}")
  59. for cm_id, ents in comm.items():
  60. weight = ents["weight"]
  61. ents = ents["nodes"]
  62. ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
  63. if ent_df.empty:
  64. continue
  65. ent_df["entity"] = ent_df["entity_name"]
  66. del ent_df["entity_name"]
  67. rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
  68. if rela_df.empty:
  69. continue
  70. rela_df["source"] = rela_df["src_id"]
  71. rela_df["target"] = rela_df["tgt_id"]
  72. del rela_df["src_id"]
  73. del rela_df["tgt_id"]
  74. prompt_variables = {
  75. "entity_df": ent_df.to_csv(index_label="id"),
  76. "relation_df": rela_df.to_csv(index_label="id")
  77. }
  78. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  79. gen_conf = {"temperature": 0.3}
  80. try:
  81. response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  82. token_count += num_tokens_from_string(text + response)
  83. response = re.sub(r"^[^\{]*", "", response)
  84. response = re.sub(r"[^\}]*$", "", response)
  85. response = re.sub(r"\{\{", "{", response)
  86. response = re.sub(r"\}\}", "}", response)
  87. logging.debug(response)
  88. response = json.loads(response)
  89. if not dict_has_keys_with_types(response, [
  90. ("title", str),
  91. ("summary", str),
  92. ("findings", list),
  93. ("rating", float),
  94. ("rating_explanation", str),
  95. ]):
  96. continue
  97. response["weight"] = weight
  98. response["entities"] = ents
  99. except Exception as e:
  100. logging.exception("CommunityReportsExtractor got exception")
  101. self._on_error(e, traceback.format_exc(), None)
  102. continue
  103. add_community_info2graph(graph, ents, response["title"])
  104. res_str.append(self._get_text_output(response))
  105. res_dict.append(response)
  106. over += 1
  107. if callback:
  108. callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
  109. return CommunityReportsResult(
  110. structured_output=res_dict,
  111. output=res_str,
  112. )
  113. def _get_text_output(self, parsed_output: dict) -> str:
  114. title = parsed_output.get("title", "Report")
  115. summary = parsed_output.get("summary", "")
  116. findings = parsed_output.get("findings", [])
  117. def finding_summary(finding: dict):
  118. if isinstance(finding, str):
  119. return finding
  120. return finding.get("summary")
  121. def finding_explanation(finding: dict):
  122. if isinstance(finding, str):
  123. return ""
  124. return finding.get("explanation")
  125. report_sections = "\n\n".join(
  126. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  127. )
  128. return f"# {title}\n\n{summary}\n\n{report_sections}"