You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_extractor.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Reference:
  17. - [graphrag](https://github.com/microsoft/graphrag)
  18. """
  19. import logging
  20. import numbers
  21. import re
  22. import traceback
  23. from dataclasses import dataclass
  24. from typing import Any, Mapping, Callable
  25. import tiktoken
  26. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  27. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  28. from rag.llm.chat_model import Base as CompletionLLM
  29. import networkx as nx
  30. from rag.utils import num_tokens_from_string
  31. from timeit import default_timer as timer
  32. DEFAULT_TUPLE_DELIMITER = "<|>"
  33. DEFAULT_RECORD_DELIMITER = "##"
  34. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  35. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  36. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  37. @dataclass
  38. class GraphExtractionResult:
  39. """Unipartite graph extraction result class definition."""
  40. output: nx.Graph
  41. source_docs: dict[Any, Any]
  42. class GraphExtractor:
  43. """Unipartite graph extractor class definition."""
  44. _llm: CompletionLLM
  45. _join_descriptions: bool
  46. _tuple_delimiter_key: str
  47. _record_delimiter_key: str
  48. _entity_types_key: str
  49. _input_text_key: str
  50. _completion_delimiter_key: str
  51. _entity_name_key: str
  52. _input_descriptions_key: str
  53. _extraction_prompt: str
  54. _summarization_prompt: str
  55. _loop_args: dict[str, Any]
  56. _max_gleanings: int
  57. _on_error: ErrorHandlerFn
  58. def __init__(
  59. self,
  60. llm_invoker: CompletionLLM,
  61. prompt: str | None = None,
  62. tuple_delimiter_key: str | None = None,
  63. record_delimiter_key: str | None = None,
  64. input_text_key: str | None = None,
  65. entity_types_key: str | None = None,
  66. completion_delimiter_key: str | None = None,
  67. join_descriptions=True,
  68. encoding_model: str | None = None,
  69. max_gleanings: int | None = None,
  70. on_error: ErrorHandlerFn | None = None,
  71. ):
  72. """Init method definition."""
  73. # TODO: streamline construction
  74. self._llm = llm_invoker
  75. self._join_descriptions = join_descriptions
  76. self._input_text_key = input_text_key or "input_text"
  77. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  78. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  79. self._completion_delimiter_key = (
  80. completion_delimiter_key or "completion_delimiter"
  81. )
  82. self._entity_types_key = entity_types_key or "entity_types"
  83. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  84. self._max_gleanings = (
  85. max_gleanings
  86. if max_gleanings is not None
  87. else ENTITY_EXTRACTION_MAX_GLEANINGS
  88. )
  89. self._on_error = on_error or (lambda _e, _s, _d: None)
  90. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  91. # Construct the looping arguments
  92. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  93. yes = encoding.encode("YES")
  94. no = encoding.encode("NO")
  95. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  96. def __call__(
  97. self, texts: list[str],
  98. prompt_variables: dict[str, Any] | None = None,
  99. callback: Callable | None = None
  100. ) -> GraphExtractionResult:
  101. """Call method definition."""
  102. if prompt_variables is None:
  103. prompt_variables = {}
  104. all_records: dict[int, str] = {}
  105. source_doc_map: dict[int, str] = {}
  106. # Wire defaults into the prompt variables
  107. prompt_variables = {
  108. **prompt_variables,
  109. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  110. or DEFAULT_TUPLE_DELIMITER,
  111. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  112. or DEFAULT_RECORD_DELIMITER,
  113. self._completion_delimiter_key: prompt_variables.get(
  114. self._completion_delimiter_key
  115. )
  116. or DEFAULT_COMPLETION_DELIMITER,
  117. self._entity_types_key: ",".join(
  118. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  119. ),
  120. }
  121. st = timer()
  122. total = len(texts)
  123. total_token_count = 0
  124. for doc_index, text in enumerate(texts):
  125. try:
  126. # Invoke the entity extraction
  127. result, token_count = self._process_document(text, prompt_variables)
  128. source_doc_map[doc_index] = text
  129. all_records[doc_index] = result
  130. total_token_count += token_count
  131. if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  132. except Exception as e:
  133. logging.exception("error extracting graph")
  134. self._on_error(
  135. e,
  136. traceback.format_exc(),
  137. {
  138. "doc_index": doc_index,
  139. "text": text,
  140. },
  141. )
  142. output = self._process_results(
  143. all_records,
  144. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  145. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  146. )
  147. return GraphExtractionResult(
  148. output=output,
  149. source_docs=source_doc_map,
  150. )
  151. def _process_document(
  152. self, text: str, prompt_variables: dict[str, str]
  153. ) -> str:
  154. variables = {
  155. **prompt_variables,
  156. self._input_text_key: text,
  157. }
  158. token_count = 0
  159. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  160. gen_conf = {"temperature": 0.3}
  161. response = self._llm.chat(text, [], gen_conf)
  162. token_count = num_tokens_from_string(text + response)
  163. results = response or ""
  164. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  165. # Repeat to ensure we maximize entity count
  166. for i in range(self._max_gleanings):
  167. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  168. history.append({"role": "user", "content": text})
  169. response = self._llm.chat("", history, gen_conf)
  170. results += response or ""
  171. # if this is the final glean, don't bother updating the continuation flag
  172. if i >= self._max_gleanings - 1:
  173. break
  174. history.append({"role": "assistant", "content": response})
  175. history.append({"role": "user", "content": LOOP_PROMPT})
  176. continuation = self._llm.chat("", history, self._loop_args)
  177. if continuation != "YES":
  178. break
  179. return results, token_count
  180. def _process_results(
  181. self,
  182. results: dict[int, str],
  183. tuple_delimiter: str,
  184. record_delimiter: str,
  185. ) -> nx.Graph:
  186. """Parse the result string to create an undirected unipartite graph.
  187. Args:
  188. - results - dict of results from the extraction chain
  189. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  190. - record_delimiter - delimiter between records, default is '##'
  191. Returns:
  192. - output - unipartite graph in graphML format
  193. """
  194. graph = nx.Graph()
  195. for source_doc_id, extracted_data in results.items():
  196. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  197. for record in records:
  198. record = re.sub(r"^\(|\)$", "", record.strip())
  199. record_attributes = record.split(tuple_delimiter)
  200. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  201. # add this record as a node in the G
  202. entity_name = clean_str(record_attributes[1].upper())
  203. entity_type = clean_str(record_attributes[2].upper())
  204. entity_description = clean_str(record_attributes[3])
  205. if entity_name in graph.nodes():
  206. node = graph.nodes[entity_name]
  207. if self._join_descriptions:
  208. node["description"] = "\n".join(
  209. list({
  210. *_unpack_descriptions(node),
  211. entity_description,
  212. })
  213. )
  214. else:
  215. if len(entity_description) > len(node["description"]):
  216. node["description"] = entity_description
  217. node["source_id"] = ", ".join(
  218. list({
  219. *_unpack_source_ids(node),
  220. str(source_doc_id),
  221. })
  222. )
  223. node["entity_type"] = (
  224. entity_type if entity_type != "" else node["entity_type"]
  225. )
  226. else:
  227. graph.add_node(
  228. entity_name,
  229. entity_type=entity_type,
  230. description=entity_description,
  231. source_id=str(source_doc_id),
  232. weight=1
  233. )
  234. if (
  235. record_attributes[0] == '"relationship"'
  236. and len(record_attributes) >= 5
  237. ):
  238. # add this record as edge
  239. source = clean_str(record_attributes[1].upper())
  240. target = clean_str(record_attributes[2].upper())
  241. edge_description = clean_str(record_attributes[3])
  242. edge_source_id = clean_str(str(source_doc_id))
  243. weight = (
  244. float(record_attributes[-1])
  245. if isinstance(record_attributes[-1], numbers.Number)
  246. else 1.0
  247. )
  248. if source not in graph.nodes():
  249. graph.add_node(
  250. source,
  251. entity_type="",
  252. description="",
  253. source_id=edge_source_id,
  254. weight=1
  255. )
  256. if target not in graph.nodes():
  257. graph.add_node(
  258. target,
  259. entity_type="",
  260. description="",
  261. source_id=edge_source_id,
  262. weight=1
  263. )
  264. if graph.has_edge(source, target):
  265. edge_data = graph.get_edge_data(source, target)
  266. if edge_data is not None:
  267. weight += edge_data["weight"]
  268. if self._join_descriptions:
  269. edge_description = "\n".join(
  270. list({
  271. *_unpack_descriptions(edge_data),
  272. edge_description,
  273. })
  274. )
  275. edge_source_id = ", ".join(
  276. list({
  277. *_unpack_source_ids(edge_data),
  278. str(source_doc_id),
  279. })
  280. )
  281. graph.add_edge(
  282. source,
  283. target,
  284. weight=weight,
  285. description=edge_description,
  286. source_id=edge_source_id,
  287. )
  288. for node_degree in graph.degree:
  289. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  290. return graph
  291. def _unpack_descriptions(data: Mapping) -> list[str]:
  292. value = data.get("description", None)
  293. return [] if value is None else value.split("\n")
  294. def _unpack_source_ids(data: Mapping) -> list[str]:
  295. value = data.get("source_id", None)
  296. return [] if value is None else value.split(", ")