Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

extractor.py 9.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import re
  19. from collections import defaultdict, Counter
  20. from concurrent.futures import ThreadPoolExecutor
  21. from copy import deepcopy
  22. from typing import Callable
  23. from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
  24. from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
  25. handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
  26. from rag.llm.chat_model import Base as CompletionLLM
  27. from rag.utils import truncate
  28. GRAPH_FIELD_SEP = "<SEP>"
  29. DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
  30. ENTITY_EXTRACTION_MAX_GLEANINGS = 2
  31. class Extractor:
  32. _llm: CompletionLLM
  33. def __init__(
  34. self,
  35. llm_invoker: CompletionLLM,
  36. language: str | None = "English",
  37. entity_types: list[str] | None = None,
  38. get_entity: Callable | None = None,
  39. set_entity: Callable | None = None,
  40. get_relation: Callable | None = None,
  41. set_relation: Callable | None = None,
  42. ):
  43. self._llm = llm_invoker
  44. self._language = language
  45. self._entity_types = entity_types or DEFAULT_ENTITY_TYPES
  46. self._get_entity_ = get_entity
  47. self._set_entity_ = set_entity
  48. self._get_relation_ = get_relation
  49. self._set_relation_ = set_relation
  50. def _chat(self, system, history, gen_conf):
  51. hist = deepcopy(history)
  52. conf = deepcopy(gen_conf)
  53. response = get_llm_cache(self._llm.llm_name, system, hist, conf)
  54. if response:
  55. return response
  56. response = self._llm.chat(system, hist, conf)
  57. response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
  58. if response.find("**ERROR**") >= 0:
  59. raise Exception(response)
  60. set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
  61. return response
  62. def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str):
  63. maybe_nodes = defaultdict(list)
  64. maybe_edges = defaultdict(list)
  65. ent_types = [t.lower() for t in self._entity_types]
  66. for record in records:
  67. record_attributes = split_string_by_multi_markers(
  68. record, [tuple_delimiter]
  69. )
  70. if_entities = handle_single_entity_extraction(
  71. record_attributes, chunk_key
  72. )
  73. if if_entities is not None and if_entities.get("entity_type", "unknown").lower() in ent_types:
  74. maybe_nodes[if_entities["entity_name"]].append(if_entities)
  75. continue
  76. if_relation = handle_single_relationship_extraction(
  77. record_attributes, chunk_key
  78. )
  79. if if_relation is not None:
  80. maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
  81. if_relation
  82. )
  83. return dict(maybe_nodes), dict(maybe_edges)
  84. def __call__(
  85. self, chunks: list[tuple[str, str]],
  86. callback: Callable | None = None
  87. ):
  88. results = []
  89. max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 50))
  90. with ThreadPoolExecutor(max_workers=max_workers) as exe:
  91. threads = []
  92. for i, (cid, ck) in enumerate(chunks):
  93. ck = truncate(ck, int(self._llm.max_length*0.8))
  94. threads.append(
  95. exe.submit(self._process_single_content, (cid, ck)))
  96. for i, _ in enumerate(threads):
  97. n, r, tc = _.result()
  98. if not isinstance(n, Exception):
  99. results.append((n, r))
  100. if callback:
  101. callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
  102. elif callback:
  103. callback(msg="Knowledge graph extraction error:{}".format(str(n)))
  104. maybe_nodes = defaultdict(list)
  105. maybe_edges = defaultdict(list)
  106. for m_nodes, m_edges in results:
  107. for k, v in m_nodes.items():
  108. maybe_nodes[k].extend(v)
  109. for k, v in m_edges.items():
  110. maybe_edges[tuple(sorted(k))].extend(v)
  111. logging.info("Inserting entities into storage...")
  112. all_entities_data = []
  113. for en_nm, ents in maybe_nodes.items():
  114. all_entities_data.append(self._merge_nodes(en_nm, ents))
  115. logging.info("Inserting relationships into storage...")
  116. all_relationships_data = []
  117. for (src,tgt), rels in maybe_edges.items():
  118. all_relationships_data.append(self._merge_edges(src, tgt, rels))
  119. if not len(all_entities_data) and not len(all_relationships_data):
  120. logging.warning(
  121. "Didn't extract any entities and relationships, maybe your LLM is not working"
  122. )
  123. if not len(all_entities_data):
  124. logging.warning("Didn't extract any entities")
  125. if not len(all_relationships_data):
  126. logging.warning("Didn't extract any relationships")
  127. return all_entities_data, all_relationships_data
  128. def _merge_nodes(self, entity_name: str, entities: list[dict]):
  129. if not entities:
  130. return
  131. already_entity_types = []
  132. already_source_ids = []
  133. already_description = []
  134. already_node = self._get_entity_(entity_name)
  135. if already_node:
  136. already_entity_types.append(already_node["entity_type"])
  137. already_source_ids.extend(already_node["source_id"])
  138. already_description.append(already_node["description"])
  139. entity_type = sorted(
  140. Counter(
  141. [dp["entity_type"] for dp in entities] + already_entity_types
  142. ).items(),
  143. key=lambda x: x[1],
  144. reverse=True,
  145. )[0][0]
  146. description = GRAPH_FIELD_SEP.join(
  147. sorted(set([dp["description"] for dp in entities] + already_description))
  148. )
  149. already_source_ids = flat_uniq_list(entities, "source_id")
  150. description = self._handle_entity_relation_summary(
  151. entity_name, description
  152. )
  153. node_data = dict(
  154. entity_type=entity_type,
  155. description=description,
  156. source_id=already_source_ids,
  157. )
  158. node_data["entity_name"] = entity_name
  159. self._set_entity_(entity_name, node_data)
  160. return node_data
  161. def _merge_edges(
  162. self,
  163. src_id: str,
  164. tgt_id: str,
  165. edges_data: list[dict]
  166. ):
  167. if not edges_data:
  168. return
  169. already_weights = []
  170. already_source_ids = []
  171. already_description = []
  172. already_keywords = []
  173. relation = self._get_relation_(src_id, tgt_id)
  174. if relation:
  175. already_weights = [relation["weight"]]
  176. already_source_ids = relation["source_id"]
  177. already_description = [relation["description"]]
  178. already_keywords = relation["keywords"]
  179. weight = sum([dp["weight"] for dp in edges_data] + already_weights)
  180. description = GRAPH_FIELD_SEP.join(
  181. sorted(set([dp["description"] for dp in edges_data] + already_description))
  182. )
  183. keywords = flat_uniq_list(edges_data, "keywords") + already_keywords
  184. source_id = flat_uniq_list(edges_data, "source_id") + already_source_ids
  185. for need_insert_id in [src_id, tgt_id]:
  186. if self._get_entity_(need_insert_id):
  187. continue
  188. self._set_entity_(need_insert_id, {
  189. "source_id": source_id,
  190. "description": description,
  191. "entity_type": 'UNKNOWN'
  192. })
  193. description = self._handle_entity_relation_summary(
  194. f"({src_id}, {tgt_id})", description
  195. )
  196. edge_data = dict(
  197. src_id=src_id,
  198. tgt_id=tgt_id,
  199. description=description,
  200. keywords=keywords,
  201. weight=weight,
  202. source_id=source_id
  203. )
  204. self._set_relation_(src_id, tgt_id, edge_data)
  205. return edge_data
  206. def _handle_entity_relation_summary(
  207. self,
  208. entity_or_relation_name: str,
  209. description: str
  210. ) -> str:
  211. summary_max_tokens = 512
  212. use_description = truncate(description, summary_max_tokens)
  213. prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
  214. context_base = dict(
  215. entity_name=entity_or_relation_name,
  216. description_list=use_description.split(GRAPH_FIELD_SEP),
  217. language=self._language,
  218. )
  219. use_prompt = prompt_template.format(**context_base)
  220. logging.info(f"Trigger summary: {entity_or_relation_name}")
  221. summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
  222. return summary