瀏覽代碼

fix(nursery): Fix Closure Trap Issues in Trio Concurrent Tasks (#7106)

## Problem Description
Multiple files in the RAGFlow project contain closure trap issues when
using lambda functions with `trio.open_nursery()`. This problem causes
concurrent tasks created in loops to reference the same variable,
resulting in all tasks processing the same data (the data from the last
iteration) rather than each task processing its corresponding data from
the loop.

## Issue Details
When using a `lambda` to create a closure function and passing it to
`nursery.start_soon()` within a loop, the lambda function captures a
reference to the loop variable rather than its value. For example:

```python
# Problematic code
async with trio.open_nursery() as nursery:
    for d in docs:
        nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, topn))
```

In this pattern, when concurrent tasks begin execution, `d` has already
become the value after the loop ends (typically the last element),
causing all tasks to use the same data.

## Fix Solution
Changed the way concurrent tasks are created with `nursery.start_soon()`
by leveraging Trio's API design to directly pass the function and its
arguments separately:

```python
# Fixed code
async with trio.open_nursery() as nursery:
    for d in docs:
        nursery.start_soon(doc_keyword_extraction, chat_mdl, d, topn)
```

This way, each task uses the parameter values at the time of the
function call, rather than references captured through closures.

## Fixed Files
Fixed closure traps in the following files:

1. `rag/svr/task_executor.py`: 3 fixes, involving document keyword
extraction, question generation, and tag processing
2. `rag/raptor.py`: 1 fix, involving document summarization
3. `graphrag/utils.py`: 2 fixes, involving graph node and edge
processing
4. `graphrag/entity_resolution.py`: 2 fixes, involving entity resolution
and graph node merging
5. `graphrag/general/mind_map_extractor.py`: 2 fixes, involving document
processing
6. `graphrag/general/extractor.py`: 3 fixes, involving content
processing and graph node/edge merging
7. `graphrag/general/community_reports_extractor.py`: 1 fix, involving
community report extraction

## Potential Impact
This fix resolves a serious concurrency issue that could have caused:
- Data processing errors (processing duplicate data)
- Performance degradation (all tasks working on the same data)
- Inconsistent results (some data not being processed)

After the fix, all concurrent tasks should correctly process their
respective data, improving system correctness and reliability.
tags/v0.18.0
aniaan 6 月之前
父節點
當前提交
8b8a2f2949
沒有連結到貢獻者的電子郵件帳戶。

+ 2
- 2
graphrag/entity_resolution.py 查看文件

@@ -103,7 +103,7 @@ class EntityResolution(Extractor):
continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result))
nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result)
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")

change = GraphChange()
@@ -112,7 +112,7 @@ class EntityResolution(Extractor):
async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
nursery.start_soon(lambda: self._merge_graph_nodes(graph, merging_nodes, change))
nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change)

# Update pagerank
pr = nx.pagerank(graph)

+ 1
- 1
graphrag/general/community_reports_extractor.py 查看文件

@@ -124,7 +124,7 @@ class CommunityReportsExtractor(Extractor):
for level, comm in communities.items():
logging.info(f"Level {level}: Community: {len(comm.keys())}")
for community in comm.items():
nursery.start_soon(lambda: extract_community_report(community))
nursery.start_soon(extract_community_report, community)
if callback:
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")


+ 3
- 3
graphrag/general/extractor.py 查看文件

@@ -97,7 +97,7 @@ class Extractor:
async with trio.open_nursery() as nursery:
for i, ck in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8))
nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results))
nursery.start_soon(self._process_single_content, (doc_id, ck), i, len(chunks), out_results)

maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list)
@@ -116,7 +116,7 @@ class Extractor:
all_entities_data = []
async with trio.open_nursery() as nursery:
for en_nm, ents in maybe_nodes.items():
nursery.start_soon(lambda: self._merge_nodes(en_nm, ents, all_entities_data))
nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
now = trio.current_time()
if callback:
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
@@ -126,7 +126,7 @@ class Extractor:
all_relationships_data = []
async with trio.open_nursery() as nursery:
for (src, tgt), rels in maybe_edges.items():
nursery.start_soon(lambda: self._merge_edges(src, tgt, rels, all_relationships_data))
nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
now = trio.current_time()
if callback:
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")

+ 2
- 2
graphrag/general/mind_map_extractor.py 查看文件

@@ -93,13 +93,13 @@ class MindMapExtractor(Extractor):
for i in range(len(sections)):
section_cnt = num_tokens_from_string(sections[i])
if cnt + section_cnt >= token_count and texts:
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
texts = []
cnt = 0
texts.append(sections[i])
cnt += section_cnt
if texts:
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
if not res:
return MindMapResult(output={"id": "root", "children": []})
merge_json = reduce(self._merge, res)

+ 3
- 3
graphrag/utils.py 查看文件

@@ -439,7 +439,7 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
if change.removed_edges:
async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges:
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
now = trio.current_time()
if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
@@ -457,13 +457,13 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang
async with trio.open_nursery() as nursery:
for node in change.added_updated_nodes:
node_attrs = graph.nodes[node]
nursery.start_soon(lambda: graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks))
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
for from_node, to_node in change.added_updated_edges:
edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue
nursery.start_soon(lambda: graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks))
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
now = trio.current_time()
if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")

+ 1
- 1
rag/raptor.py 查看文件

@@ -152,7 +152,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
assert len(ck_idx) > 0
async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx))
nursery.start_soon(summarize, ck_idx)

assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters

+ 6
- 6
rag/svr/task_executor.py 查看文件

@@ -309,7 +309,7 @@ async def build_chunks(task, progress_callback):
return
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(lambda: doc_keyword_extraction(chat_mdl, d, task["parser_config"]["auto_keywords"]))
nursery.start_soon(doc_keyword_extraction, chat_mdl, d, task["parser_config"]["auto_keywords"])
progress_callback(msg="Keywords generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

if task["parser_config"].get("auto_questions", 0):
@@ -328,7 +328,7 @@ async def build_chunks(task, progress_callback):
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
async with trio.open_nursery() as nursery:
for d in docs:
nursery.start_soon(lambda: doc_question_proposal(chat_mdl, d, task["parser_config"]["auto_questions"]))
nursery.start_soon(doc_question_proposal, chat_mdl, d, task["parser_config"]["auto_questions"])
progress_callback(msg="Question generation {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

if task["kb_parser_config"].get("tag_kb_ids", []):
@@ -370,7 +370,7 @@ async def build_chunks(task, progress_callback):
d[TAG_FLD] = json.loads(cached)
async with trio.open_nursery() as nursery:
for d in docs_to_tag:
nursery.start_soon(lambda: doc_content_tagging(chat_mdl, d, topn_tags))
nursery.start_soon(doc_content_tagging, chat_mdl, d, topn_tags)
progress_callback(msg="Tagging {} chunks completed in {:.2f}s".format(len(docs), timer() - st))

return docs
@@ -653,11 +653,11 @@ async def report_status():

async def main():
logging.info(r"""
______ __ ______ __
______ __ ______ __
/_ __/___ ______/ /__ / ____/ _____ _______ __/ /_____ _____
/ / / __ `/ ___/ //_/ / __/ | |/_/ _ \/ ___/ / / / __/ __ \/ ___/
/ / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
/_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
/ / / /_/ (__ ) ,< / /____> </ __/ /__/ /_/ / /_/ /_/ / /
/_/ \__,_/____/_/|_| /_____/_/|_|\___/\___/\__,_/\__/\____/_/
""")
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
settings.init_settings()

Loading…
取消
儲存