Pārlūkot izejas kodu

Refa: GraphRAG and explaining GraphRAG stalling behavior on large files (#8223)

### What problem does this PR solve?

This PR investigates the cause of #7957.

TL;DR: Incorrect similarity calculations lead to too many candidates.
Since candidate selection involves interaction with the LLM, this causes
significant delays in the program.

What this PR does:

1. **Fix similarity calculation**:
When processing a 64 pages government document, the corrected similarity
calculation reduces the number of candidates from over 100,000 to around
16,000. With a default batch size of 100 pairs per LLM call, this fix
reduces unnecessary LLM interactions from over 1,000 calls to around
160, a roughly 10x improvement.
2. **Add concurrency and timeout limits**: 
Up to 5 entity types are processed in "parallel", each with a 180-second
timeout. These limits may be configurable in future updates.
3. **Improve logging**:
The candidate resolution process now reports progress in real time.
4. **Mitigates potential concurrency risks**


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
tags/v0.19.1
Yongteng Lei pirms 4 mēnešiem
vecāks
revīzija
24ca4cc6b7
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 44
- 6
graphrag/entity_resolution.py Parādīt failu

@@ -94,25 +94,52 @@ class EntityResolution(Extractor):
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
callback(msg=f"Identified {num_candidates} candidate pairs")
remain_candidates_to_resolve = num_candidates

resolution_result = set()
resolution_result_lock = trio.Lock()
resolution_batch_size = 100
max_concurrent_tasks = 5
semaphore = trio.Semaphore(max_concurrent_tasks)

async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback
async with semaphore:
try:
with trio.move_on_after(180) as cancel_scope:
await self._resolve_candidate(candidate_batch, result_set, result_lock)
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
if cancel_scope.cancelled_caught:
logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
except Exception as e:
logging.error(f"Error resolving candidate batch: {e}")


async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]:
continue
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result)
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)

callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")

change = GraphChange()
connect_graph = nx.Graph()
connect_graph.add_edges_from(resolution_result)

async def limited_merge_nodes(graph, nodes, change):
async with semaphore:
await self._merge_graph_nodes(graph, nodes, change)

async with trio.open_nursery() as nursery:
for sub_connect_graph in nx.connected_components(connect_graph):
merging_nodes = list(sub_connect_graph)
nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change)
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)

# Update pagerank
pr = nx.pagerank(graph)
@@ -124,7 +151,7 @@ class EntityResolution(Extractor):
change=change,
)

async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
gen_conf = {"temperature": 0.5}
pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
@@ -142,7 +169,16 @@ class EntityResolution(Extractor):
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
try:
with trio.move_on_after(120) as cancel_scope:
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
if cancel_scope.cancelled_caught:
logging.warning("_resolve_candidate._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
return

logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
result = self._process_results(len(candidate_resolution_i[1]), response,
self.prompt_variables.get(self._record_delimiter_key,
@@ -151,8 +187,9 @@ class EntityResolution(Extractor):
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
async with resolution_result_lock:
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])

def _process_results(
self,
@@ -185,6 +222,7 @@ class EntityResolution(Extractor):
if is_english(a) and is_english(b):
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
return True
return False

if len(set(a) & set(b)) > 1:
return True

+ 9
- 1
graphrag/general/community_reports_extractor.py Parādīt failu

@@ -89,7 +89,15 @@ class CommunityReportsExtractor(Extractor):
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
gen_conf = {"temperature": 0.3}
async with chat_limiter:
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
try:
with trio.move_on_after(120) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")
return
except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}")
return
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)
response = re.sub(r"[^\}]*$", "", response)

+ 37
- 9
rag/svr/task_executor.py Parādīt failu

@@ -21,6 +21,8 @@ import sys
import threading
import time

from valkey import RedisError

from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@@ -187,18 +189,44 @@ async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
global UNACKED_ITERATOR
svr_queue_names = get_svr_queue_names()
redis_msg = None

try:
if not UNACKED_ITERATOR:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
UNACKED_ITERATOR = None
logging.debug("Rebuilding UNACKED_ITERATOR due to it is None")
try:
UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
logging.debug("UNACKED_ITERATOR rebuilt successfully")
except RedisError as e:
UNACKED_ITERATOR = None
logging.warning(f"Failed to rebuild UNACKED_ITERATOR: {e}")

if UNACKED_ITERATOR:
try:
redis_msg = next(UNACKED_ITERATOR)
except StopIteration:
UNACKED_ITERATOR = None
logging.debug("UNACKED_ITERATOR exhausted, clearing")

except Exception as e:
UNACKED_ITERATOR = None
logging.warning(f"UNACKED_ITERATOR raised exception: {e}")

if not redis_msg:
for svr_queue_name in svr_queue_names:
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
if redis_msg:
break
except Exception:
logging.exception("collect got exception")
try:
redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME)
if redis_msg:
break
except RedisError as e:
logging.warning(f"queue_consumer failed for {svr_queue_name}: {e}")
continue

except Exception as e:
logging.exception(f"collect task encountered unexpected exception: {e}")
UNACKED_ITERATOR = None
await trio.sleep(1)
return None, None

if not redis_msg:

Notiek ielāde…
Atcelt
Saglabāt