浏览代码

EntityResolution batch. Close #6570 (#6602)

### What problem does this PR solve?

EntityResolution batch

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.18.0
Zhichang Yu 7 个月前
父节点
当前提交
36b62e0fab
没有帐户链接到提交者的电子邮件
共有 2 个文件被更改,包括 28 次插入20 次删除
  1. 10
    4
      graphrag/entity_resolution.py
  2. 18
    16
      graphrag/general/index.py

+ 10
- 4
graphrag/entity_resolution.py 查看文件

self._resolution_result_delimiter_key = "resolution_result_delimiter" self._resolution_result_delimiter_key = "resolution_result_delimiter"
self._input_text_key = "input_text" self._input_text_key = "input_text"


async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
async def __call__(self, graph: nx.Graph,
subgraph_nodes: set[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None) -> EntityResolutionResult:
"""Call method definition.""" """Call method definition."""
if prompt_variables is None: if prompt_variables is None:
prompt_variables = {} prompt_variables = {}


candidate_resolution = {entity_type: [] for entity_type in entity_types} candidate_resolution = {entity_type: [] for entity_type in entity_types}
for k, v in node_clusters.items(): for k, v in node_clusters.items():
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
callback(msg=f"Identified {num_candidates} candidate pairs") callback(msg=f"Identified {num_candidates} candidate pairs")


resolution_result = set() resolution_result = set()
resolution_batch_size = 100
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for candidate_resolution_i in candidate_resolution.items(): for candidate_resolution_i in candidate_resolution.items():
if not candidate_resolution_i[1]: if not candidate_resolution_i[1]:
continue continue
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
nursery.start_soon(lambda: self._resolve_candidate(candidate_batch, resolution_result))
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")


change = GraphChange() change = GraphChange()
change=change, change=change,
) )


async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]):
gen_conf = {"temperature": 0.5} gen_conf = {"temperature": 0.5}
pair_txt = [ pair_txt = [
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']

+ 18
- 16
graphrag/general/index.py 查看文件

embedding_model, embedding_model,
callback, callback,
) )
new_graph = None
if subgraph:
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
if not subgraph:
return

subgraph_nodes = set(subgraph.nodes())
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
assert new_graph is not None


if not with_resolution or not with_community: if not with_resolution or not with_community:
return return


if new_graph is None:
new_graph = await get_graph(tenant_id, kb_id)

if with_resolution and new_graph is not None:
if with_resolution:
await resolve_entities( await resolve_entities(
new_graph, new_graph,
subgraph_nodes,
tenant_id, tenant_id,
kb_id, kb_id,
doc_id, doc_id,
embedding_model, embedding_model,
callback, callback,
) )
if with_community and new_graph is not None:
if with_community:
await extract_community( await extract_community(
new_graph, new_graph,
tenant_id, tenant_id,


async def resolve_entities( async def resolve_entities(
graph, graph,
subgraph_nodes: set[str],
tenant_id: str, tenant_id: str,
kb_id: str, kb_id: str,
doc_id: str, doc_id: str,
er = EntityResolution( er = EntityResolution(
llm_bdl, llm_bdl,
) )
reso = await er(graph, callback=callback)
reso = await er(graph, subgraph_nodes, callback=callback)
graph = reso.graph graph = reso.graph
change = reso.change change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.") callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")

正在加载...
取消
保存