|
|
|
@@ -94,25 +94,52 @@ class EntityResolution(Extractor): |
|
|
|
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)] |
|
|
|
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) |
|
|
|
callback(msg=f"Identified {num_candidates} candidate pairs") |
|
|
|
remain_candidates_to_resolve = num_candidates |
|
|
|
|
|
|
|
resolution_result = set() |
|
|
|
resolution_result_lock = trio.Lock() |
|
|
|
resolution_batch_size = 100 |
|
|
|
max_concurrent_tasks = 5 |
|
|
|
semaphore = trio.Semaphore(max_concurrent_tasks) |
|
|
|
|
|
|
|
async def limited_resolve_candidate(candidate_batch, result_set, result_lock): |
|
|
|
nonlocal remain_candidates_to_resolve, callback |
|
|
|
async with semaphore: |
|
|
|
try: |
|
|
|
with trio.move_on_after(180) as cancel_scope: |
|
|
|
await self._resolve_candidate(candidate_batch, result_set, result_lock) |
|
|
|
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) |
|
|
|
callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ") |
|
|
|
if cancel_scope.cancelled_caught: |
|
|
|
logging.warning(f"Timeout resolving {candidate_batch}, skipping...") |
|
|
|
remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1]) |
|
|
|
callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ") |
|
|
|
except Exception as e: |
|
|
|
logging.error(f"Error resolving candidate batch: {e}") |
|
|
|
|
|
|
|
|
|
|
|
async with trio.open_nursery() as nursery: |
|
|
|
for candidate_resolution_i in candidate_resolution.items(): |
|
|
|
if not candidate_resolution_i[1]: |
|
|
|
continue |
|
|
|
for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size): |
|
|
|
candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size] |
|
|
|
nursery.start_soon(self._resolve_candidate, candidate_batch, resolution_result) |
|
|
|
nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock) |
|
|
|
|
|
|
|
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") |
|
|
|
|
|
|
|
change = GraphChange() |
|
|
|
connect_graph = nx.Graph() |
|
|
|
connect_graph.add_edges_from(resolution_result) |
|
|
|
|
|
|
|
async def limited_merge_nodes(graph, nodes, change): |
|
|
|
async with semaphore: |
|
|
|
await self._merge_graph_nodes(graph, nodes, change) |
|
|
|
|
|
|
|
async with trio.open_nursery() as nursery: |
|
|
|
for sub_connect_graph in nx.connected_components(connect_graph): |
|
|
|
merging_nodes = list(sub_connect_graph) |
|
|
|
nursery.start_soon(self._merge_graph_nodes, graph, merging_nodes, change) |
|
|
|
nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change) |
|
|
|
|
|
|
|
# Update pagerank |
|
|
|
pr = nx.pagerank(graph) |
|
|
|
@@ -124,7 +151,7 @@ class EntityResolution(Extractor): |
|
|
|
change=change, |
|
|
|
) |
|
|
|
|
|
|
|
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str]): |
|
|
|
async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock): |
|
|
|
gen_conf = {"temperature": 0.5} |
|
|
|
pair_txt = [ |
|
|
|
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n'] |
|
|
|
@@ -142,7 +169,16 @@ class EntityResolution(Extractor): |
|
|
|
text = perform_variable_replacements(self._resolution_prompt, variables=variables) |
|
|
|
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") |
|
|
|
async with chat_limiter: |
|
|
|
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) |
|
|
|
try: |
|
|
|
with trio.move_on_after(120) as cancel_scope: |
|
|
|
response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) |
|
|
|
if cancel_scope.cancelled_caught: |
|
|
|
logging.warning("_resolve_candidate._chat timeout, skipping...") |
|
|
|
return |
|
|
|
except Exception as e: |
|
|
|
logging.error(f"_resolve_candidate._chat failed: {e}") |
|
|
|
return |
|
|
|
|
|
|
|
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") |
|
|
|
result = self._process_results(len(candidate_resolution_i[1]), response, |
|
|
|
self.prompt_variables.get(self._record_delimiter_key, |
|
|
|
@@ -151,8 +187,9 @@ class EntityResolution(Extractor): |
|
|
|
DEFAULT_ENTITY_INDEX_DELIMITER), |
|
|
|
self.prompt_variables.get(self._resolution_result_delimiter_key, |
|
|
|
DEFAULT_RESOLUTION_RESULT_DELIMITER)) |
|
|
|
for result_i in result: |
|
|
|
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) |
|
|
|
async with resolution_result_lock: |
|
|
|
for result_i in result: |
|
|
|
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1]) |
|
|
|
|
|
|
|
def _process_results( |
|
|
|
self, |
|
|
|
@@ -185,6 +222,7 @@ class EntityResolution(Extractor): |
|
|
|
if is_english(a) and is_english(b): |
|
|
|
if editdistance.eval(a, b) <= min(len(a), len(b)) // 2: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
if len(set(a) & set(b)) > 1: |
|
|
|
return True |