You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

community_reports_extractor.py 5.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. """
  17. Reference:
  18. - [graphrag](https://github.com/microsoft/graphrag)
  19. """
  20. import json
  21. import logging
  22. import re
  23. import traceback
  24. from dataclasses import dataclass
  25. from typing import Any, List
  26. import networkx as nx
  27. import pandas as pd
  28. from graphrag import leiden
  29. from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
  30. from graphrag.leiden import add_community_info2graph
  31. from rag.llm.chat_model import Base as CompletionLLM
  32. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
  33. log = logging.getLogger(__name__)
  34. @dataclass
  35. class CommunityReportsResult:
  36. """Community reports result class definition."""
  37. output: List[str]
  38. structured_output: List[dict]
  39. class CommunityReportsExtractor:
  40. """Community reports extractor class definition."""
  41. _llm: CompletionLLM
  42. _extraction_prompt: str
  43. _output_formatter_prompt: str
  44. _on_error: ErrorHandlerFn
  45. _max_report_length: int
  46. def __init__(
  47. self,
  48. llm_invoker: CompletionLLM,
  49. extraction_prompt: str | None = None,
  50. on_error: ErrorHandlerFn | None = None,
  51. max_report_length: int | None = None,
  52. ):
  53. """Init method definition."""
  54. self._llm = llm_invoker
  55. self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
  56. self._on_error = on_error or (lambda _e, _s, _d: None)
  57. self._max_report_length = max_report_length or 1500
  58. def __call__(self, graph: nx.Graph):
  59. communities: dict[str, dict[str, List]] = leiden.run(graph, {})
  60. relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
  61. res_str = []
  62. res_dict = []
  63. for level, comm in communities.items():
  64. for cm_id, ents in comm.items():
  65. weight = ents["weight"]
  66. ents = ents["nodes"]
  67. ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
  68. rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
  69. prompt_variables = {
  70. "entity_df": ent_df.to_csv(index_label="id"),
  71. "relation_df": rela_df.to_csv(index_label="id")
  72. }
  73. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  74. gen_conf = {"temperature": 0.5}
  75. try:
  76. response = self._llm.chat(text, [], gen_conf)
  77. response = re.sub(r"^[^\{]*", "", response)
  78. response = re.sub(r"[^\}]*$", "", response)
  79. print(response)
  80. response = json.loads(response)
  81. if not dict_has_keys_with_types(response, [
  82. ("title", str),
  83. ("summary", str),
  84. ("findings", list),
  85. ("rating", float),
  86. ("rating_explanation", str),
  87. ]): continue
  88. response["weight"] = weight
  89. response["entities"] = ents
  90. except Exception as e:
  91. print("ERROR: ", traceback.format_exc())
  92. self._on_error(e, traceback.format_exc(), None)
  93. continue
  94. add_community_info2graph(graph, ents, response["title"])
  95. res_str.append(self._get_text_output(response))
  96. res_dict.append(response)
  97. return CommunityReportsResult(
  98. structured_output=res_dict,
  99. output=res_str,
  100. )
  101. def _get_text_output(self, parsed_output: dict) -> str:
  102. title = parsed_output.get("title", "Report")
  103. summary = parsed_output.get("summary", "")
  104. findings = parsed_output.get("findings", [])
  105. def finding_summary(finding: dict):
  106. if isinstance(finding, str):
  107. return finding
  108. return finding.get("summary")
  109. def finding_explanation(finding: dict):
  110. if isinstance(finding, str):
  111. return ""
  112. return finding.get("explanation")
  113. report_sections = "\n\n".join(
  114. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  115. )
  116. return f"# {title}\n\n{summary}\n\n{report_sections}"