Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

entity_resolution.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import itertools
  18. import os
  19. import re
  20. from dataclasses import dataclass
  21. from typing import Any, Callable
  22. import networkx as nx
  23. import trio
  24. from graphrag.general.extractor import Extractor
  25. from rag.nlp import is_english
  26. import editdistance
  27. from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
  28. from rag.llm.chat_model import Base as CompletionLLM
  29. from graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
  30. DEFAULT_RECORD_DELIMITER = "##"
  31. DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
  32. DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
  33. @dataclass
  34. class EntityResolutionResult:
  35. """Entity resolution result class definition."""
  36. graph: nx.Graph
  37. change: GraphChange
  38. class EntityResolution(Extractor):
  39. """Entity resolution class definition."""
  40. _resolution_prompt: str
  41. _output_formatter_prompt: str
  42. _record_delimiter_key: str
  43. _entity_index_delimiter_key: str
  44. _resolution_result_delimiter_key: str
  45. def __init__(
  46. self,
  47. llm_invoker: CompletionLLM,
  48. ):
  49. super().__init__(llm_invoker)
  50. """Init method definition."""
  51. self._llm = llm_invoker
  52. self._resolution_prompt = ENTITY_RESOLUTION_PROMPT
  53. self._record_delimiter_key = "record_delimiter"
  54. self._entity_index_dilimiter_key = "entity_index_delimiter"
  55. self._resolution_result_delimiter_key = "resolution_result_delimiter"
  56. self._input_text_key = "input_text"
  57. async def __call__(self, graph: nx.Graph,
  58. subgraph_nodes: set[str],
  59. prompt_variables: dict[str, Any] | None = None,
  60. callback: Callable | None = None) -> EntityResolutionResult:
  61. """Call method definition."""
  62. if prompt_variables is None:
  63. prompt_variables = {}
  64. # Wire defaults into the prompt variables
  65. self.prompt_variables = {
  66. **prompt_variables,
  67. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  68. or DEFAULT_RECORD_DELIMITER,
  69. self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
  70. or DEFAULT_ENTITY_INDEX_DELIMITER,
  71. self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
  72. or DEFAULT_RESOLUTION_RESULT_DELIMITER,
  73. }
  74. nodes = sorted(graph.nodes())
  75. entity_types = sorted(set(graph.nodes[node].get('entity_type', '-') for node in nodes))
  76. node_clusters = {entity_type: [] for entity_type in entity_types}
  77. for node in nodes:
  78. node_clusters[graph.nodes[node].get('entity_type', '-')].append(node)
  79. candidate_resolution = {entity_type: [] for entity_type in entity_types}
  80. for k, v in node_clusters.items():
  81. candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if (a in subgraph_nodes or b in subgraph_nodes) and self.is_similarity(a, b)]
  82. num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
  83. callback(msg=f"Identified {num_candidates} candidate pairs")
  84. remain_candidates_to_resolve = num_candidates
  85. resolution_result = set()
  86. resolution_result_lock = trio.Lock()
  87. resolution_batch_size = 100
  88. max_concurrent_tasks = 5
  89. semaphore = trio.Semaphore(max_concurrent_tasks)
  90. async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
  91. nonlocal remain_candidates_to_resolve, callback
  92. async with semaphore:
  93. try:
  94. enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
  95. with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
  96. await self._resolve_candidate(candidate_batch, result_set, result_lock)
  97. remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
  98. callback(msg=f"Resolved {len(candidate_batch[1])} pairs, {remain_candidates_to_resolve} are remained to resolve. ")
  99. if cancel_scope.cancelled_caught:
  100. logging.warning(f"Timeout resolving {candidate_batch}, skipping...")
  101. remain_candidates_to_resolve = remain_candidates_to_resolve - len(candidate_batch[1])
  102. callback(msg=f"Fail to resolved {len(candidate_batch[1])} pairs due to timeout reason, skipped. {remain_candidates_to_resolve} are remained to resolve. ")
  103. except Exception as e:
  104. logging.error(f"Error resolving candidate batch: {e}")
  105. async with trio.open_nursery() as nursery:
  106. for candidate_resolution_i in candidate_resolution.items():
  107. if not candidate_resolution_i[1]:
  108. continue
  109. for i in range(0, len(candidate_resolution_i[1]), resolution_batch_size):
  110. candidate_batch = candidate_resolution_i[0], candidate_resolution_i[1][i:i + resolution_batch_size]
  111. nursery.start_soon(limited_resolve_candidate, candidate_batch, resolution_result, resolution_result_lock)
  112. callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
  113. change = GraphChange()
  114. connect_graph = nx.Graph()
  115. connect_graph.add_edges_from(resolution_result)
  116. async def limited_merge_nodes(graph, nodes, change):
  117. async with semaphore:
  118. await self._merge_graph_nodes(graph, nodes, change)
  119. async with trio.open_nursery() as nursery:
  120. for sub_connect_graph in nx.connected_components(connect_graph):
  121. merging_nodes = list(sub_connect_graph)
  122. nursery.start_soon(limited_merge_nodes, graph, merging_nodes, change)
  123. # Update pagerank
  124. pr = nx.pagerank(graph)
  125. for node_name, pagerank in pr.items():
  126. graph.nodes[node_name]["pagerank"] = pagerank
  127. return EntityResolutionResult(
  128. graph=graph,
  129. change=change,
  130. )
  131. async def _resolve_candidate(self, candidate_resolution_i: tuple[str, list[tuple[str, str]]], resolution_result: set[str], resolution_result_lock: trio.Lock):
  132. pair_txt = [
  133. f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
  134. for index, candidate in enumerate(candidate_resolution_i[1]):
  135. pair_txt.append(
  136. f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
  137. sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
  138. pair_txt.append(
  139. f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
  140. pair_prompt = '\n'.join(pair_txt)
  141. variables = {
  142. **self.prompt_variables,
  143. self._input_text_key: pair_prompt
  144. }
  145. text = perform_variable_replacements(self._resolution_prompt, variables=variables)
  146. logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
  147. async with chat_limiter:
  148. try:
  149. enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
  150. with trio.move_on_after(280 if enable_timeout_assertion else 1000000000) as cancel_scope:
  151. response = await trio.to_thread.run_sync(self._chat, text, [{"role": "user", "content": "Output:"}], {})
  152. if cancel_scope.cancelled_caught:
  153. logging.warning("_resolve_candidate._chat timeout, skipping...")
  154. return
  155. except Exception as e:
  156. logging.error(f"_resolve_candidate._chat failed: {e}")
  157. return
  158. logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
  159. result = self._process_results(len(candidate_resolution_i[1]), response,
  160. self.prompt_variables.get(self._record_delimiter_key,
  161. DEFAULT_RECORD_DELIMITER),
  162. self.prompt_variables.get(self._entity_index_dilimiter_key,
  163. DEFAULT_ENTITY_INDEX_DELIMITER),
  164. self.prompt_variables.get(self._resolution_result_delimiter_key,
  165. DEFAULT_RESOLUTION_RESULT_DELIMITER))
  166. async with resolution_result_lock:
  167. for result_i in result:
  168. resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
  169. def _process_results(
  170. self,
  171. records_length: int,
  172. results: str,
  173. record_delimiter: str,
  174. entity_index_delimiter: str,
  175. resolution_result_delimiter: str
  176. ) -> list:
  177. ans_list = []
  178. records = [r.strip() for r in results.split(record_delimiter)]
  179. for record in records:
  180. pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
  181. match_int = re.search(pattern_int, record)
  182. res_int = int(str(match_int.group(1) if match_int else '0'))
  183. if res_int > records_length:
  184. continue
  185. pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
  186. match_bool = re.search(pattern_bool, record)
  187. res_bool = str(match_bool.group(1) if match_bool else '')
  188. if res_int and res_bool:
  189. if res_bool.lower() == 'yes':
  190. ans_list.append((res_int, "yes"))
  191. return ans_list
  192. def _has_digit_in_2gram_diff(self, a, b):
  193. def to_2gram_set(s):
  194. return {s[i:i+2] for i in range(len(s) - 1)}
  195. set_a = to_2gram_set(a)
  196. set_b = to_2gram_set(b)
  197. diff = set_a ^ set_b
  198. return any(any(c.isdigit() for c in pair) for pair in diff)
  199. def is_similarity(self, a, b):
  200. if self._has_digit_in_2gram_diff(a, b):
  201. return False
  202. if is_english(a) and is_english(b):
  203. if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
  204. return True
  205. return False
  206. a, b = set(a), set(b)
  207. max_l = max(len(a), len(b))
  208. if max_l < 4:
  209. return len(a & b) > 1
  210. return len(a & b)*1./max_l >= 0.8