Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """
  4. Reference:
  5. - [graphrag](https://github.com/microsoft/graphrag)
  6. """
  7. import logging
  8. import numbers
  9. import re
  10. import traceback
  11. from typing import Any, Callable, Mapping
  12. from dataclasses import dataclass
  13. import tiktoken
  14. from graphrag.extractor import Extractor
  15. from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
  16. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str
  17. from rag.llm.chat_model import Base as CompletionLLM
  18. import networkx as nx
  19. from rag.utils import num_tokens_from_string
  20. from timeit import default_timer as timer
  21. DEFAULT_TUPLE_DELIMITER = "<|>"
  22. DEFAULT_RECORD_DELIMITER = "##"
  23. DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
  24. DEFAULT_ENTITY_TYPES = ["organization", "person", "location", "event", "time"]
  25. ENTITY_EXTRACTION_MAX_GLEANINGS = 1
  26. @dataclass
  27. class GraphExtractionResult:
  28. """Unipartite graph extraction result class definition."""
  29. output: nx.Graph
  30. source_docs: dict[Any, Any]
  31. class GraphExtractor(Extractor):
  32. """Unipartite graph extractor class definition."""
  33. _join_descriptions: bool
  34. _tuple_delimiter_key: str
  35. _record_delimiter_key: str
  36. _entity_types_key: str
  37. _input_text_key: str
  38. _completion_delimiter_key: str
  39. _entity_name_key: str
  40. _input_descriptions_key: str
  41. _extraction_prompt: str
  42. _summarization_prompt: str
  43. _loop_args: dict[str, Any]
  44. _max_gleanings: int
  45. _on_error: ErrorHandlerFn
  46. def __init__(
  47. self,
  48. llm_invoker: CompletionLLM,
  49. prompt: str | None = None,
  50. tuple_delimiter_key: str | None = None,
  51. record_delimiter_key: str | None = None,
  52. input_text_key: str | None = None,
  53. entity_types_key: str | None = None,
  54. completion_delimiter_key: str | None = None,
  55. join_descriptions=True,
  56. encoding_model: str | None = None,
  57. max_gleanings: int | None = None,
  58. on_error: ErrorHandlerFn | None = None,
  59. ):
  60. """Init method definition."""
  61. # TODO: streamline construction
  62. self._llm = llm_invoker
  63. self._join_descriptions = join_descriptions
  64. self._input_text_key = input_text_key or "input_text"
  65. self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
  66. self._record_delimiter_key = record_delimiter_key or "record_delimiter"
  67. self._completion_delimiter_key = (
  68. completion_delimiter_key or "completion_delimiter"
  69. )
  70. self._entity_types_key = entity_types_key or "entity_types"
  71. self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT
  72. self._max_gleanings = (
  73. max_gleanings
  74. if max_gleanings is not None
  75. else ENTITY_EXTRACTION_MAX_GLEANINGS
  76. )
  77. self._on_error = on_error or (lambda _e, _s, _d: None)
  78. self.prompt_token_count = num_tokens_from_string(self._extraction_prompt)
  79. # Construct the looping arguments
  80. encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
  81. yes = encoding.encode("YES")
  82. no = encoding.encode("NO")
  83. self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
  84. def __call__(
  85. self, texts: list[str],
  86. prompt_variables: dict[str, Any] | None = None,
  87. callback: Callable | None = None
  88. ) -> GraphExtractionResult:
  89. """Call method definition."""
  90. if prompt_variables is None:
  91. prompt_variables = {}
  92. all_records: dict[int, str] = {}
  93. source_doc_map: dict[int, str] = {}
  94. # Wire defaults into the prompt variables
  95. prompt_variables = {
  96. **prompt_variables,
  97. self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
  98. or DEFAULT_TUPLE_DELIMITER,
  99. self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
  100. or DEFAULT_RECORD_DELIMITER,
  101. self._completion_delimiter_key: prompt_variables.get(
  102. self._completion_delimiter_key
  103. )
  104. or DEFAULT_COMPLETION_DELIMITER,
  105. self._entity_types_key: ",".join(
  106. prompt_variables.get(self._entity_types_key) or DEFAULT_ENTITY_TYPES
  107. ),
  108. }
  109. st = timer()
  110. total = len(texts)
  111. total_token_count = 0
  112. for doc_index, text in enumerate(texts):
  113. try:
  114. # Invoke the entity extraction
  115. result, token_count = self._process_document(text, prompt_variables)
  116. source_doc_map[doc_index] = text
  117. all_records[doc_index] = result
  118. total_token_count += token_count
  119. if callback:
  120. callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
  121. except Exception as e:
  122. if callback:
  123. callback(msg="Knowledge graph extraction error:{}".format(str(e)))
  124. logging.exception("error extracting graph")
  125. self._on_error(
  126. e,
  127. traceback.format_exc(),
  128. {
  129. "doc_index": doc_index,
  130. "text": text,
  131. },
  132. )
  133. output = self._process_results(
  134. all_records,
  135. prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER),
  136. prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER),
  137. )
  138. return GraphExtractionResult(
  139. output=output,
  140. source_docs=source_doc_map,
  141. )
  142. def _process_document(
  143. self, text: str, prompt_variables: dict[str, str]
  144. ) -> str:
  145. variables = {
  146. **prompt_variables,
  147. self._input_text_key: text,
  148. }
  149. token_count = 0
  150. text = perform_variable_replacements(self._extraction_prompt, variables=variables)
  151. gen_conf = {"temperature": 0.3}
  152. response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  153. token_count = num_tokens_from_string(text + response)
  154. results = response or ""
  155. history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}]
  156. # Repeat to ensure we maximize entity count
  157. for i in range(self._max_gleanings):
  158. text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
  159. history.append({"role": "user", "content": text})
  160. response = self._chat("", history, gen_conf)
  161. results += response or ""
  162. # if this is the final glean, don't bother updating the continuation flag
  163. if i >= self._max_gleanings - 1:
  164. break
  165. history.append({"role": "assistant", "content": response})
  166. history.append({"role": "user", "content": LOOP_PROMPT})
  167. continuation = self._chat("", history, self._loop_args)
  168. if continuation != "YES":
  169. break
  170. return results, token_count
  171. def _process_results(
  172. self,
  173. results: dict[int, str],
  174. tuple_delimiter: str,
  175. record_delimiter: str,
  176. ) -> nx.Graph:
  177. """Parse the result string to create an undirected unipartite graph.
  178. Args:
  179. - results - dict of results from the extraction chain
  180. - tuple_delimiter - delimiter between tuples in an output record, default is '<|>'
  181. - record_delimiter - delimiter between records, default is '##'
  182. Returns:
  183. - output - unipartite graph in graphML format
  184. """
  185. graph = nx.Graph()
  186. for source_doc_id, extracted_data in results.items():
  187. records = [r.strip() for r in extracted_data.split(record_delimiter)]
  188. for record in records:
  189. record = re.sub(r"^\(|\)$", "", record.strip())
  190. record_attributes = record.split(tuple_delimiter)
  191. if record_attributes[0] == '"entity"' and len(record_attributes) >= 4:
  192. # add this record as a node in the G
  193. entity_name = clean_str(record_attributes[1].upper())
  194. entity_type = clean_str(record_attributes[2].upper())
  195. entity_description = clean_str(record_attributes[3])
  196. if entity_name in graph.nodes():
  197. node = graph.nodes[entity_name]
  198. if self._join_descriptions:
  199. node["description"] = "\n".join(
  200. list({
  201. *_unpack_descriptions(node),
  202. entity_description,
  203. })
  204. )
  205. else:
  206. if len(entity_description) > len(node["description"]):
  207. node["description"] = entity_description
  208. node["source_id"] = ", ".join(
  209. list({
  210. *_unpack_source_ids(node),
  211. str(source_doc_id),
  212. })
  213. )
  214. node["entity_type"] = (
  215. entity_type if entity_type != "" else node["entity_type"]
  216. )
  217. else:
  218. graph.add_node(
  219. entity_name,
  220. entity_type=entity_type,
  221. description=entity_description,
  222. source_id=str(source_doc_id),
  223. weight=1
  224. )
  225. if (
  226. record_attributes[0] == '"relationship"'
  227. and len(record_attributes) >= 5
  228. ):
  229. # add this record as edge
  230. source = clean_str(record_attributes[1].upper())
  231. target = clean_str(record_attributes[2].upper())
  232. edge_description = clean_str(record_attributes[3])
  233. edge_source_id = clean_str(str(source_doc_id))
  234. weight = (
  235. float(record_attributes[-1])
  236. if isinstance(record_attributes[-1], numbers.Number)
  237. else 1.0
  238. )
  239. if source not in graph.nodes():
  240. graph.add_node(
  241. source,
  242. entity_type="",
  243. description="",
  244. source_id=edge_source_id,
  245. weight=1
  246. )
  247. if target not in graph.nodes():
  248. graph.add_node(
  249. target,
  250. entity_type="",
  251. description="",
  252. source_id=edge_source_id,
  253. weight=1
  254. )
  255. if graph.has_edge(source, target):
  256. edge_data = graph.get_edge_data(source, target)
  257. if edge_data is not None:
  258. weight += edge_data["weight"]
  259. if self._join_descriptions:
  260. edge_description = "\n".join(
  261. list({
  262. *_unpack_descriptions(edge_data),
  263. edge_description,
  264. })
  265. )
  266. edge_source_id = ", ".join(
  267. list({
  268. *_unpack_source_ids(edge_data),
  269. str(source_doc_id),
  270. })
  271. )
  272. graph.add_edge(
  273. source,
  274. target,
  275. weight=weight,
  276. description=edge_description,
  277. source_id=edge_source_id,
  278. )
  279. for node_degree in graph.degree:
  280. graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
  281. return graph
  282. def _unpack_descriptions(data: Mapping) -> list[str]:
  283. value = data.get("description", None)
  284. return [] if value is None else value.split("\n")
  285. def _unpack_source_ids(data: Mapping) -> list[str]:
  286. value = data.get("source_id", None)
  287. return [] if value is None else value.split(", ")