Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

graph_extractor.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import numbers
  9. import re
  10. import traceback
  11. from dataclasses import dataclass
  12. from typing import Any, Mapping, Callable
  13. import tiktoken
  14. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  15. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  16. from rag.llm.chat_model import Base as CompletionLLM
  17. import networkx as nx
  18. from rag.utils import num_tokens_from_string
  19. from timeit import default_timer as timer
  20. DEFAULT_TUPLE_DELIMITER = "<|>"
  21. DEFAULT_RECORD_DELIMITER = "##"
  22. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  23. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  24. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  25. @dataclass
  26. class GraphExtractionResult:
  27. """Unipartite graph extraction result class definition."""
  28. output: nx.Graph
  29. source_docs: dict[Any, Any]
  30. class GraphExtractor:
  31. """Unipartite graph extractor class definition."""
  32. _llm: CompletionLLM
  33. _join_descriptions: bool
  34. _tuple_delimiter_key: str
  35. _record_delimiter_key: str
  36. _entity_types_key: str
  37. _input_text_key: str
  38. _completion_delimiter_key: str
  39. _entity_name_key: str
  40. _input_descriptions_key: str
  41. _extraction_prompt: str
  42. _summarization_prompt: str
  43. _loop_args: dict[str, Any]
  44. _max_gleanings: int
  45. _on_error: ErrorHandlerFn
  46. def __init__(
  47. self,
  48. llm_invoker: CompletionLLM,
  49. prompt: str | None = None,
  50. tuple_delimiter_key: str | None = None,
  51. record_delimiter_key: str | None = None,
  52. input_text_key: str | None = None,
  53. entity_types_key: str | None = None,
  54. completion_delimiter_key: str | None = None,
  55. join_descriptions=True,
  56. encoding_model: str | None = None,
  57. max_gleanings: int | None = None,
  58. on_error: ErrorHandlerFn | None = None,
  59. ):
  60. """Init method definition."""
  61. # TODO: streamline construction
  62. self._llm = llm_invoker
  63. self._join_descriptions = join_descriptions
  64. self._input_text_key = input_text_key or "input_text"
  65. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  66. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  67. self._completion_delimiter_key = (
  68. completion_delimiter_key or "completion_delimiter"
  69. )
  70. self._entity_types_key = entity_types_key or "entity_types"
  71. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  72. self._max_gleanings = (
  73. max_gleanings
  74. if max_gleanings is not None
  75. else ENTITY_EXTRACTION_MAX_GLEANINGS
  76. )
  77. self._on_error = on_error or (lambda _e, _s, _d: None)
  78. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  79. # Construct the looping arguments
  80. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  81. yes = encoding.encode("YES")
  82. no = encoding.encode("NO")
  83. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  84. def __call__(
  85. self, texts: list[str],
  86. prompt_variables: dict[str, Any] | None = None,
  87. callback: Callable | None = None
  88. ) -> GraphExtractionResult:
  89. """Call method definition."""
  90. if prompt_variables is None:
  91. prompt_variables = {}
  92. all_records: dict[int, str] = {}
  93. source_doc_map: dict[int, str] = {}
  94. # Wire defaults into the prompt variables
  95. prompt_variables = {
  96. **prompt_variables,
  97. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  98. or DEFAULT_TUPLE_DELIMITER,
  99. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  100. or DEFAULT_RECORD_DELIMITER,
  101. self._completion_delimiter_key: prompt_variables.get(
  102. self._completion_delimiter_key
  103. )
  104. or DEFAULT_COMPLETION_DELIMITER,
  105. self._entity_types_key: ",".join(
  106. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  107. ),
  108. }
  109. st = timer()
  110. total = len(texts)
  111. total_token_count = 0
  112. for doc_index, text in enumerate(texts):
  113. try:
  114. # Invoke the entity extraction
  115. result, token_count = self._process_document(text, prompt_variables)
  116. source_doc_map[doc_index] = text
  117. all_records[doc_index] = result
  118. total_token_count += token_count
  119. if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  120. except Exception as e:
  121. if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
  122. logging.exception("error extracting graph")
  123. self._on_error(
  124. e,
  125. traceback.format_exc(),
  126. {
  127. "doc_index": doc_index,
  128. "text": text,
  129. },
  130. )
  131. output = self._process_results(
  132. all_records,
  133. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  134. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  135. )
  136. return GraphExtractionResult(
  137. output=output,
  138. source_docs=source_doc_map,
  139. )
  140. def _process_document(
  141. self, text: str, prompt_variables: dict[str, str]
  142. ) -> str:
  143. variables = {
  144. **prompt_variables,
  145. self._input_text_key: text,
  146. }
  147. token_count = 0
  148. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  149. gen_conf = {"temperature": 0.3}
  150. response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  151. if response.find("**ERROR**") >= 0: raise Exception(response)
  152. token_count = num_tokens_from_string(text + response)
  153. results = response or ""
  154. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  155. # Repeat to ensure we maximize entity count
  156. for i in range(self._max_gleanings):
  157. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  158. history.append({"role": "user", "content": text})
  159. response = self._llm.chat("", history, gen_conf)
  160. if response.find("**ERROR**") >=0: raise Exception(response)
  161. results += response or ""
  162. # if this is the final glean, don't bother updating the continuation flag
  163. if i >= self._max_gleanings - 1:
  164. break
  165. history.append({"role": "assistant", "content": response})
  166. history.append({"role": "user", "content": LOOP_PROMPT})
  167. continuation = self._llm.chat("", history, self._loop_args)
  168. if continuation != "YES":
  169. break
  170. return results, token_count
  171. def _process_results(
  172. self,
  173. results: dict[int, str],
  174. tuple_delimiter: str,
  175. record_delimiter: str,
  176. ) -> nx.Graph:
  177. """Parse the result string to create an undirected unipartite graph.
  178. Args:
  179. - results - dict of results from the extraction chain
  180. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  181. - record_delimiter - delimiter between records, default is '##'
  182. Returns:
  183. - output - unipartite graph in graphML format
  184. """
  185. graph = nx.Graph()
  186. for source_doc_id, extracted_data in results.items():
  187. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  188. for record in records:
  189. record = re.sub(r"^\(|\)$", "", record.strip())
  190. record_attributes = record.split(tuple_delimiter)
  191. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  192. # add this record as a node in the G
  193. entity_name = clean_str(record_attributes[1].upper())
  194. entity_type = clean_str(record_attributes[2].upper())
  195. entity_description = clean_str(record_attributes[3])
  196. if entity_name in graph.nodes():
  197. node = graph.nodes[entity_name]
  198. if self._join_descriptions:
  199. node["description"] = "\n".join(
  200. list({
  201. *_unpack_descriptions(node),
  202. entity_description,
  203. })
  204. )
  205. else:
  206. if len(entity_description) > len(node["description"]):
  207. node["description"] = entity_description
  208. node["source_id"] = ", ".join(
  209. list({
  210. *_unpack_source_ids(node),
  211. str(source_doc_id),
  212. })
  213. )
  214. node["entity_type"] = (
  215. entity_type if entity_type != "" else node["entity_type"]
  216. )
  217. else:
  218. graph.add_node(
  219. entity_name,
  220. entity_type=entity_type,
  221. description=entity_description,
  222. source_id=str(source_doc_id),
  223. weight=1
  224. )
  225. if (
  226. record_attributes[0] == '"relationship"'
  227. and len(record_attributes) >= 5
  228. ):
  229. # add this record as edge
  230. source = clean_str(record_attributes[1].upper())
  231. target = clean_str(record_attributes[2].upper())
  232. edge_description = clean_str(record_attributes[3])
  233. edge_source_id = clean_str(str(source_doc_id))
  234. weight = (
  235. float(record_attributes[-1])
  236. if isinstance(record_attributes[-1], numbers.Number)
  237. else 1.0
  238. )
  239. if source not in graph.nodes():
  240. graph.add_node(
  241. source,
  242. entity_type="",
  243. description="",
  244. source_id=edge_source_id,
  245. weight=1
  246. )
  247. if target not in graph.nodes():
  248. graph.add_node(
  249. target,
  250. entity_type="",
  251. description="",
  252. source_id=edge_source_id,
  253. weight=1
  254. )
  255. if graph.has_edge(source, target):
  256. edge_data = graph.get_edge_data(source, target)
  257. if edge_data is not None:
  258. weight += edge_data["weight"]
  259. if self._join_descriptions:
  260. edge_description = "\n".join(
  261. list({
  262. *_unpack_descriptions(edge_data),
  263. edge_description,
  264. })
  265. )
  266. edge_source_id = ", ".join(
  267. list({
  268. *_unpack_source_ids(edge_data),
  269. str(source_doc_id),
  270. })
  271. )
  272. graph.add_edge(
  273. source,
  274. target,
  275. weight=weight,
  276. description=edge_description,
  277. source_id=edge_source_id,
  278. )
  279. for node_degree in graph.degree:
  280. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  281. return graph
  282. def _unpack_descriptions(data: Mapping) -> list[str]:
  283. value = data.get("description", None)
  284. return [] if value is None else value.split("\n")
  285. def _unpack_source_ids(data: Mapping) -> list[str]:
  286. value = data.get("source_id", None)
  287. return [] if value is None else value.split(", ")