您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106) ## Problem Description Multiple files in the RAGFlow project contain closure trap issues when using lambda functions with `trio.open_nursery()`. This problem causes concurrent tasks created in loops to reference the same variable, resulting in all tasks processing the same data (the data from the last iteration) rather than each task processing its corresponding data from the loop. ## Issue Details When using a `lambda` to create a closure function and passing it to `nursery.start_soon()` within a loop, the lambda function captures a reference to the loop variable rather than its value. For example: ```python # Problematic code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn)) ``` In this pattern, when concurrent tasks begin execution, `d` has already become the value after the loop ends (typically the last element), causing all tasks to use the same data. ## Fix Solution Changed the way concurrent tasks are created with `nursery.start_soon()` by leveraging Trio's API design to directly pass the function and its arguments separately: ```python # Fixed code async with trio.open_nursery() as nursery: for d in docs: nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn) ``` This way, each task uses the parameter values at the time of the function call, rather than references captured through closures. ## Fixed Files Fixed closure traps in the following files: 1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword extraction, question generation, and tag processing 2. `rag/raptor.py`: 1 fix, involving document summarization 3. `graphrag/utils.py`: 2 fixes, involving graph node and edge processing 4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution and graph node merging 5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document processing 6. `graphrag/general/extractor.py`: 3 fixes, involving content processing and graph node/edge merging 7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving community report extraction ## Potential Impact This fix resolves a serious concurrency issue that could have caused: - Data processing errors (processing duplicate data) - Performance degradation (all tasks working on the same data) - Inconsistent results (some data not being processed) After the fix, all concurrent tasks should correctly process their respective data, improving system correctness and reliability.
6 个月前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import umap
  19. import numpy as np
  20. from sklearn.mixture import GaussianMixture
  21. import trio
  22. from graphrag.utils import (
  23. get_llm_cache,
  24. get_embed_cache,
  25. set_embed_cache,
  26. set_llm_cache,
  27. chat_limiter,
  28. )
  29. from rag.utils import truncate
  30. class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
  31. def __init__(
  32. self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
  33. ):
  34. self._max_cluster = max_cluster
  35. self._llm_model = llm_model
  36. self._embd_model = embd_model
  37. self._threshold = threshold
  38. self._prompt = prompt
  39. self._max_token = max_token
  40. async def _chat(self, system, history, gen_conf):
  41. response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
  42. if response:
  43. return response
  44. response = await trio.to_thread.run_sync(
  45. lambda: self._llm_model.chat(system, history, gen_conf)
  46. )
  47. response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
  48. if response.find("**ERROR**") >= 0:
  49. raise Exception(response)
  50. set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
  51. return response
  52. async def _embedding_encode(self, txt):
  53. response = get_embed_cache(self._embd_model.llm_name, txt)
  54. if response is not None:
  55. return response
  56. embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
  57. if len(embds) < 1 or len(embds[0]) < 1:
  58. raise Exception("Embedding error: ")
  59. embds = embds[0]
  60. set_embed_cache(self._embd_model.llm_name, txt, embds)
  61. return embds
  62. def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int):
  63. max_clusters = min(self._max_cluster, len(embeddings))
  64. n_clusters = np.arange(1, max_clusters)
  65. bics = []
  66. for n in n_clusters:
  67. gm = GaussianMixture(n_components=n, random_state=random_state)
  68. gm.fit(embeddings)
  69. bics.append(gm.bic(embeddings))
  70. optimal_clusters = n_clusters[np.argmin(bics)]
  71. return optimal_clusters
  72. async def __call__(self, chunks, random_state, callback=None):
  73. if len(chunks) <= 1:
  74. return []
  75. chunks = [(s, a) for s, a in chunks if s and len(a) > 0]
  76. layers = [(0, len(chunks))]
  77. start, end = 0, len(chunks)
  78. async def summarize(ck_idx: list[int]):
  79. nonlocal chunks
  80. texts = [chunks[i][0] for i in ck_idx]
  81. len_per_chunk = int(
  82. (self._llm_model.max_length - self._max_token) / len(texts)
  83. )
  84. cluster_content = "\n".join(
  85. [truncate(t, max(1, len_per_chunk)) for t in texts]
  86. )
  87. async with chat_limiter:
  88. cnt = await self._chat(
  89. "You're a helpful assistant.",
  90. [
  91. {
  92. "role": "user",
  93. "content": self._prompt.format(
  94. cluster_content=cluster_content
  95. ),
  96. }
  97. ],
  98. {"temperature": 0.3, "max_tokens": self._max_token},
  99. )
  100. cnt = re.sub(
  101. "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
  102. "",
  103. cnt,
  104. )
  105. logging.debug(f"SUM: {cnt}")
  106. embds = await self._embedding_encode(cnt)
  107. chunks.append((cnt, embds))
  108. labels = []
  109. while end - start > 1:
  110. embeddings = [embd for _, embd in chunks[start:end]]
  111. if len(embeddings) == 2:
  112. await summarize([start, start + 1])
  113. if callback:
  114. callback(
  115. msg="Cluster one layer: {} -> {}".format(
  116. end - start, len(chunks) - end
  117. )
  118. )
  119. labels.extend([0, 0])
  120. layers.append((end, len(chunks)))
  121. start = end
  122. end = len(chunks)
  123. continue
  124. n_neighbors = int((len(embeddings) - 1) ** 0.8)
  125. reduced_embeddings = umap.UMAP(
  126. n_neighbors=max(2, n_neighbors),
  127. n_components=min(12, len(embeddings) - 2),
  128. metric="cosine",
  129. ).fit_transform(embeddings)
  130. n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
  131. if n_clusters == 1:
  132. lbls = [0 for _ in range(len(reduced_embeddings))]
  133. else:
  134. gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
  135. gm.fit(reduced_embeddings)
  136. probs = gm.predict_proba(reduced_embeddings)
  137. lbls = [np.where(prob > self._threshold)[0] for prob in probs]
  138. lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
  139. async with trio.open_nursery() as nursery:
  140. for c in range(n_clusters):
  141. ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
  142. assert len(ck_idx) > 0
  143. async with chat_limiter:
  144. nursery.start_soon(summarize, ck_idx)
  145. assert len(chunks) - end == n_clusters, "{} vs. {}".format(
  146. len(chunks) - end, n_clusters
  147. )
  148. labels.extend(lbls)
  149. layers.append((end, len(chunks)))
  150. if callback:
  151. callback(
  152. msg="Cluster one layer: {} -> {}".format(
  153. end - start, len(chunks) - end
  154. )
  155. )
  156. start = end
  157. end = len(chunks)
  158. return chunks