選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

community_reports_extractor.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import json
  8. import re
  9. import traceback
  10. from dataclasses import dataclass
  11. from typing import List, Callable
  12. import networkx as nx
  13. import pandas as pd
  14. from graphrag import leiden
  15. from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT
  16. from graphrag.leiden import add_community_info2graph
  17. from rag.llm.chat_model import Base as CompletionLLM
  18. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
  19. from rag.utils import num_tokens_from_string
  20. from timeit import default_timer as timer
  21. from api.utils.log_utils import logger
  22. @dataclass
  23. class CommunityReportsResult:
  24. """Community reports result class definition."""
  25. output: List[str]
  26. structured_output: List[dict]
  27. class CommunityReportsExtractor:
  28. """Community reports extractor class definition."""
  29. _llm: CompletionLLM
  30. _extraction_prompt: str
  31. _output_formatter_prompt: str
  32. _on_error: ErrorHandlerFn
  33. _max_report_length: int
  34. def __init__(
  35. self,
  36. llm_invoker: CompletionLLM,
  37. extraction_prompt: str | None = None,
  38. on_error: ErrorHandlerFn | None = None,
  39. max_report_length: int | None = None,
  40. ):
  41. """Init method definition."""
  42. self._llm = llm_invoker
  43. self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT
  44. self._on_error = on_error or (lambda _e, _s, _d: None)
  45. self._max_report_length = max_report_length or 1500
  46. def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  47. communities: dict[str, dict[str, List]] = leiden.run(graph, {})
  48. total = sum([len(comm.items()) for _, comm in communities.items()])
  49. relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)])
  50. res_str = []
  51. res_dict = []
  52. over, token_count = 0, 0
  53. st = timer()
  54. for level, comm in communities.items():
  55. for cm_id, ents in comm.items():
  56. weight = ents["weight"]
  57. ents = ents["nodes"]
  58. ent_df = pd.DataFrame([{"entity": n, **graph.nodes[n]} for n in ents])
  59. rela_df = relations_df[(relations_df["source"].isin(ents)) | (relations_df["target"].isin(ents))].reset_index(drop=True)
  60. prompt_variables = {
  61. "entity_df": ent_df.to_csv(index_label="id"),
  62. "relation_df": rela_df.to_csv(index_label="id")
  63. }
  64. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  65. gen_conf = {"temperature": 0.3}
  66. try:
  67. response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  68. token_count += num_tokens_from_string(text + response)
  69. response = re.sub(r"^[^\{]*", "", response)
  70. response = re.sub(r"[^\}]*$", "", response)
  71. response = re.sub(r"\{\{", "{", response)
  72. response = re.sub(r"\}\}", "}", response)
  73. logger.info(response)
  74. response = json.loads(response)
  75. if not dict_has_keys_with_types(response, [
  76. ("title", str),
  77. ("summary", str),
  78. ("findings", list),
  79. ("rating", float),
  80. ("rating_explanation", str),
  81. ]): continue
  82. response["weight"] = weight
  83. response["entities"] = ents
  84. except Exception as e:
  85. logger.exception("CommunityReportsExtractor got exception")
  86. self._on_error(e, traceback.format_exc(), None)
  87. continue
  88. add_community_info2graph(graph, ents, response["title"])
  89. res_str.append(self._get_text_output(response))
  90. res_dict.append(response)
  91. over += 1
  92. if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
  93. return CommunityReportsResult(
  94. structured_output=res_dict,
  95. output=res_str,
  96. )
  97. def _get_text_output(self, parsed_output: dict) -> str:
  98. title = parsed_output.get("title", "Report")
  99. summary = parsed_output.get("summary", "")
  100. findings = parsed_output.get("findings", [])
  101. def finding_summary(finding: dict):
  102. if isinstance(finding, str):
  103. return finding
  104. return finding.get("summary")
  105. def finding_explanation(finding: dict):
  106. if isinstance(finding, str):
  107. return ""
  108. return finding.get("explanation")
  109. report_sections = "\n\n".join(
  110. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  111. )
  112. return f"# {title}\n\n{summary}\n\n{report_sections}"