You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_extractor.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import numbers
  9. import re
  10. import traceback
  11. from dataclasses import dataclass
  12. from typing import Any, Mapping, Callable
  13. import tiktoken
  14. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  15. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  16. from rag.llm.chat_model import Base as CompletionLLM
  17. import networkx as nx
  18. from rag.utils import num_tokens_from_string
  19. from timeit import default_timer as timer
  20. DEFAULT_TUPLE_DELIMITER = "<|>"
  21. DEFAULT_RECORD_DELIMITER = "##"
  22. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  23. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  24. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  25. @dataclass
  26. class GraphExtractionResult:
  27. """Unipartite graph extraction result class definition."""
  28. output: nx.Graph
  29. source_docs: dict[Any, Any]
  30. class GraphExtractor:
  31. """Unipartite graph extractor class definition."""
  32. _llm: CompletionLLM
  33. _join_descriptions: bool
  34. _tuple_delimiter_key: str
  35. _record_delimiter_key: str
  36. _entity_types_key: str
  37. _input_text_key: str
  38. _completion_delimiter_key: str
  39. _entity_name_key: str
  40. _input_descriptions_key: str
  41. _extraction_prompt: str
  42. _summarization_prompt: str
  43. _loop_args: dict[str, Any]
  44. _max_gleanings: int
  45. _on_error: ErrorHandlerFn
  46. def __init__(
  47. self,
  48. llm_invoker: CompletionLLM,
  49. prompt: str | None = None,
  50. tuple_delimiter_key: str | None = None,
  51. record_delimiter_key: str | None = None,
  52. input_text_key: str | None = None,
  53. entity_types_key: str | None = None,
  54. completion_delimiter_key: str | None = None,
  55. join_descriptions=True,
  56. encoding_model: str | None = None,
  57. max_gleanings: int | None = None,
  58. on_error: ErrorHandlerFn | None = None,
  59. ):
  60. """Init method definition."""
  61. # TODO: streamline construction
  62. self._llm = llm_invoker
  63. self._join_descriptions = join_descriptions
  64. self._input_text_key = input_text_key or "input_text"
  65. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  66. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  67. self._completion_delimiter_key = (
  68. completion_delimiter_key or "completion_delimiter"
  69. )
  70. self._entity_types_key = entity_types_key or "entity_types"
  71. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  72. self._max_gleanings = (
  73. max_gleanings
  74. if max_gleanings is not None
  75. else ENTITY_EXTRACTION_MAX_GLEANINGS
  76. )
  77. self._on_error = on_error or (lambda _e, _s, _d: None)
  78. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  79. # Construct the looping arguments
  80. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  81. yes = encoding.encode("YES")
  82. no = encoding.encode("NO")
  83. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  84. def __call__(
  85. self, texts: list[str],
  86. prompt_variables: dict[str, Any] | None = None,
  87. callback: Callable | None = None
  88. ) -> GraphExtractionResult:
  89. """Call method definition."""
  90. if prompt_variables is None:
  91. prompt_variables = {}
  92. all_records: dict[int, str] = {}
  93. source_doc_map: dict[int, str] = {}
  94. # Wire defaults into the prompt variables
  95. prompt_variables = {
  96. **prompt_variables,
  97. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  98. or DEFAULT_TUPLE_DELIMITER,
  99. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  100. or DEFAULT_RECORD_DELIMITER,
  101. self._completion_delimiter_key: prompt_variables.get(
  102. self._completion_delimiter_key
  103. )
  104. or DEFAULT_COMPLETION_DELIMITER,
  105. self._entity_types_key: ",".join(
  106. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  107. ),
  108. }
  109. st = timer()
  110. total = len(texts)
  111. total_token_count = 0
  112. for doc_index, text in enumerate(texts):
  113. try:
  114. # Invoke the entity extraction
  115. result, token_count = self._process_document(text, prompt_variables)
  116. source_doc_map[doc_index] = text
  117. all_records[doc_index] = result
  118. total_token_count += token_count
  119. if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  120. except Exception as e:
  121. if callback: callback("Knowledge graph extraction error:{}".format(str(e)))
  122. logging.exception("error extracting graph")
  123. self._on_error(
  124. e,
  125. traceback.format_exc(),
  126. {
  127. "doc_index": doc_index,
  128. "text": text,
  129. },
  130. )
  131. output = self._process_results(
  132. all_records,
  133. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  134. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  135. )
  136. return GraphExtractionResult(
  137. output=output,
  138. source_docs=source_doc_map,
  139. )
  140. def _process_document(
  141. self, text: str, prompt_variables: dict[str, str]
  142. ) -> str:
  143. variables = {
  144. **prompt_variables,
  145. self._input_text_key: text,
  146. }
  147. token_count = 0
  148. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  149. gen_conf = {"temperature": 0.3}
  150. response = self._llm.chat(text, [], gen_conf)
  151. token_count = num_tokens_from_string(text + response)
  152. results = response or ""
  153. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  154. # Repeat to ensure we maximize entity count
  155. for i in range(self._max_gleanings):
  156. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  157. history.append({"role": "user", "content": text})
  158. response = self._llm.chat("", history, gen_conf)
  159. if response.find("**ERROR**") >=0: raise Exception(response)
  160. results += response or ""
  161. # if this is the final glean, don't bother updating the continuation flag
  162. if i >= self._max_gleanings - 1:
  163. break
  164. history.append({"role": "assistant", "content": response})
  165. history.append({"role": "user", "content": LOOP_PROMPT})
  166. continuation = self._llm.chat("", history, self._loop_args)
  167. if continuation != "YES":
  168. break
  169. return results, token_count
  170. def _process_results(
  171. self,
  172. results: dict[int, str],
  173. tuple_delimiter: str,
  174. record_delimiter: str,
  175. ) -> nx.Graph:
  176. """Parse the result string to create an undirected unipartite graph.
  177. Args:
  178. - results - dict of results from the extraction chain
  179. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  180. - record_delimiter - delimiter between records, default is '##'
  181. Returns:
  182. - output - unipartite graph in graphML format
  183. """
  184. graph = nx.Graph()
  185. for source_doc_id, extracted_data in results.items():
  186. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  187. for record in records:
  188. record = re.sub(r"^\(|\)$", "", record.strip())
  189. record_attributes = record.split(tuple_delimiter)
  190. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  191. # add this record as a node in the G
  192. entity_name = clean_str(record_attributes[1].upper())
  193. entity_type = clean_str(record_attributes[2].upper())
  194. entity_description = clean_str(record_attributes[3])
  195. if entity_name in graph.nodes():
  196. node = graph.nodes[entity_name]
  197. if self._join_descriptions:
  198. node["description"] = "\n".join(
  199. list({
  200. *_unpack_descriptions(node),
  201. entity_description,
  202. })
  203. )
  204. else:
  205. if len(entity_description) > len(node["description"]):
  206. node["description"] = entity_description
  207. node["source_id"] = ", ".join(
  208. list({
  209. *_unpack_source_ids(node),
  210. str(source_doc_id),
  211. })
  212. )
  213. node["entity_type"] = (
  214. entity_type if entity_type != "" else node["entity_type"]
  215. )
  216. else:
  217. graph.add_node(
  218. entity_name,
  219. entity_type=entity_type,
  220. description=entity_description,
  221. source_id=str(source_doc_id),
  222. weight=1
  223. )
  224. if (
  225. record_attributes[0] == '"relationship"'
  226. and len(record_attributes) >= 5
  227. ):
  228. # add this record as edge
  229. source = clean_str(record_attributes[1].upper())
  230. target = clean_str(record_attributes[2].upper())
  231. edge_description = clean_str(record_attributes[3])
  232. edge_source_id = clean_str(str(source_doc_id))
  233. weight = (
  234. float(record_attributes[-1])
  235. if isinstance(record_attributes[-1], numbers.Number)
  236. else 1.0
  237. )
  238. if source not in graph.nodes():
  239. graph.add_node(
  240. source,
  241. entity_type="",
  242. description="",
  243. source_id=edge_source_id,
  244. weight=1
  245. )
  246. if target not in graph.nodes():
  247. graph.add_node(
  248. target,
  249. entity_type="",
  250. description="",
  251. source_id=edge_source_id,
  252. weight=1
  253. )
  254. if graph.has_edge(source, target):
  255. edge_data = graph.get_edge_data(source, target)
  256. if edge_data is not None:
  257. weight += edge_data["weight"]
  258. if self._join_descriptions:
  259. edge_description = "\n".join(
  260. list({
  261. *_unpack_descriptions(edge_data),
  262. edge_description,
  263. })
  264. )
  265. edge_source_id = ", ".join(
  266. list({
  267. *_unpack_source_ids(edge_data),
  268. str(source_doc_id),
  269. })
  270. )
  271. graph.add_edge(
  272. source,
  273. target,
  274. weight=weight,
  275. description=edge_description,
  276. source_id=edge_source_id,
  277. )
  278. for node_degree in graph.degree:
  279. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  280. return graph
  281. def _unpack_descriptions(data: Mapping) -> list[str]:
  282. value = data.get("description", None)
  283. return [] if value is None else value.split("\n")
  284. def _unpack_source_ids(data: Mapping) -> list[str]:
  285. value = data.get("source_id", None)
  286. return [] if value is None else value.split(", ")