Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

extractor.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from collections import defaultdict, Counter
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from typing import Callable
  23. from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
  24. from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
  25. handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
  26. from rag.llm.chat_model import Base as CompletionLLM
  27. from rag.utils import truncate
  28. GRAPH_FIELD_SEP = "<SEP>"
  29. DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
  30. ENTITY_EXTRACTION_MAX_GLEANINGS = 2
  31. class Extractor:
  32. _llm: CompletionLLM
  33. def __init__(
  34. self,
  35. llm_invoker: CompletionLLM,
  36. language: str | None = "English",
  37. entity_types: list[str] | None = None,
  38. get_entity: Callable | None = None,
  39. set_entity: Callable | None = None,
  40. get_relation: Callable | None = None,
  41. set_relation: Callable | None = None,
  42. ):
  43. self._llm = llm_invoker
  44. self._language = language
  45. self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
  46. self._get_entity_ = get_entity
  47. self._set_entity_ = set_entity
  48. self._get_relation_ = get_relation
  49. self._set_relation_ = set_relation
  50. def _chat(self, system, history, gen_conf):
  51. hist = deepcopy(history)
  52. conf = deepcopy(gen_conf)
  53. response = get_llm_cache(self._llm.llm_name, system, hist, conf)
  54. if response:
  55. return response
  56. response = self._llm.chat(system, hist, conf)
  57. response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
  58. if response.find("**ERROR**") >= 0:
  59. raise Exception(response)
  60. set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
  61. return response
  62. def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
  63. maybe_nodes = defaultdict(list)
  64. maybe_edges = defaultdict(list)
  65. ent_types = [t.lower() for t in self._entity_types]
  66. for record in records:
  67. record_attributes = split_string_by_multi_markers(
  68. record, [tuple_delimiter]
  69. )
  70. if_entities = handle_single_entity_extraction(
  71. record_attributes, chunk_key
  72. )
  73. if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
  74. maybe_nodes[if_entities["entity_name"]].append(if_entities)
  75. continue
  76. if_relation = handle_single_relationship_extraction(
  77. record_attributes, chunk_key
  78. )
  79. if if_relation is not None:
  80. maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
  81. if_relation
  82. )
  83. return dict(maybe_nodes), dict(maybe_edges)
  84. def __call__(
  85. self, chunks: list[tuple[str, str]],
  86. callback: Callable | None = None
  87. ):
  88. results = []
  89. max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
  90. with ThreadPoolExecutor(max_workers=max_workers) as exe:
  91. threads = []
  92. for i, (cid, ck) in enumerate(chunks):
  93. ck = truncate(ck, int(self._llm.max_length*0.8))
  94. threads.append(
  95. exe.submit(self._process_single_content, (cid, ck)))
  96. for i, _ in enumerate(threads):
  97. n, r, tc = _.result()
  98. if not isinstance(n, Exception):
  99. results.append((n, r))
  100. if callback:
  101. callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
  102. elif callback:
  103. callback(msg="Knowledge graph extraction error:{}".format(str(n)))
  104. maybe_nodes = defaultdict(list)
  105. maybe_edges = defaultdict(list)
  106. for m_nodes, m_edges in results:
  107. for k, v in m_nodes.items():
  108. maybe_nodes[k].extend(v)
  109. for k, v in m_edges.items():
  110. maybe_edges[tuple(sorted(k))].extend(v)
  111. logging.info("Inserting entities into storage...")
  112. all_entities_data = []
  113. with ThreadPoolExecutor(max_workers=max_workers) as exe:
  114. threads = []
  115. for en_nm, ents in maybe_nodes.items():
  116. threads.append(
  117. exe.submit(self._merge_nodes, en_nm, ents))
  118. for t in threads:
  119. n = t.result()
  120. if not isinstance(n, Exception):
  121. all_entities_data.append(n)
  122. elif callback:
  123. callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
  124. logging.info("Inserting relationships into storage...")
  125. all_relationships_data = []
  126. for (src, tgt), rels in maybe_edges.items():
  127. all_relationships_data.append(self._merge_edges(src, tgt, rels))
  128. if not len(all_entities_data) and not len(all_relationships_data):
  129. logging.warning(
  130. "Didn't extract any entities and relationships, maybe your LLM is not working"
  131. )
  132. if not len(all_entities_data):
  133. logging.warning("Didn't extract any entities")
  134. if not len(all_relationships_data):
  135. logging.warning("Didn't extract any relationships")
  136. return all_entities_data, all_relationships_data
  137. def _merge_nodes(self, entity_name: str, entities: list[dict]):
  138. if not entities:
  139. return
  140. already_entity_types = []
  141. already_source_ids = []
  142. already_description = []
  143. already_node = self._get_entity_(entity_name)
  144. if already_node:
  145. already_entity_types.append(already_node["entity_type"])
  146. already_source_ids.extend(already_node["source_id"])
  147. already_description.append(already_node["description"])
  148. entity_type = sorted(
  149. Counter(
  150. [dp["entity_type"] for dp in entities] + already_entity_types
  151. ).items(),
  152. key=lambda x: x[1],
  153. reverse=True,
  154. )[0][0]
  155. description = GRAPH_FIELD_SEP.join(
  156. sorted(set([dp["description"] for dp in entities] + already_description))
  157. )
  158. already_source_ids = flat_uniq_list(entities, "source_id")
  159. try:
  160. description = self._handle_entity_relation_summary(
  161. entity_name, description
  162. )
  163. node_data = dict(
  164. entity_type=entity_type,
  165. description=description,
  166. source_id=already_source_ids,
  167. )
  168. node_data["entity_name"] = entity_name
  169. self._set_entity_(entity_name, node_data)
  170. return node_data
  171. except Exception as e:
  172. return e
  173. def _merge_edges(
  174. self,
  175. src_id: str,
  176. tgt_id: str,
  177. edges_data: list[dict]
  178. ):
  179. if not edges_data:
  180. return
  181. already_weights = []
  182. already_source_ids = []
  183. already_description = []
  184. already_keywords = []
  185. relation = self._get_relation_(src_id, tgt_id)
  186. if relation:
  187. already_weights = [relation["weight"]]
  188. already_source_ids = relation["source_id"]
  189. already_description = [relation["description"]]
  190. already_keywords = relation["keywords"]
  191. weight = sum([dp["weight"] for dp in edges_data] + already_weights)
  192. description = GRAPH_FIELD_SEP.join(
  193. sorted(set([dp["description"] for dp in edges_data] + already_description))
  194. )
  195. keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
  196. source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
  197. for need_insert_id in [src_id, tgt_id]:
  198. if self._get_entity_(need_insert_id):
  199. continue
  200. self._set_entity_(need_insert_id, {
  201. "source_id": source_id,
  202. "description": description,
  203. "entity_type": 'UNKNOWN'
  204. })
  205. description = self._handle_entity_relation_summary(
  206. f"({src_id}, {tgt_id})", description
  207. )
  208. edge_data = dict(
  209. src_id=src_id,
  210. tgt_id=tgt_id,
  211. description=description,
  212. keywords=keywords,
  213. weight=weight,
  214. source_id=source_id
  215. )
  216. self._set_relation_(src_id, tgt_id, edge_data)
  217. return edge_data
  218. def _handle_entity_relation_summary(
  219. self,
  220. entity_or_relation_name: str,
  221. description: str
  222. ) -> str:
  223. summary_max_tokens = 512
  224. use_description = truncate(description, summary_max_tokens)
  225. prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
  226. context_base = dict(
  227. entity_name=entity_or_relation_name,
  228. description_list=use_description.split(GRAPH_FIELD_SEP),
  229. language=self._language,
  230. )
  231. use_prompt = prompt_template.format(**context_base)
  232. logging.info(f"Trigger summary: {entity_or_relation_name}")
  233. summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
  234. return summary