| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- # Copyright (c) 2024 Microsoft Corporation.
- # Licensed under the MIT License
- """
- Reference:
- - [graphrag](https://github.com/microsoft/graphrag)
- """
-
- import logging
- import json
- import re
- from typing import Callable
- from dataclasses import dataclass
- import networkx as nx
- import pandas as pd
- from graphrag.general import leiden
- from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
- from graphrag.general.extractor import Extractor
- from graphrag.general.leiden import add_community_info2graph
- from rag.llm.chat_model import Base as CompletionLLM
- from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
- from rag.utils import num_tokens_from_string
- import trio
-
-
- @dataclass
- class CommunityReportsResult:
- """Community reports result class definition."""
-
- output: list[str]
- structured_output: list[dict]
-
-
- class CommunityReportsExtractor(Extractor):
- """Community reports extractor class definition."""
-
- _extraction_prompt: str
- _output_formatter_prompt: str
- _max_report_length: int
-
- def __init__(
- self,
- llm_invoker: CompletionLLM,
- max_report_length: int | None = None,
- ):
- super().__init__(llm_invoker)
- """Init method definition."""
- self._llm = llm_invoker
- self._extraction_prompt = COMMUNITY_REPORT_PROMPT
- self._max_report_length = max_report_length or 1500
-
- async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
- for node_degree in graph.degree:
- graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
-
- communities: dict[str, dict[str, list]] = leiden.run(graph, {})
- total = sum([len(comm.items()) for _, comm in communities.items()])
- res_str = []
- res_dict = []
- over, token_count = 0, 0
- async def extract_community_report(community):
- nonlocal res_str, res_dict, over, token_count
- cm_id, cm = community
- weight = cm["weight"]
- ents = cm["nodes"]
- if len(ents) < 2:
- return
- ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
- ent_df = pd.DataFrame(ent_list)
-
- rela_list = []
- k = 0
- for i in range(0, len(ents)):
- if k >= 10000:
- break
- for j in range(i + 1, len(ents)):
- if k >= 10000:
- break
- edge = graph.get_edge_data(ents[i], ents[j])
- if edge is None:
- continue
- rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
- k += 1
- rela_df = pd.DataFrame(rela_list)
-
- prompt_variables = {
- "entity_df": ent_df.to_csv(index_label="id"),
- "relation_df": rela_df.to_csv(index_label="id")
- }
- text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
- gen_conf = {"temperature": 0.3}
- async with chat_limiter:
- response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
- token_count += num_tokens_from_string(text + response)
- response = re.sub(r"^[^\{]*", "", response)
- response = re.sub(r"[^\}]*$", "", response)
- response = re.sub(r"\{\{", "{", response)
- response = re.sub(r"\}\}", "}", response)
- logging.debug(response)
- try:
- response = json.loads(response)
- except json.JSONDecodeError as e:
- logging.error(f"Failed to parse JSON response: {e}")
- logging.error(f"Response content: {response}")
- return
- if not dict_has_keys_with_types(response, [
- ("title", str),
- ("summary", str),
- ("findings", list),
- ("rating", float),
- ("rating_explanation", str),
- ]):
- return
- response["weight"] = weight
- response["entities"] = ents
- add_community_info2graph(graph, ents, response["title"])
- res_str.append(self._get_text_output(response))
- res_dict.append(response)
- over += 1
- if callback:
- callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
-
- st = trio.current_time()
- async with trio.open_nursery() as nursery:
- for level, comm in communities.items():
- logging.info(f"Level {level}: Community: {len(comm.keys())}")
- for community in comm.items():
- nursery.start_soon(extract_community_report, community)
- if callback:
- callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
-
- return CommunityReportsResult(
- structured_output=res_dict,
- output=res_str,
- )
-
- def _get_text_output(self, parsed_output: dict) -> str:
- title = parsed_output.get("title", "Report")
- summary = parsed_output.get("summary", "")
- findings = parsed_output.get("findings", [])
-
- def finding_summary(finding: dict):
- if isinstance(finding, str):
- return finding
- return finding.get("summary")
-
- def finding_explanation(finding: dict):
- if isinstance(finding, str):
- return ""
- return finding.get("explanation")
-
- report_sections = "\n\n".join(
- f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
- )
- return f"# {title}\n\n{summary}\n\n{report_sections}"
|