You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

community_reports_extractor.py 6.2KB

fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106) ## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
6 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import json
  9. import re
  10. from typing import Callable
  11. from dataclasses import dataclass
  12. import networkx as nx
  13. import pandas as pd
  14. from api.utils.api_utils import timeout
  15. from graphrag.general import leiden
  16. from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
  17. from graphrag.general.extractor import Extractor
  18. from graphrag.general.leiden import add_community_info2graph
  19. from rag.llm.chat_model import Base as CompletionLLM
  20. from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
  21. from rag.utils import num_tokens_from_string
  22. import trio
  23. @dataclass
  24. class CommunityReportsResult:
  25. """Community reports result class definition."""
  26. output: list[str]
  27. structured_output: list[dict]
  28. class CommunityReportsExtractor(Extractor):
  29. """Community reports extractor class definition."""
  30. _extraction_prompt: str
  31. _output_formatter_prompt: str
  32. _max_report_length: int
  33. def __init__(
  34. self,
  35. llm_invoker: CompletionLLM,
  36. max_report_length: int | None = None,
  37. ):
  38. super().__init__(llm_invoker)
  39. """Init method definition."""
  40. self._llm = llm_invoker
  41. self._extraction_prompt = COMMUNITY_REPORT_PROMPT
  42. self._max_report_length = max_report_length or 1500
  43. async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
  44. for node_degree in graph.degree:
  45. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  46. communities: dict[str, dict[str, list]] = leiden.run(graph, {})
  47. total = sum([len(comm.items()) for _, comm in communities.items()])
  48. res_str = []
  49. res_dict = []
  50. over, token_count = 0, 0
  51. @timeout(120)
  52. async def extract_community_report(community):
  53. nonlocal res_str, res_dict, over, token_count
  54. cm_id, cm = community
  55. weight = cm["weight"]
  56. ents = cm["nodes"]
  57. if len(ents) < 2:
  58. return
  59. ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
  60. ent_df = pd.DataFrame(ent_list)
  61. rela_list = []
  62. k = 0
  63. for i in range(0, len(ents)):
  64. if k >= 10000:
  65. break
  66. for j in range(i + 1, len(ents)):
  67. if k >= 10000:
  68. break
  69. edge = graph.get_edge_data(ents[i], ents[j])
  70. if edge is None:
  71. continue
  72. rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
  73. k += 1
  74. rela_df = pd.DataFrame(rela_list)
  75. prompt_variables = {
  76. "entity_df": ent_df.to_csv(index_label="id"),
  77. "relation_df": rela_df.to_csv(index_label="id")
  78. }
  79. text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
  80. async with chat_limiter:
  81. try:
  82. with trio.move_on_after(180) as cancel_scope:
  83. response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], {})
  84. if cancel_scope.cancelled_caught:
  85. logging.warning("extract_community_report._chat timeout, skipping...")
  86. return
  87. except Exception as e:
  88. logging.error(f"extract_community_report._chat failed: {e}")
  89. return
  90. token_count += num_tokens_from_string(text + response)
  91. response = re.sub(r"^[^\{]*", "", response)
  92. response = re.sub(r"[^\}]*$", "", response)
  93. response = re.sub(r"\{\{", "{", response)
  94. response = re.sub(r"\}\}", "}", response)
  95. logging.debug(response)
  96. try:
  97. response = json.loads(response)
  98. except json.JSONDecodeError as e:
  99. logging.error(f"Failed to parse JSON response: {e}")
  100. logging.error(f"Response content: {response}")
  101. return
  102. if not dict_has_keys_with_types(response, [
  103. ("title", str),
  104. ("summary", str),
  105. ("findings", list),
  106. ("rating", float),
  107. ("rating_explanation", str),
  108. ]):
  109. return
  110. response["weight"] = weight
  111. response["entities"] = ents
  112. add_community_info2graph(graph, ents, response["title"])
  113. res_str.append(self._get_text_output(response))
  114. res_dict.append(response)
  115. over += 1
  116. if callback:
  117. callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
  118. st = trio.current_time()
  119. async with trio.open_nursery() as nursery:
  120. for level, comm in communities.items():
  121. logging.info(f"Level {level}: Community: {len(comm.keys())}")
  122. for community in comm.items():
  123. nursery.start_soon(extract_community_report, community)
  124. if callback:
  125. callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
  126. return CommunityReportsResult(
  127. structured_output=res_dict,
  128. output=res_str,
  129. )
  130. def _get_text_output(self, parsed_output: dict) -> str:
  131. title = parsed_output.get("title", "Report")
  132. summary = parsed_output.get("summary", "")
  133. findings = parsed_output.get("findings", [])
  134. def finding_summary(finding: dict):
  135. if isinstance(finding, str):
  136. return finding
  137. return finding.get("summary")
  138. def finding_explanation(finding: dict):
  139. if isinstance(finding, str):
  140. return ""
  141. return finding.get("explanation")
  142. report_sections = "\n\n".join(
  143. f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
  144. )
  145. return f"# {title}\n\n{summary}\n\n{report_sections}"