Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

community_reports_extractor.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import json
  9. import re
  10. from typing import Callable
  11. from dataclasses import dataclass
  12. import networkx as nx
  13. import pandas as pd
  14. from graphrag.general import leiden
  15. from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
  16. from graphrag.general.extractor import Extractor
  17. from graphrag.general.leiden import add_community_info2graph
  18. from rag.llm.chat_model import Base as CompletionLLM
  19. from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
  20. from rag.utils import num_tokens_from_string
  21. import trio
  22. @dataclass
  23. class CommunityReportsResult:
  24. """Community reports result class definition."""
  25. output: list[str]
  26. structured_output: list[dict]
  27. class CommunityReportsExtractor(Extractor):
  28. """Community reports extractor class definition."""
  29. _extraction_prompt: str
  30. _output_formatter_prompt: str
  31. _max_report_length: int
  32. def __init__(
  33. self,
  34. llm_invoker: CompletionLLM,
  35. get_entity: Callable | None = None,
  36. set_entity: Callable | None = None,
  37. get_relation: Callable | None = None,
  38. set_relation: Callable | None = None,
  39. max_report_length: int | None = None,
  40. ):
  41. super().__init__(llm_invoker, get_entity=get_entity, set_entity=set_entity, get_relation=get_relation, set_relation=set_relation)
  42. """Init method definition."""
  43. self._llm = llm_invoker
  44. self._extraction_prompt = COMMUNITY_REPORT_PROMPT
  45. self._max_report_length = max_report_length or 1500
  46. async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  47. for node_degree in graph.degree:
  48. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  49. communities: dict[str, dict[str, list]] = leiden.run(graph, {})
  50. total = sum([len(comm.items()) for _, comm in communities.items()])
  51. res_str = []
  52. res_dict = []
  53. over, token_count = 0, 0
  54. async def extract_community_report(community):
  55. nonlocal res_str, res_dict, over, token_count
  56. cm_id, ents = community
  57. weight = ents["weight"]
  58. ents = ents["nodes"]
  59. ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
  60. if ent_df.empty or "entity_name" not in ent_df.columns:
  61. return
  62. ent_df["entity"] = ent_df["entity_name"]
  63. del ent_df["entity_name"]
  64. rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
  65. if rela_df.empty:
  66. return
  67. rela_df["source"] = rela_df["src_id"]
  68. rela_df["target"] = rela_df["tgt_id"]
  69. del rela_df["src_id"]
  70. del rela_df["tgt_id"]
  71. prompt_variables = {
  72. "entity_df": ent_df.to_csv(index_label="id"),
  73. "relation_df": rela_df.to_csv(index_label="id")
  74. }
  75. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  76. gen_conf = {"temperature": 0.3}
  77. async with chat_limiter:
  78. response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
  79. token_count += num_tokens_from_string(text + response)
  80. response = re.sub(r"^[^\{]*", "", response)
  81. response = re.sub(r"[^\}]*$", "", response)
  82. response = re.sub(r"\{\{", "{", response)
  83. response = re.sub(r"\}\}", "}", response)
  84. logging.debug(response)
  85. try:
  86. response = json.loads(response)
  87. except json.JSONDecodeError as e:
  88. logging.error(f"Failed to parse JSON response: {e}")
  89. logging.error(f"Response content: {response}")
  90. return
  91. if not dict_has_keys_with_types(response, [
  92. ("title", str),
  93. ("summary", str),
  94. ("findings", list),
  95. ("rating", float),
  96. ("rating_explanation", str),
  97. ]):
  98. return
  99. response["weight"] = weight
  100. response["entities"] = ents
  101. add_community_info2graph(graph, ents, response["title"])
  102. res_str.append(self._get_text_output(response))
  103. res_dict.append(response)
  104. over += 1
  105. if callback:
  106. callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
  107. st = trio.current_time()
  108. async with trio.open_nursery() as nursery:
  109. for level, comm in communities.items():
  110. logging.info(f"Level {level}: Community: {len(comm.keys())}")
  111. for community in comm.items():
  112. nursery.start_soon(lambda: extract_community_report(community))
  113. if callback:
  114. callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
  115. return CommunityReportsResult(
  116. structured_output=res_dict,
  117. output=res_str,
  118. )
  119. def _get_text_output(self, parsed_output: dict) -> str:
  120. title = parsed_output.get("title", "Report")
  121. summary = parsed_output.get("summary", "")
  122. findings = parsed_output.get("findings", [])
  123. def finding_summary(finding: dict):
  124. if isinstance(finding, str):
  125. return finding
  126. return finding.get("summary")
  127. def finding_explanation(finding: dict):
  128. if isinstance(finding, str):
  129. return ""
  130. return finding.get("explanation")
  131. report_sections = "\n\n".join(
  132. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  133. )
  134. return f"# {title}\n\n{summary}\n\n{report_sections}"