Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

entity_resolution.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. import traceback
  19. from dataclasses import dataclass
  20. from typing import Any
  21. import networkx as nx
  22. from rag.nlp import is_english
  23. import editdistance
  24. from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
  25. from rag.llm.chat_model import Base as CompletionLLM
  26. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  27. DEFAULT_RECORD_DELIMITER = "##"
  28. DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
  29. DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
  30. @dataclass
  31. class EntityResolutionResult:
  32. """Entity resolution result class definition."""
  33. output: nx.Graph
  34. class EntityResolution:
  35. """Entity resolution class definition."""
  36. _llm: CompletionLLM
  37. _resolution_prompt: str
  38. _output_formatter_prompt: str
  39. _on_error: ErrorHandlerFn
  40. _record_delimiter_key: str
  41. _entity_index_delimiter_key: str
  42. _resolution_result_delimiter_key: str
  43. def __init__(
  44. self,
  45. llm_invoker: CompletionLLM,
  46. resolution_prompt: str | None = None,
  47. on_error: ErrorHandlerFn | None = None,
  48. record_delimiter_key: str | None = None,
  49. entity_index_delimiter_key: str | None = None,
  50. resolution_result_delimiter_key: str | None = None,
  51. input_text_key: str | None = None
  52. ):
  53. """Init method definition."""
  54. self._llm = llm_invoker
  55. self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
  56. self._on_error = on_error or (lambda _e, _s, _d: None)
  57. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  58. self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
  59. self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
  60. self._input_text_key = input_text_key or "input_text"
  61. def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
  62. """Call method definition."""
  63. if prompt_variables is None:
  64. prompt_variables = {}
  65. # Wire defaults into the prompt variables
  66. prompt_variables = {
  67. **prompt_variables,
  68. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  69. or DEFAULT_RECORD_DELIMITER,
  70. self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
  71. or DEFAULT_ENTITY_INDEX_DELIMITER,
  72. self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
  73. or DEFAULT_RESOLUTION_RESULT_DELIMITER,
  74. }
  75. nodes = graph.nodes
  76. entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
  77. node_clusters = {entity_type: [] for entity_type in entity_types}
  78. for node in nodes:
  79. node_clusters[graph.nodes[node]['entity_type']].append(node)
  80. candidate_resolution = {entity_type: [] for entity_type in entity_types}
  81. for node_cluster in node_clusters.items():
  82. candidate_resolution_tmp = []
  83. for a in node_cluster[1]:
  84. for b in node_cluster[1]:
  85. if a == b:
  86. continue
  87. if self.is_similarity(a, b) and (b, a) not in candidate_resolution_tmp:
  88. candidate_resolution_tmp.append((a, b))
  89. if candidate_resolution_tmp:
  90. candidate_resolution[node_cluster[0]] = candidate_resolution_tmp
  91. gen_conf = {"temperature": 0.5}
  92. resolution_result = set()
  93. for candidate_resolution_i in candidate_resolution.items():
  94. if candidate_resolution_i[1]:
  95. try:
  96. pair_txt = [
  97. f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
  98. for index, candidate in enumerate(candidate_resolution_i[1]):
  99. pair_txt.append(
  100. f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
  101. sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
  102. pair_txt.append(
  103. f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
  104. pair_prompt = '\n'.join(pair_txt)
  105. variables = {
  106. **prompt_variables,
  107. self._input_text_key: pair_prompt
  108. }
  109. text = perform_variable_replacements(self._resolution_prompt, variables=variables)
  110. response = self._llm.chat(text, [], gen_conf)
  111. result = self._process_results(len(candidate_resolution_i[1]), response,
  112. prompt_variables.get(self._record_delimiter_key,
  113. DEFAULT_RECORD_DELIMITER),
  114. prompt_variables.get(self._entity_index_dilimiter_key,
  115. DEFAULT_ENTITY_INDEX_DELIMITER),
  116. prompt_variables.get(self._resolution_result_delimiter_key,
  117. DEFAULT_RESOLUTION_RESULT_DELIMITER))
  118. for result_i in result:
  119. resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
  120. except Exception as e:
  121. logging.exception("error entity resolution")
  122. self._on_error(e, traceback.format_exc(), None)
  123. connect_graph = nx.Graph()
  124. connect_graph.add_edges_from(resolution_result)
  125. for sub_connect_graph in nx.connected_components(connect_graph):
  126. sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
  127. remove_nodes = list(sub_connect_graph.nodes)
  128. keep_node = remove_nodes.pop()
  129. for remove_node in remove_nodes:
  130. remove_node_neighbors = graph[remove_node]
  131. graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
  132. graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
  133. remove_node_neighbors = list(remove_node_neighbors)
  134. for remove_node_neighbor in remove_node_neighbors:
  135. if remove_node_neighbor == keep_node:
  136. graph.remove_edge(keep_node, remove_node)
  137. continue
  138. if graph.has_edge(keep_node, remove_node_neighbor):
  139. graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
  140. 'weight']
  141. graph[keep_node][remove_node_neighbor]['description'] += \
  142. graph[remove_node][remove_node_neighbor]['description']
  143. graph.remove_edge(remove_node, remove_node_neighbor)
  144. else:
  145. graph.add_edge(keep_node, remove_node_neighbor,
  146. weight=graph[remove_node][remove_node_neighbor]['weight'],
  147. description=graph[remove_node][remove_node_neighbor]['description'],
  148. source_id="")
  149. graph.remove_edge(remove_node, remove_node_neighbor)
  150. graph.remove_node(remove_node)
  151. for node_degree in graph.degree:
  152. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  153. return EntityResolutionResult(
  154. output=graph,
  155. )
  156. def _process_results(
  157. self,
  158. records_length: int,
  159. results: str,
  160. record_delimiter: str,
  161. entity_index_delimiter: str,
  162. resolution_result_delimiter: str
  163. ) -> list:
  164. ans_list = []
  165. records = [r.strip() for r in results.split(record_delimiter)]
  166. for record in records:
  167. pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
  168. match_int = re.search(pattern_int, record)
  169. res_int = int(str(match_int.group(1) if match_int else '0'))
  170. if res_int > records_length:
  171. continue
  172. pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
  173. match_bool = re.search(pattern_bool, record)
  174. res_bool = str(match_bool.group(1) if match_bool else '')
  175. if res_int and res_bool:
  176. if res_bool.lower() == 'yes':
  177. ans_list.append((res_int, "yes"))
  178. return ans_list
  179. def is_similarity(self, a, b):
  180. if is_english(a) and is_english(b):
  181. if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
  182. return True
  183. if len(set(a) & set(b)) > 0:
  184. return True
  185. return False