You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_extractor.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Reference:
  17. - [graphrag](https://github.com/microsoft/graphrag)
  18. """
  19. import logging
  20. import numbers
  21. import re
  22. import traceback
  23. from dataclasses import dataclass
  24. from typing import Any, Mapping
  25. import tiktoken
  26. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  27. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  28. from rag.llm.chat_model import Base as CompletionLLM
  29. import networkx as nx
  30. from rag.utils import num_tokens_from_string
  31. DEFAULT_TUPLE_DELIMITER = "<|>"
  32. DEFAULT_RECORD_DELIMITER = "##"
  33. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  34. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  35. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  36. @dataclass
  37. class GraphExtractionResult:
  38. """Unipartite graph extraction result class definition."""
  39. output: nx.Graph
  40. source_docs: dict[Any, Any]
  41. class GraphExtractor:
  42. """Unipartite graph extractor class definition."""
  43. _llm: CompletionLLM
  44. _join_descriptions: bool
  45. _tuple_delimiter_key: str
  46. _record_delimiter_key: str
  47. _entity_types_key: str
  48. _input_text_key: str
  49. _completion_delimiter_key: str
  50. _entity_name_key: str
  51. _input_descriptions_key: str
  52. _extraction_prompt: str
  53. _summarization_prompt: str
  54. _loop_args: dict[str, Any]
  55. _max_gleanings: int
  56. _on_error: ErrorHandlerFn
  57. def __init__(
  58. self,
  59. llm_invoker: CompletionLLM,
  60. prompt: str | None = None,
  61. tuple_delimiter_key: str | None = None,
  62. record_delimiter_key: str | None = None,
  63. input_text_key: str | None = None,
  64. entity_types_key: str | None = None,
  65. completion_delimiter_key: str | None = None,
  66. join_descriptions=True,
  67. encoding_model: str | None = None,
  68. max_gleanings: int | None = None,
  69. on_error: ErrorHandlerFn | None = None,
  70. ):
  71. """Init method definition."""
  72. # TODO: streamline construction
  73. self._llm = llm_invoker
  74. self._join_descriptions = join_descriptions
  75. self._input_text_key = input_text_key or "input_text"
  76. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  77. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  78. self._completion_delimiter_key = (
  79. completion_delimiter_key or "completion_delimiter"
  80. )
  81. self._entity_types_key = entity_types_key or "entity_types"
  82. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  83. self._max_gleanings = (
  84. max_gleanings
  85. if max_gleanings is not None
  86. else ENTITY_EXTRACTION_MAX_GLEANINGS
  87. )
  88. self._on_error = on_error or (lambda _e, _s, _d: None)
  89. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  90. # Construct the looping arguments
  91. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  92. yes = encoding.encode("YES")
  93. no = encoding.encode("NO")
  94. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  95. def __call__(
  96. self, texts: list[str], prompt_variables: dict[str, Any] | None = None
  97. ) -> GraphExtractionResult:
  98. """Call method definition."""
  99. if prompt_variables is None:
  100. prompt_variables = {}
  101. all_records: dict[int, str] = {}
  102. source_doc_map: dict[int, str] = {}
  103. # Wire defaults into the prompt variables
  104. prompt_variables = {
  105. **prompt_variables,
  106. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  107. or DEFAULT_TUPLE_DELIMITER,
  108. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  109. or DEFAULT_RECORD_DELIMITER,
  110. self._completion_delimiter_key: prompt_variables.get(
  111. self._completion_delimiter_key
  112. )
  113. or DEFAULT_COMPLETION_DELIMITER,
  114. self._entity_types_key: ",".join(
  115. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  116. ),
  117. }
  118. for doc_index, text in enumerate(texts):
  119. try:
  120. # Invoke the entity extraction
  121. result = self._process_document(text, prompt_variables)
  122. source_doc_map[doc_index] = text
  123. all_records[doc_index] = result
  124. except Exception as e:
  125. logging.exception("error extracting graph")
  126. self._on_error(
  127. e,
  128. traceback.format_exc(),
  129. {
  130. "doc_index": doc_index,
  131. "text": text,
  132. },
  133. )
  134. output = self._process_results(
  135. all_records,
  136. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  137. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  138. )
  139. return GraphExtractionResult(
  140. output=output,
  141. source_docs=source_doc_map,
  142. )
  143. def _process_document(
  144. self, text: str, prompt_variables: dict[str, str]
  145. ) -> str:
  146. variables = {
  147. **prompt_variables,
  148. self._input_text_key: text,
  149. }
  150. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  151. gen_conf = {"temperature": 0.5}
  152. response = self._llm.chat(text, [], gen_conf)
  153. results = response or ""
  154. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  155. # Repeat to ensure we maximize entity count
  156. for i in range(self._max_gleanings):
  157. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  158. history.append({"role": "user", "content": text})
  159. response = self._llm.chat("", history, gen_conf)
  160. results += response or ""
  161. # if this is the final glean, don't bother updating the continuation flag
  162. if i >= self._max_gleanings - 1:
  163. break
  164. history.append({"role": "assistant", "content": response})
  165. history.append({"role": "user", "content": LOOP_PROMPT})
  166. continuation = self._llm.chat("", history, self._loop_args)
  167. if continuation != "YES":
  168. break
  169. return results
  170. def _process_results(
  171. self,
  172. results: dict[int, str],
  173. tuple_delimiter: str,
  174. record_delimiter: str,
  175. ) -> nx.Graph:
  176. """Parse the result string to create an undirected unipartite graph.
  177. Args:
  178. - results - dict of results from the extraction chain
  179. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  180. - record_delimiter - delimiter between records, default is '##'
  181. Returns:
  182. - output - unipartite graph in graphML format
  183. """
  184. graph = nx.Graph()
  185. for source_doc_id, extracted_data in results.items():
  186. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  187. for record in records:
  188. record = re.sub(r"^\(|\)$", "", record.strip())
  189. record_attributes = record.split(tuple_delimiter)
  190. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  191. # add this record as a node in the G
  192. entity_name = clean_str(record_attributes[1].upper())
  193. entity_type = clean_str(record_attributes[2].upper())
  194. entity_description = clean_str(record_attributes[3])
  195. if entity_name in graph.nodes():
  196. node = graph.nodes[entity_name]
  197. if self._join_descriptions:
  198. node["description"] = "\n".join(
  199. list({
  200. *_unpack_descriptions(node),
  201. entity_description,
  202. })
  203. )
  204. else:
  205. if len(entity_description) > len(node["description"]):
  206. node["description"] = entity_description
  207. node["source_id"] = ", ".join(
  208. list({
  209. *_unpack_source_ids(node),
  210. str(source_doc_id),
  211. })
  212. )
  213. node["entity_type"] = (
  214. entity_type if entity_type != "" else node["entity_type"]
  215. )
  216. else:
  217. graph.add_node(
  218. entity_name,
  219. entity_type=entity_type,
  220. description=entity_description,
  221. source_id=str(source_doc_id),
  222. weight=1
  223. )
  224. if (
  225. record_attributes[0] == '"relationship"'
  226. and len(record_attributes) >= 5
  227. ):
  228. # add this record as edge
  229. source = clean_str(record_attributes[1].upper())
  230. target = clean_str(record_attributes[2].upper())
  231. edge_description = clean_str(record_attributes[3])
  232. edge_source_id = clean_str(str(source_doc_id))
  233. weight = (
  234. float(record_attributes[-1])
  235. if isinstance(record_attributes[-1], numbers.Number)
  236. else 1.0
  237. )
  238. if source not in graph.nodes():
  239. graph.add_node(
  240. source,
  241. entity_type="",
  242. description="",
  243. source_id=edge_source_id,
  244. weight=1
  245. )
  246. if target not in graph.nodes():
  247. graph.add_node(
  248. target,
  249. entity_type="",
  250. description="",
  251. source_id=edge_source_id,
  252. weight=1
  253. )
  254. if graph.has_edge(source, target):
  255. edge_data = graph.get_edge_data(source, target)
  256. if edge_data is not None:
  257. weight += edge_data["weight"]
  258. if self._join_descriptions:
  259. edge_description = "\n".join(
  260. list({
  261. *_unpack_descriptions(edge_data),
  262. edge_description,
  263. })
  264. )
  265. edge_source_id = ", ".join(
  266. list({
  267. *_unpack_source_ids(edge_data),
  268. str(source_doc_id),
  269. })
  270. )
  271. graph.add_edge(
  272. source,
  273. target,
  274. weight=weight,
  275. description=edge_description,
  276. source_id=edge_source_id,
  277. )
  278. for node_degree in graph.degree:
  279. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  280. return graph
  281. def _unpack_descriptions(data: Mapping) -> list[str]:
  282. value = data.get("description", None)
  283. return [] if value is None else value.split("\n")
  284. def _unpack_source_ids(data: Mapping) -> list[str]:
  285. value = data.get("source_id", None)
  286. return [] if value is None else value.split(", ")