您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

entity_resolution.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import itertools
  18. import re
  19. import traceback
  20. from dataclasses import dataclass
  21. from typing import Any
  22. import networkx as nx
  23. from graphrag.extractor import Extractor
  24. from rag.nlp import is_english
  25. import editdistance
  26. from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
  27. from rag.llm.chat_model import Base as CompletionLLM
  28. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  29. DEFAULT_RECORD_DELIMITER = "##"
  30. DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
  31. DEFAULT_RESOLUTION_RESULT_DELIMITER = "&&"
  32. @dataclass
  33. class EntityResolutionResult:
  34. """Entity resolution result class definition."""
  35. output: nx.Graph
  36. class EntityResolution(Extractor):
  37. """Entity resolution class definition."""
  38. _resolution_prompt: str
  39. _output_formatter_prompt: str
  40. _on_error: ErrorHandlerFn
  41. _record_delimiter_key: str
  42. _entity_index_delimiter_key: str
  43. _resolution_result_delimiter_key: str
  44. def __init__(
  45. self,
  46. llm_invoker: CompletionLLM,
  47. resolution_prompt: str | None = None,
  48. on_error: ErrorHandlerFn | None = None,
  49. record_delimiter_key: str | None = None,
  50. entity_index_delimiter_key: str | None = None,
  51. resolution_result_delimiter_key: str | None = None,
  52. input_text_key: str | None = None
  53. ):
  54. """Init method definition."""
  55. self._llm = llm_invoker
  56. self._resolution_prompt = resolution_prompt or ENTITY_RESOLUTION_PROMPT
  57. self._on_error = on_error or (lambda _e, _s, _d: None)
  58. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  59. self._entity_index_dilimiter_key = entity_index_delimiter_key or "entity_index_delimiter"
  60. self._resolution_result_delimiter_key = resolution_result_delimiter_key or "resolution_result_delimiter"
  61. self._input_text_key = input_text_key or "input_text"
  62. def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
  63. """Call method definition."""
  64. if prompt_variables is None:
  65. prompt_variables = {}
  66. # Wire defaults into the prompt variables
  67. prompt_variables = {
  68. **prompt_variables,
  69. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  70. or DEFAULT_RECORD_DELIMITER,
  71. self._entity_index_dilimiter_key: prompt_variables.get(self._entity_index_dilimiter_key)
  72. or DEFAULT_ENTITY_INDEX_DELIMITER,
  73. self._resolution_result_delimiter_key: prompt_variables.get(self._resolution_result_delimiter_key)
  74. or DEFAULT_RESOLUTION_RESULT_DELIMITER,
  75. }
  76. nodes = graph.nodes
  77. entity_types = list(set(graph.nodes[node]['entity_type'] for node in nodes))
  78. node_clusters = {entity_type: [] for entity_type in entity_types}
  79. for node in nodes:
  80. node_clusters[graph.nodes[node]['entity_type']].append(node)
  81. candidate_resolution = {entity_type: [] for entity_type in entity_types}
  82. for k, v in node_clusters.items():
  83. candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
  84. gen_conf = {"temperature": 0.5}
  85. resolution_result = set()
  86. for candidate_resolution_i in candidate_resolution.items():
  87. if candidate_resolution_i[1]:
  88. try:
  89. pair_txt = [
  90. f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
  91. for index, candidate in enumerate(candidate_resolution_i[1]):
  92. pair_txt.append(
  93. f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
  94. sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
  95. pair_txt.append(
  96. f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
  97. pair_prompt = '\n'.join(pair_txt)
  98. variables = {
  99. **prompt_variables,
  100. self._input_text_key: pair_prompt
  101. }
  102. text = perform_variable_replacements(self._resolution_prompt, variables=variables)
  103. response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  104. result = self._process_results(len(candidate_resolution_i[1]), response,
  105. prompt_variables.get(self._record_delimiter_key,
  106. DEFAULT_RECORD_DELIMITER),
  107. prompt_variables.get(self._entity_index_dilimiter_key,
  108. DEFAULT_ENTITY_INDEX_DELIMITER),
  109. prompt_variables.get(self._resolution_result_delimiter_key,
  110. DEFAULT_RESOLUTION_RESULT_DELIMITER))
  111. for result_i in result:
  112. resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
  113. except Exception as e:
  114. logging.exception("error entity resolution")
  115. self._on_error(e, traceback.format_exc(), None)
  116. connect_graph = nx.Graph()
  117. connect_graph.add_edges_from(resolution_result)
  118. for sub_connect_graph in nx.connected_components(connect_graph):
  119. sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
  120. remove_nodes = list(sub_connect_graph.nodes)
  121. keep_node = remove_nodes.pop()
  122. for remove_node in remove_nodes:
  123. remove_node_neighbors = graph[remove_node]
  124. graph.nodes[keep_node]['description'] += graph.nodes[remove_node]['description']
  125. graph.nodes[keep_node]['weight'] += graph.nodes[remove_node]['weight']
  126. remove_node_neighbors = list(remove_node_neighbors)
  127. for remove_node_neighbor in remove_node_neighbors:
  128. if remove_node_neighbor == keep_node:
  129. graph.remove_edge(keep_node, remove_node)
  130. continue
  131. if graph.has_edge(keep_node, remove_node_neighbor):
  132. graph[keep_node][remove_node_neighbor]['weight'] += graph[remove_node][remove_node_neighbor][
  133. 'weight']
  134. graph[keep_node][remove_node_neighbor]['description'] += \
  135. graph[remove_node][remove_node_neighbor]['description']
  136. graph.remove_edge(remove_node, remove_node_neighbor)
  137. else:
  138. graph.add_edge(keep_node, remove_node_neighbor,
  139. weight=graph[remove_node][remove_node_neighbor]['weight'],
  140. description=graph[remove_node][remove_node_neighbor]['description'],
  141. source_id="")
  142. graph.remove_edge(remove_node, remove_node_neighbor)
  143. graph.remove_node(remove_node)
  144. for node_degree in graph.degree:
  145. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  146. return EntityResolutionResult(
  147. output=graph,
  148. )
  149. def _process_results(
  150. self,
  151. records_length: int,
  152. results: str,
  153. record_delimiter: str,
  154. entity_index_delimiter: str,
  155. resolution_result_delimiter: str
  156. ) -> list:
  157. ans_list = []
  158. records = [r.strip() for r in results.split(record_delimiter)]
  159. for record in records:
  160. pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
  161. match_int = re.search(pattern_int, record)
  162. res_int = int(str(match_int.group(1) if match_int else '0'))
  163. if res_int > records_length:
  164. continue
  165. pattern_bool = f"{re.escape(resolution_result_delimiter)}([a-zA-Z]+){re.escape(resolution_result_delimiter)}"
  166. match_bool = re.search(pattern_bool, record)
  167. res_bool = str(match_bool.group(1) if match_bool else '')
  168. if res_int and res_bool:
  169. if res_bool.lower() == 'yes':
  170. ans_list.append((res_int, "yes"))
  171. return ans_list
  172. def is_similarity(self, a, b):
  173. if is_english(a) and is_english(b):
  174. if editdistance.eval(a, b) <= min(len(a), len(b)) // 2:
  175. return True
  176. if len(set(a) & set(b)) > 0:
  177. return True
  178. return False