You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_extractor.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Reference:
  17. - [graphrag](https://github.com/microsoft/graphrag)
  18. """
  19. import logging
  20. import numbers
  21. import re
  22. import traceback
  23. from dataclasses import dataclass
  24. from typing import Any, Mapping, Callable
  25. import tiktoken
  26. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  27. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  28. from rag.llm.chat_model import Base as CompletionLLM
  29. import networkx as nx
  30. from rag.utils import num_tokens_from_string
  31. from timeit import default_timer as timer
  32. DEFAULT_TUPLE_DELIMITER = "<|>"
  33. DEFAULT_RECORD_DELIMITER = "##"
  34. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  35. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  36. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  37. @dataclass
  38. class GraphExtractionResult:
  39. """Unipartite graph extraction result class definition."""
  40. output: nx.Graph
  41. source_docs: dict[Any, Any]
  42. class GraphExtractor:
  43. """Unipartite graph extractor class definition."""
  44. _llm: CompletionLLM
  45. _join_descriptions: bool
  46. _tuple_delimiter_key: str
  47. _record_delimiter_key: str
  48. _entity_types_key: str
  49. _input_text_key: str
  50. _completion_delimiter_key: str
  51. _entity_name_key: str
  52. _input_descriptions_key: str
  53. _extraction_prompt: str
  54. _summarization_prompt: str
  55. _loop_args: dict[str, Any]
  56. _max_gleanings: int
  57. _on_error: ErrorHandlerFn
  58. def __init__(
  59. self,
  60. llm_invoker: CompletionLLM,
  61. prompt: str | None = None,
  62. tuple_delimiter_key: str | None = None,
  63. record_delimiter_key: str | None = None,
  64. input_text_key: str | None = None,
  65. entity_types_key: str | None = None,
  66. completion_delimiter_key: str | None = None,
  67. join_descriptions=True,
  68. encoding_model: str | None = None,
  69. max_gleanings: int | None = None,
  70. on_error: ErrorHandlerFn | None = None,
  71. ):
  72. """Init method definition."""
  73. # TODO: streamline construction
  74. self._llm = llm_invoker
  75. self._join_descriptions = join_descriptions
  76. self._input_text_key = input_text_key or "input_text"
  77. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  78. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  79. self._completion_delimiter_key = (
  80. completion_delimiter_key or "completion_delimiter"
  81. )
  82. self._entity_types_key = entity_types_key or "entity_types"
  83. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  84. self._max_gleanings = (
  85. max_gleanings
  86. if max_gleanings is not None
  87. else ENTITY_EXTRACTION_MAX_GLEANINGS
  88. )
  89. self._on_error = on_error or (lambda _e, _s, _d: None)
  90. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  91. # Construct the looping arguments
  92. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  93. yes = encoding.encode("YES")
  94. no = encoding.encode("NO")
  95. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  96. def __call__(
  97. self, texts: list[str],
  98. prompt_variables: dict[str, Any] | None = None,
  99. callback: Callable | None = None
  100. ) -> GraphExtractionResult:
  101. """Call method definition."""
  102. if prompt_variables is None:
  103. prompt_variables = {}
  104. all_records: dict[int, str] = {}
  105. source_doc_map: dict[int, str] = {}
  106. # Wire defaults into the prompt variables
  107. prompt_variables = {
  108. **prompt_variables,
  109. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  110. or DEFAULT_TUPLE_DELIMITER,
  111. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  112. or DEFAULT_RECORD_DELIMITER,
  113. self._completion_delimiter_key: prompt_variables.get(
  114. self._completion_delimiter_key
  115. )
  116. or DEFAULT_COMPLETION_DELIMITER,
  117. self._entity_types_key: ",".join(
  118. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  119. ),
  120. }
  121. st = timer()
  122. total = len(texts)
  123. total_token_count = 0
  124. for doc_index, text in enumerate(texts):
  125. try:
  126. # Invoke the entity extraction
  127. result, token_count = self._process_document(text, prompt_variables)
  128. source_doc_map[doc_index] = text
  129. all_records[doc_index] = result
  130. total_token_count += token_count
  131. if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  132. except Exception as e:
  133. if callback: callback("Knowledge graph extraction error:{}".format(str(e)))
  134. logging.exception("error extracting graph")
  135. self._on_error(
  136. e,
  137. traceback.format_exc(),
  138. {
  139. "doc_index": doc_index,
  140. "text": text,
  141. },
  142. )
  143. output = self._process_results(
  144. all_records,
  145. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  146. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  147. )
  148. return GraphExtractionResult(
  149. output=output,
  150. source_docs=source_doc_map,
  151. )
  152. def _process_document(
  153. self, text: str, prompt_variables: dict[str, str]
  154. ) -> str:
  155. variables = {
  156. **prompt_variables,
  157. self._input_text_key: text,
  158. }
  159. token_count = 0
  160. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  161. gen_conf = {"temperature": 0.3}
  162. response = self._llm.chat(text, [], gen_conf)
  163. token_count = num_tokens_from_string(text + response)
  164. results = response or ""
  165. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  166. # Repeat to ensure we maximize entity count
  167. for i in range(self._max_gleanings):
  168. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  169. history.append({"role": "user", "content": text})
  170. response = self._llm.chat("", history, gen_conf)
  171. if response.find("**ERROR**") >=0: raise Exception(response)
  172. results += response or ""
  173. # if this is the final glean, don't bother updating the continuation flag
  174. if i >= self._max_gleanings - 1:
  175. break
  176. history.append({"role": "assistant", "content": response})
  177. history.append({"role": "user", "content": LOOP_PROMPT})
  178. continuation = self._llm.chat("", history, self._loop_args)
  179. if continuation != "YES":
  180. break
  181. return results, token_count
  182. def _process_results(
  183. self,
  184. results: dict[int, str],
  185. tuple_delimiter: str,
  186. record_delimiter: str,
  187. ) -> nx.Graph:
  188. """Parse the result string to create an undirected unipartite graph.
  189. Args:
  190. - results - dict of results from the extraction chain
  191. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  192. - record_delimiter - delimiter between records, default is '##'
  193. Returns:
  194. - output - unipartite graph in graphML format
  195. """
  196. graph = nx.Graph()
  197. for source_doc_id, extracted_data in results.items():
  198. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  199. for record in records:
  200. record = re.sub(r"^\(|\)$", "", record.strip())
  201. record_attributes = record.split(tuple_delimiter)
  202. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  203. # add this record as a node in the G
  204. entity_name = clean_str(record_attributes[1].upper())
  205. entity_type = clean_str(record_attributes[2].upper())
  206. entity_description = clean_str(record_attributes[3])
  207. if entity_name in graph.nodes():
  208. node = graph.nodes[entity_name]
  209. if self._join_descriptions:
  210. node["description"] = "\n".join(
  211. list({
  212. *_unpack_descriptions(node),
  213. entity_description,
  214. })
  215. )
  216. else:
  217. if len(entity_description) > len(node["description"]):
  218. node["description"] = entity_description
  219. node["source_id"] = ", ".join(
  220. list({
  221. *_unpack_source_ids(node),
  222. str(source_doc_id),
  223. })
  224. )
  225. node["entity_type"] = (
  226. entity_type if entity_type != "" else node["entity_type"]
  227. )
  228. else:
  229. graph.add_node(
  230. entity_name,
  231. entity_type=entity_type,
  232. description=entity_description,
  233. source_id=str(source_doc_id),
  234. weight=1
  235. )
  236. if (
  237. record_attributes[0] == '"relationship"'
  238. and len(record_attributes) >= 5
  239. ):
  240. # add this record as edge
  241. source = clean_str(record_attributes[1].upper())
  242. target = clean_str(record_attributes[2].upper())
  243. edge_description = clean_str(record_attributes[3])
  244. edge_source_id = clean_str(str(source_doc_id))
  245. weight = (
  246. float(record_attributes[-1])
  247. if isinstance(record_attributes[-1], numbers.Number)
  248. else 1.0
  249. )
  250. if source not in graph.nodes():
  251. graph.add_node(
  252. source,
  253. entity_type="",
  254. description="",
  255. source_id=edge_source_id,
  256. weight=1
  257. )
  258. if target not in graph.nodes():
  259. graph.add_node(
  260. target,
  261. entity_type="",
  262. description="",
  263. source_id=edge_source_id,
  264. weight=1
  265. )
  266. if graph.has_edge(source, target):
  267. edge_data = graph.get_edge_data(source, target)
  268. if edge_data is not None:
  269. weight += edge_data["weight"]
  270. if self._join_descriptions:
  271. edge_description = "\n".join(
  272. list({
  273. *_unpack_descriptions(edge_data),
  274. edge_description,
  275. })
  276. )
  277. edge_source_id = ", ".join(
  278. list({
  279. *_unpack_source_ids(edge_data),
  280. str(source_doc_id),
  281. })
  282. )
  283. graph.add_edge(
  284. source,
  285. target,
  286. weight=weight,
  287. description=edge_description,
  288. source_id=edge_source_id,
  289. )
  290. for node_degree in graph.degree:
  291. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  292. return graph
  293. def _unpack_descriptions(data: Mapping) -> list[str]:
  294. value = data.get("description", None)
  295. return [] if value is None else value.split("\n")
  296. def _unpack_source_ids(data: Mapping) -> list[str]:
  297. value = data.get("source_id", None)
  298. return [] if value is None else value.split(", ")