Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

community_reports_extractor.py 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import json
  8. import logging
  9. import re
  10. import traceback
  11. from dataclasses import dataclass
  12. from typing import Any, List, Callable
  13. import networkx as nx
  14. import pandas as pd
  15. from graphrag import leiden
  16. from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
  17. from graphrag.leiden import add_community_info2graph
  18. from rag.llm.chat_model import Base as CompletionLLM
  19. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
  20. from rag.utils import num_tokens_from_string
  21. from timeit import default_timer as timer
  22. log = logging.getLogger(__name__)
  23. @dataclass
  24. class CommunityReportsResult:
  25. """Community reports result class definition."""
  26. output: List[str]
  27. structured_output: List[dict]
  28. class CommunityReportsExtractor:
  29. """Community reports extractor class definition."""
  30. _llm: CompletionLLM
  31. _extraction_prompt: str
  32. _output_formatter_prompt: str
  33. _on_error: ErrorHandlerFn
  34. _max_report_length: int
  35. def __init__(
  36. self,
  37. llm_invoker: CompletionLLM,
  38. extraction_prompt: str | None = None,
  39. on_error: ErrorHandlerFn | None = None,
  40. max_report_length: int | None = None,
  41. ):
  42. """Init method definition."""
  43. self._llm = llm_invoker
  44. self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
  45. self._on_error = on_error or (lambda _e, _s, _d: None)
  46. self._max_report_length = max_report_length or 1500
  47. def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  48. communities: dict[str, dict[str, List]] = leiden.run(graph, {})
  49. total = sum([len(comm.items()) for _, comm in communities.items()])
  50. relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
  51. res_str = []
  52. res_dict = []
  53. over, token_count = 0, 0
  54. st = timer()
  55. for level, comm in communities.items():
  56. for cm_id, ents in comm.items():
  57. weight = ents["weight"]
  58. ents = ents["nodes"]
  59. ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
  60. rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
  61. prompt_variables = {
  62. "entity_df": ent_df.to_csv(index_label="id"),
  63. "relation_df": rela_df.to_csv(index_label="id")
  64. }
  65. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  66. gen_conf = {"temperature": 0.3}
  67. try:
  68. response = self._llm.chat(text, [], gen_conf)
  69. token_count += num_tokens_from_string(text + response)
  70. response = re.sub(r"^[^\{]*", "", response)
  71. response = re.sub(r"[^\}]*$", "", response)
  72. print(response)
  73. response = json.loads(response)
  74. if not dict_has_keys_with_types(response, [
  75. ("title", str),
  76. ("summary", str),
  77. ("findings", list),
  78. ("rating", float),
  79. ("rating_explanation", str),
  80. ]): continue
  81. response["weight"] = weight
  82. response["entities"] = ents
  83. except Exception as e:
  84. print("ERROR: ", traceback.format_exc())
  85. self._on_error(e, traceback.format_exc(), None)
  86. continue
  87. add_community_info2graph(graph, ents, response["title"])
  88. res_str.append(self._get_text_output(response))
  89. res_dict.append(response)
  90. over += 1
  91. if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
  92. return CommunityReportsResult(
  93. structured_output=res_dict,
  94. output=res_str,
  95. )
  96. def _get_text_output(self, parsed_output: dict) -> str:
  97. title = parsed_output.get("title", "Report")
  98. summary = parsed_output.get("summary", "")
  99. findings = parsed_output.get("findings", [])
  100. def finding_summary(finding: dict):
  101. if isinstance(finding, str):
  102. return finding
  103. return finding.get("summary")
  104. def finding_explanation(finding: dict):
  105. if isinstance(finding, str):
  106. return ""
  107. return finding.get("explanation")
  108. report_sections = "\n\n".join(
  109. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  110. )
  111. return f"# {title}\n\n{summary}\n\n{report_sections}"