You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extractor.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import re
  18. from collections import defaultdict, Counter
  19. from copy import deepcopy
  20. from typing import Callable
  21. import trio
  22. import networkx as nx
  23. from api.utils.api_utils import timeout
  24. from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
  25. from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
  26. handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
  27. from rag.llm.chat_model import Base as CompletionLLM
  28. from rag.prompts import message_fit_in
  29. from rag.utils import truncate
  30. GRAPH_FIELD_SEP = "<SEP>"
  31. DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
  32. ENTITY_EXTRACTION_MAX_GLEANINGS = 2
  33. class Extractor:
  34. _llm: CompletionLLM
  35. def __init__(
  36. self,
  37. llm_invoker: CompletionLLM,
  38. language: str | None = "English",
  39. entity_types: list[str] | None = None,
  40. ):
  41. self._llm = llm_invoker
  42. self._language = language
  43. self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
  44. @timeout(60*5)
  45. def _chat(self, system, history, gen_conf={}):
  46. hist = deepcopy(history)
  47. conf = deepcopy(gen_conf)
  48. response = get_llm_cache(self._llm.llm_name, system, hist, conf)
  49. if response:
  50. return response
  51. _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
  52. for attempt in range(3):
  53. try:
  54. response = self._llm.chat(system_msg[0]["content"], hist, conf)
  55. response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
  56. if response.find("**ERROR**") >= 0:
  57. raise Exception(response)
  58. set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
  59. except Exception as e:
  60. logging.exception(e)
  61. if attempt == 2:
  62. raise
  63. return response
  64. def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
  65. maybe_nodes = defaultdict(list)
  66. maybe_edges = defaultdict(list)
  67. ent_types = [t.lower() for t in self._entity_types]
  68. for record in records:
  69. record_attributes = split_string_by_multi_markers(
  70. record, [tuple_delimiter]
  71. )
  72. if_entities = handle_single_entity_extraction(
  73. record_attributes, chunk_key
  74. )
  75. if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
  76. maybe_nodes[if_entities["entity_name"]].append(if_entities)
  77. continue
  78. if_relation = handle_single_relationship_extraction(
  79. record_attributes, chunk_key
  80. )
  81. if if_relation is not None:
  82. maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
  83. if_relation
  84. )
  85. return dict(maybe_nodes), dict(maybe_edges)
  86. async def __call__(
  87. self, doc_id: str, chunks: list[str],
  88. callback: Callable | None = None
  89. ):
  90. self.callback = callback
  91. start_ts = trio.current_time()
  92. out_results = []
  93. async with trio.open_nursery() as nursery:
  94. for i, ck in enumerate(chunks):
  95. ck = truncate(ck, int(self._llm.max_length*0.8))
  96. nursery.start_soon(self._process_single_content, (doc_id, ck), i, len(chunks), out_results)
  97. maybe_nodes = defaultdict(list)
  98. maybe_edges = defaultdict(list)
  99. sum_token_count = 0
  100. for m_nodes, m_edges, token_count in out_results:
  101. for k, v in m_nodes.items():
  102. maybe_nodes[k].extend(v)
  103. for k, v in m_edges.items():
  104. maybe_edges[tuple(sorted(k))].extend(v)
  105. sum_token_count += token_count
  106. now = trio.current_time()
  107. if callback:
  108. callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
  109. start_ts = now
  110. logging.info("Entities merging...")
  111. all_entities_data = []
  112. async with trio.open_nursery() as nursery:
  113. for en_nm, ents in maybe_nodes.items():
  114. nursery.start_soon(self._merge_nodes, en_nm, ents, all_entities_data)
  115. now = trio.current_time()
  116. if callback:
  117. callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
  118. start_ts = now
  119. logging.info("Relationships merging...")
  120. all_relationships_data = []
  121. async with trio.open_nursery() as nursery:
  122. for (src, tgt), rels in maybe_edges.items():
  123. nursery.start_soon(self._merge_edges, src, tgt, rels, all_relationships_data)
  124. now = trio.current_time()
  125. if callback:
  126. callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
  127. if not len(all_entities_data) and not len(all_relationships_data):
  128. logging.warning(
  129. "Didn't extract any entities and relationships, maybe your LLM is not working"
  130. )
  131. if not len(all_entities_data):
  132. logging.warning("Didn't extract any entities")
  133. if not len(all_relationships_data):
  134. logging.warning("Didn't extract any relationships")
  135. return all_entities_data, all_relationships_data
  136. async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
  137. if not entities:
  138. return
  139. entity_type = sorted(
  140. Counter(
  141. [dp["entity_type"] for dp in entities]
  142. ).items(),
  143. key=lambda x: x[1],
  144. reverse=True,
  145. )[0][0]
  146. description = GRAPH_FIELD_SEP.join(
  147. sorted(set([dp["description"] for dp in entities]))
  148. )
  149. already_source_ids = flat_uniq_list(entities, "source_id")
  150. description = await self._handle_entity_relation_summary(entity_name, description)
  151. node_data = dict(
  152. entity_type=entity_type,
  153. description=description,
  154. source_id=already_source_ids,
  155. )
  156. node_data["entity_name"] = entity_name
  157. all_relationships_data.append(node_data)
  158. async def _merge_edges(
  159. self,
  160. src_id: str,
  161. tgt_id: str,
  162. edges_data: list[dict],
  163. all_relationships_data=None
  164. ):
  165. if not edges_data:
  166. return
  167. weight = sum([edge["weight"] for edge in edges_data])
  168. description = GRAPH_FIELD_SEP.join(sorted(set([edge["description"] for edge in edges_data])))
  169. description = await self._handle_entity_relation_summary(f"{src_id} -> {tgt_id}", description)
  170. keywords = flat_uniq_list(edges_data, "keywords")
  171. source_id = flat_uniq_list(edges_data, "source_id")
  172. edge_data = dict(
  173. src_id=src_id,
  174. tgt_id=tgt_id,
  175. description=description,
  176. keywords=keywords,
  177. weight=weight,
  178. source_id=source_id
  179. )
  180. all_relationships_data.append(edge_data)
  181. async def _merge_graph_nodes(self, graph: nx.Graph, nodes: list[str], change: GraphChange):
  182. if len(nodes) <= 1:
  183. return
  184. change.added_updated_nodes.add(nodes[0])
  185. change.removed_nodes.update(nodes[1:])
  186. nodes_set = set(nodes)
  187. node0_attrs = graph.nodes[nodes[0]]
  188. node0_neighbors = set(graph.neighbors(nodes[0]))
  189. for node1 in nodes[1:]:
  190. # Merge two nodes, keep "entity_name", "entity_type", "page_rank" unchanged.
  191. node1_attrs = graph.nodes[node1]
  192. node0_attrs["description"] += f"{GRAPH_FIELD_SEP}{node1_attrs['description']}"
  193. node0_attrs["source_id"] = sorted(set(node0_attrs["source_id"] + node1_attrs["source_id"]))
  194. for neighbor in graph.neighbors(node1):
  195. change.removed_edges.add(get_from_to(node1, neighbor))
  196. if neighbor not in nodes_set:
  197. edge1_attrs = graph.get_edge_data(node1, neighbor)
  198. if neighbor in node0_neighbors:
  199. # Merge two edges
  200. change.added_updated_edges.add(get_from_to(nodes[0], neighbor))
  201. edge0_attrs = graph.get_edge_data(nodes[0], neighbor)
  202. edge0_attrs["weight"] += edge1_attrs["weight"]
  203. edge0_attrs["description"] += f"{GRAPH_FIELD_SEP}{edge1_attrs['description']}"
  204. for attr in ["keywords", "source_id"]:
  205. edge0_attrs[attr] = sorted(set(edge0_attrs[attr] + edge1_attrs[attr]))
  206. edge0_attrs["description"] = await self._handle_entity_relation_summary(f"({nodes[0]}, {neighbor})", edge0_attrs["description"])
  207. graph.add_edge(nodes[0], neighbor, **edge0_attrs)
  208. else:
  209. graph.add_edge(nodes[0], neighbor, **edge1_attrs)
  210. graph.remove_node(node1)
  211. node0_attrs["description"] = await self._handle_entity_relation_summary(nodes[0], node0_attrs["description"])
  212. graph.nodes[nodes[0]].update(node0_attrs)
  213. async def _handle_entity_relation_summary(
  214. self,
  215. entity_or_relation_name: str,
  216. description: str
  217. ) -> str:
  218. summary_max_tokens = 512
  219. use_description = truncate(description, summary_max_tokens)
  220. description_list=use_description.split(GRAPH_FIELD_SEP),
  221. if len(description_list) <= 12:
  222. return use_description
  223. prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
  224. context_base = dict(
  225. entity_name=entity_or_relation_name,
  226. description_list=description_list,
  227. language=self._language,
  228. )
  229. use_prompt = prompt_template.format(**context_base)
  230. logging.info(f"Trigger summary: {entity_or_relation_name}")
  231. async with chat_limiter:
  232. summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}]))
  233. return summary