You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_extractor.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import numbers
  9. import re
  10. import traceback
  11. from typing import Any, Callable, Mapping
  12. from dataclasses import dataclass
  13. import tiktoken
  14. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  15. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  16. from rag.llm.chat_model import Base as CompletionLLM
  17. import networkx as nx
  18. from rag.utils import num_tokens_from_string
  19. from timeit import default_timer as timer
  20. DEFAULT_TUPLE_DELIMITER = "<|>"
  21. DEFAULT_RECORD_DELIMITER = "##"
  22. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  23. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  24. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  25. @dataclass
  26. class GraphExtractionResult:
  27. """Unipartite graph extraction result class definition."""
  28. output: nx.Graph
  29. source_docs: dict[Any, Any]
  30. class GraphExtractor:
  31. """Unipartite graph extractor class definition."""
  32. _llm: CompletionLLM
  33. _join_descriptions: bool
  34. _tuple_delimiter_key: str
  35. _record_delimiter_key: str
  36. _entity_types_key: str
  37. _input_text_key: str
  38. _completion_delimiter_key: str
  39. _entity_name_key: str
  40. _input_descriptions_key: str
  41. _extraction_prompt: str
  42. _summarization_prompt: str
  43. _loop_args: dict[str, Any]
  44. _max_gleanings: int
  45. _on_error: ErrorHandlerFn
  46. def __init__(
  47. self,
  48. llm_invoker: CompletionLLM,
  49. prompt: str | None = None,
  50. tuple_delimiter_key: str | None = None,
  51. record_delimiter_key: str | None = None,
  52. input_text_key: str | None = None,
  53. entity_types_key: str | None = None,
  54. completion_delimiter_key: str | None = None,
  55. join_descriptions=True,
  56. encoding_model: str | None = None,
  57. max_gleanings: int | None = None,
  58. on_error: ErrorHandlerFn | None = None,
  59. ):
  60. """Init method definition."""
  61. # TODO: streamline construction
  62. self._llm = llm_invoker
  63. self._join_descriptions = join_descriptions
  64. self._input_text_key = input_text_key or "input_text"
  65. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  66. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  67. self._completion_delimiter_key = (
  68. completion_delimiter_key or "completion_delimiter"
  69. )
  70. self._entity_types_key = entity_types_key or "entity_types"
  71. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  72. self._max_gleanings = (
  73. max_gleanings
  74. if max_gleanings is not None
  75. else ENTITY_EXTRACTION_MAX_GLEANINGS
  76. )
  77. self._on_error = on_error or (lambda _e, _s, _d: None)
  78. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  79. # Construct the looping arguments
  80. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  81. yes = encoding.encode("YES")
  82. no = encoding.encode("NO")
  83. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  84. def __call__(
  85. self, texts: list[str],
  86. prompt_variables: dict[str, Any] | None = None,
  87. callback: Callable | None = None
  88. ) -> GraphExtractionResult:
  89. """Call method definition."""
  90. if prompt_variables is None:
  91. prompt_variables = {}
  92. all_records: dict[int, str] = {}
  93. source_doc_map: dict[int, str] = {}
  94. # Wire defaults into the prompt variables
  95. prompt_variables = {
  96. **prompt_variables,
  97. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  98. or DEFAULT_TUPLE_DELIMITER,
  99. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  100. or DEFAULT_RECORD_DELIMITER,
  101. self._completion_delimiter_key: prompt_variables.get(
  102. self._completion_delimiter_key
  103. )
  104. or DEFAULT_COMPLETION_DELIMITER,
  105. self._entity_types_key: ",".join(
  106. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  107. ),
  108. }
  109. st = timer()
  110. total = len(texts)
  111. total_token_count = 0
  112. for doc_index, text in enumerate(texts):
  113. try:
  114. # Invoke the entity extraction
  115. result, token_count = self._process_document(text, prompt_variables)
  116. source_doc_map[doc_index] = text
  117. all_records[doc_index] = result
  118. total_token_count += token_count
  119. if callback:
  120. callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  121. except Exception as e:
  122. if callback:
  123. callback(msg="Knowledge graph extraction error:{}".format(str(e)))
  124. logging.exception("error extracting graph")
  125. self._on_error(
  126. e,
  127. traceback.format_exc(),
  128. {
  129. "doc_index": doc_index,
  130. "text": text,
  131. },
  132. )
  133. output = self._process_results(
  134. all_records,
  135. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  136. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  137. )
  138. return GraphExtractionResult(
  139. output=output,
  140. source_docs=source_doc_map,
  141. )
  142. def _process_document(
  143. self, text: str, prompt_variables: dict[str, str]
  144. ) -> str:
  145. variables = {
  146. **prompt_variables,
  147. self._input_text_key: text,
  148. }
  149. token_count = 0
  150. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  151. gen_conf = {"temperature": 0.3}
  152. response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  153. if response.find("**ERROR**") >= 0:
  154. raise Exception(response)
  155. token_count = num_tokens_from_string(text + response)
  156. results = response or ""
  157. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  158. # Repeat to ensure we maximize entity count
  159. for i in range(self._max_gleanings):
  160. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  161. history.append({"role": "user", "content": text})
  162. response = self._llm.chat("", history, gen_conf)
  163. if response.find("**ERROR**") >=0:
  164. raise Exception(response)
  165. results += response or ""
  166. # if this is the final glean, don't bother updating the continuation flag
  167. if i >= self._max_gleanings - 1:
  168. break
  169. history.append({"role": "assistant", "content": response})
  170. history.append({"role": "user", "content": LOOP_PROMPT})
  171. continuation = self._llm.chat("", history, self._loop_args)
  172. if continuation != "YES":
  173. break
  174. return results, token_count
  175. def _process_results(
  176. self,
  177. results: dict[int, str],
  178. tuple_delimiter: str,
  179. record_delimiter: str,
  180. ) -> nx.Graph:
  181. """Parse the result string to create an undirected unipartite graph.
  182. Args:
  183. - results - dict of results from the extraction chain
  184. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  185. - record_delimiter - delimiter between records, default is '##'
  186. Returns:
  187. - output - unipartite graph in graphML format
  188. """
  189. graph = nx.Graph()
  190. for source_doc_id, extracted_data in results.items():
  191. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  192. for record in records:
  193. record = re.sub(r"^\(|\)$", "", record.strip())
  194. record_attributes = record.split(tuple_delimiter)
  195. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  196. # add this record as a node in the G
  197. entity_name = clean_str(record_attributes[1].upper())
  198. entity_type = clean_str(record_attributes[2].upper())
  199. entity_description = clean_str(record_attributes[3])
  200. if entity_name in graph.nodes():
  201. node = graph.nodes[entity_name]
  202. if self._join_descriptions:
  203. node["description"] = "\n".join(
  204. list({
  205. *_unpack_descriptions(node),
  206. entity_description,
  207. })
  208. )
  209. else:
  210. if len(entity_description) > len(node["description"]):
  211. node["description"] = entity_description
  212. node["source_id"] = ", ".join(
  213. list({
  214. *_unpack_source_ids(node),
  215. str(source_doc_id),
  216. })
  217. )
  218. node["entity_type"] = (
  219. entity_type if entity_type != "" else node["entity_type"]
  220. )
  221. else:
  222. graph.add_node(
  223. entity_name,
  224. entity_type=entity_type,
  225. description=entity_description,
  226. source_id=str(source_doc_id),
  227. weight=1
  228. )
  229. if (
  230. record_attributes[0] == '"relationship"'
  231. and len(record_attributes) >= 5
  232. ):
  233. # add this record as edge
  234. source = clean_str(record_attributes[1].upper())
  235. target = clean_str(record_attributes[2].upper())
  236. edge_description = clean_str(record_attributes[3])
  237. edge_source_id = clean_str(str(source_doc_id))
  238. weight = (
  239. float(record_attributes[-1])
  240. if isinstance(record_attributes[-1], numbers.Number)
  241. else 1.0
  242. )
  243. if source not in graph.nodes():
  244. graph.add_node(
  245. source,
  246. entity_type="",
  247. description="",
  248. source_id=edge_source_id,
  249. weight=1
  250. )
  251. if target not in graph.nodes():
  252. graph.add_node(
  253. target,
  254. entity_type="",
  255. description="",
  256. source_id=edge_source_id,
  257. weight=1
  258. )
  259. if graph.has_edge(source, target):
  260. edge_data = graph.get_edge_data(source, target)
  261. if edge_data is not None:
  262. weight += edge_data["weight"]
  263. if self._join_descriptions:
  264. edge_description = "\n".join(
  265. list({
  266. *_unpack_descriptions(edge_data),
  267. edge_description,
  268. })
  269. )
  270. edge_source_id = ", ".join(
  271. list({
  272. *_unpack_source_ids(edge_data),
  273. str(source_doc_id),
  274. })
  275. )
  276. graph.add_edge(
  277. source,
  278. target,
  279. weight=weight,
  280. description=edge_description,
  281. source_id=edge_source_id,
  282. )
  283. for node_degree in graph.degree:
  284. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  285. return graph
  286. def _unpack_descriptions(data: Mapping) -> list[str]:
  287. value = data.get("description", None)
  288. return [] if value is None else value.split("\n")
  289. def _unpack_source_ids(data: Mapping) -> list[str]:
  290. value = data.get("source_id", None)
  291. return [] if value is None else value.split(", ")