您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

community_reports_extractor.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. """
  17. Reference:
  18. - [graphrag](https://github.com/microsoft/graphrag)
  19. """
  20. import json
  21. import logging
  22. import re
  23. import traceback
  24. from dataclasses import dataclass
  25. from typing import Any, List, Callable
  26. import networkx as nx
  27. import pandas as pd
  28. from graphrag import leiden
  29. from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
  30. from graphrag.leiden import add_community_info2graph
  31. from rag.llm.chat_model import Base as CompletionLLM
  32. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
  33. from rag.utils import num_tokens_from_string
  34. from timeit import default_timer as timer
  35. log = logging.getLogger(__name__)
  36. @dataclass
  37. class CommunityReportsResult:
  38. """Community reports result class definition."""
  39. output: List[str]
  40. structured_output: List[dict]
  41. class CommunityReportsExtractor:
  42. """Community reports extractor class definition."""
  43. _llm: CompletionLLM
  44. _extraction_prompt: str
  45. _output_formatter_prompt: str
  46. _on_error: ErrorHandlerFn
  47. _max_report_length: int
  48. def __init__(
  49. self,
  50. llm_invoker: CompletionLLM,
  51. extraction_prompt: str | None = None,
  52. on_error: ErrorHandlerFn | None = None,
  53. max_report_length: int | None = None,
  54. ):
  55. """Init method definition."""
  56. self._llm = llm_invoker
  57. self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
  58. self._on_error = on_error or (lambda _e, _s, _d: None)
  59. self._max_report_length = max_report_length or 1500
  60. def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  61. communities: dict[str, dict[str, List]] = leiden.run(graph, {})
  62. total = sum([len(comm.items()) for _, comm in communities.items()])
  63. relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
  64. res_str = []
  65. res_dict = []
  66. over, token_count = 0, 0
  67. st = timer()
  68. for level, comm in communities.items():
  69. for cm_id, ents in comm.items():
  70. weight = ents["weight"]
  71. ents = ents["nodes"]
  72. ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
  73. rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
  74. prompt_variables = {
  75. "entity_df": ent_df.to_csv(index_label="id"),
  76. "relation_df": rela_df.to_csv(index_label="id")
  77. }
  78. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  79. gen_conf = {"temperature": 0.3}
  80. try:
  81. response = self._llm.chat(text, [], gen_conf)
  82. token_count += num_tokens_from_string(text + response)
  83. response = re.sub(r"^[^\{]*", "", response)
  84. response = re.sub(r"[^\}]*$", "", response)
  85. print(response)
  86. response = json.loads(response)
  87. if not dict_has_keys_with_types(response, [
  88. ("title", str),
  89. ("summary", str),
  90. ("findings", list),
  91. ("rating", float),
  92. ("rating_explanation", str),
  93. ]): continue
  94. response["weight"] = weight
  95. response["entities"] = ents
  96. except Exception as e:
  97. print("ERROR: ", traceback.format_exc())
  98. self._on_error(e, traceback.format_exc(), None)
  99. continue
  100. add_community_info2graph(graph, ents, response["title"])
  101. res_str.append(self._get_text_output(response))
  102. res_dict.append(response)
  103. over += 1
  104. if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
  105. return CommunityReportsResult(
  106. structured_output=res_dict,
  107. output=res_str,
  108. )
  109. def _get_text_output(self, parsed_output: dict) -> str:
  110. title = parsed_output.get("title", "Report")
  111. summary = parsed_output.get("summary", "")
  112. findings = parsed_output.get("findings", [])
  113. def finding_summary(finding: dict):
  114. if isinstance(finding, str):
  115. return finding
  116. return finding.get("summary")
  117. def finding_explanation(finding: dict):
  118. if isinstance(finding, str):
  119. return ""
  120. return finding.get("explanation")
  121. report_sections = "\n\n".join(
  122. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  123. )
  124. return f"# {title}\n\n{summary}\n\n{report_sections}"