Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

mind_map_extractor.py 6.3KB

fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106) ## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
il y a 6 mois
fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106) ## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
il y a 6 mois
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import collections
  18. import re
  19. from typing import Any
  20. from dataclasses import dataclass
  21. import trio
  22. from graphrag.general.extractor import Extractor
  23. from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  24. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
  25. from rag.llm.chat_model import Base as CompletionLLM
  26. import markdown_to_json
  27. from functools import reduce
  28. from rag.utils import num_tokens_from_string
  29. @dataclass
  30. class MindMapResult:
  31. """Unipartite Mind Graph result class definition."""
  32. output: dict
  33. class MindMapExtractor(Extractor):
  34. _input_text_key: str
  35. _mind_map_prompt: str
  36. _on_error: ErrorHandlerFn
  37. def __init__(
  38. self,
  39. llm_invoker: CompletionLLM,
  40. prompt: str | None = None,
  41. input_text_key: str | None = None,
  42. on_error: ErrorHandlerFn | None = None,
  43. ):
  44. """Init method definition."""
  45. # TODO: streamline construction
  46. self._llm = llm_invoker
  47. self._input_text_key = input_text_key or "input_text"
  48. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  49. self._on_error = on_error or (lambda _e, _s, _d: None)
  50. def _key(self, k):
  51. return re.sub(r"\*+", "", k)
  52. def _be_children(self, obj: dict, keyset: set):
  53. if isinstance(obj, str):
  54. obj = [obj]
  55. if isinstance(obj, list):
  56. keyset.update(obj)
  57. obj = [re.sub(r"\*+", "", i) for i in obj]
  58. return [{"id": i, "children": []} for i in obj if i]
  59. arr = []
  60. for k, v in obj.items():
  61. k = self._key(k)
  62. if k and k not in keyset:
  63. keyset.add(k)
  64. arr.append(
  65. {
  66. "id": k,
  67. "children": self._be_children(v, keyset)
  68. }
  69. )
  70. return arr
  71. async def __call__(
  72. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  73. ) -> MindMapResult:
  74. """Call method definition."""
  75. if prompt_variables is None:
  76. prompt_variables = {}
  77. res = []
  78. token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
  79. texts = []
  80. cnt = 0
  81. async with trio.open_nursery() as nursery:
  82. for i in range(len(sections)):
  83. section_cnt = num_tokens_from_string(sections[i])
  84. if cnt + section_cnt >= token_count and texts:
  85. nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
  86. texts = []
  87. cnt = 0
  88. texts.append(sections[i])
  89. cnt += section_cnt
  90. if texts:
  91. nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
  92. if not res:
  93. return MindMapResult(output={"id": "root", "children": []})
  94. merge_json = reduce(self._merge, res)
  95. if len(merge_json) > 1:
  96. keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
  97. keyset = set(i for i in keys if i)
  98. merge_json = {
  99. "id": "root",
  100. "children": [
  101. {
  102. "id": self._key(k),
  103. "children": self._be_children(v, keyset)
  104. }
  105. for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
  106. ]
  107. }
  108. else:
  109. k = self._key(list(merge_json.keys())[0])
  110. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
  111. return MindMapResult(output=merge_json)
  112. def _merge(self, d1, d2):
  113. for k in d1:
  114. if k in d2:
  115. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  116. self._merge(d1[k], d2[k])
  117. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  118. d2[k].extend(d1[k])
  119. else:
  120. d2[k] = d1[k]
  121. else:
  122. d2[k] = d1[k]
  123. return d2
  124. def _list_to_kv(self, data):
  125. for key, value in data.items():
  126. if isinstance(value, dict):
  127. self._list_to_kv(value)
  128. elif isinstance(value, list):
  129. new_value = {}
  130. for i in range(len(value)):
  131. if isinstance(value[i], list) and i > 0:
  132. new_value[value[i - 1]] = value[i][0]
  133. data[key] = new_value
  134. else:
  135. continue
  136. return data
  137. def _todict(self, layer: collections.OrderedDict):
  138. to_ret = layer
  139. if isinstance(layer, collections.OrderedDict):
  140. to_ret = dict(layer)
  141. try:
  142. for key, value in to_ret.items():
  143. to_ret[key] = self._todict(value)
  144. except AttributeError:
  145. pass
  146. return self._list_to_kv(to_ret)
  147. async def _process_document(
  148. self, text: str, prompt_variables: dict[str, str], out_res
  149. ) -> str:
  150. variables = {
  151. **prompt_variables,
  152. self._input_text_key: text,
  153. }
  154. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  155. async with chat_limiter:
  156. response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], {}))
  157. response = re.sub(r"```[^\n]*", "", response)
  158. logging.debug(response)
  159. logging.debug(self._todict(markdown_to_json.dictify(response)))
  160. out_res.append(self._todict(markdown_to_json.dictify(response)))