You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mind_map_extractor.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import collections
  18. import re
  19. from typing import Any
  20. from dataclasses import dataclass
  21. import trio
  22. from graphrag.general.extractor import Extractor
  23. from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  24. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
  25. from rag.llm.chat_model import Base as CompletionLLM
  26. import markdown_to_json
  27. from functools import reduce
  28. from rag.utils import num_tokens_from_string
  29. @dataclass
  30. class MindMapResult:
  31. """Unipartite Mind Graph result class definition."""
  32. output: dict
  33. class MindMapExtractor(Extractor):
  34. _input_text_key: str
  35. _mind_map_prompt: str
  36. _on_error: ErrorHandlerFn
  37. def __init__(
  38. self,
  39. llm_invoker: CompletionLLM,
  40. prompt: str | None = None,
  41. input_text_key: str | None = None,
  42. on_error: ErrorHandlerFn | None = None,
  43. ):
  44. """Init method definition."""
  45. # TODO: streamline construction
  46. self._llm = llm_invoker
  47. self._input_text_key = input_text_key or "input_text"
  48. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  49. self._on_error = on_error or (lambda _e, _s, _d: None)
  50. def _key(self, k):
  51. return re.sub(r"\*+", "", k)
  52. def _be_children(self, obj: dict, keyset: set):
  53. if isinstance(obj, str):
  54. obj = [obj]
  55. if isinstance(obj, list):
  56. keyset.update(obj)
  57. obj = [re.sub(r"\*+", "", i) for i in obj]
  58. return [{"id": i, "children": []} for i in obj if i]
  59. arr = []
  60. for k, v in obj.items():
  61. k = self._key(k)
  62. if k and k not in keyset:
  63. keyset.add(k)
  64. arr.append(
  65. {
  66. "id": k,
  67. "children": self._be_children(v, keyset)
  68. }
  69. )
  70. return arr
  71. async def __call__(
  72. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  73. ) -> MindMapResult:
  74. """Call method definition."""
  75. if prompt_variables is None:
  76. prompt_variables = {}
  77. res = []
  78. token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
  79. texts = []
  80. cnt = 0
  81. async with trio.open_nursery() as nursery:
  82. for i in range(len(sections)):
  83. section_cnt = num_tokens_from_string(sections[i])
  84. if cnt + section_cnt >= token_count and texts:
  85. nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
  86. texts = []
  87. cnt = 0
  88. texts.append(sections[i])
  89. cnt += section_cnt
  90. if texts:
  91. nursery.start_soon(self._process_document, "".join(texts), prompt_variables, res)
  92. if not res:
  93. return MindMapResult(output={"id": "root", "children": []})
  94. merge_json = reduce(self._merge, res)
  95. if len(merge_json) > 1:
  96. keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
  97. keyset = set(i for i in keys if i)
  98. merge_json = {
  99. "id": "root",
  100. "children": [
  101. {
  102. "id": self._key(k),
  103. "children": self._be_children(v, keyset)
  104. }
  105. for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
  106. ]
  107. }
  108. else:
  109. k = self._key(list(merge_json.keys())[0])
  110. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
  111. return MindMapResult(output=merge_json)
  112. def _merge(self, d1, d2):
  113. for k in d1:
  114. if k in d2:
  115. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  116. self._merge(d1[k], d2[k])
  117. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  118. d2[k].extend(d1[k])
  119. else:
  120. d2[k] = d1[k]
  121. else:
  122. d2[k] = d1[k]
  123. return d2
  124. def _list_to_kv(self, data):
  125. for key, value in data.items():
  126. if isinstance(value, dict):
  127. self._list_to_kv(value)
  128. elif isinstance(value, list):
  129. new_value = {}
  130. for i in range(len(value)):
  131. if isinstance(value[i], list) and i > 0:
  132. new_value[value[i - 1]] = value[i][0]
  133. data[key] = new_value
  134. else:
  135. continue
  136. return data
  137. def _todict(self, layer: collections.OrderedDict):
  138. to_ret = layer
  139. if isinstance(layer, collections.OrderedDict):
  140. to_ret = dict(layer)
  141. try:
  142. for key, value in to_ret.items():
  143. to_ret[key] = self._todict(value)
  144. except AttributeError:
  145. pass
  146. return self._list_to_kv(to_ret)
  147. async def _process_document(
  148. self, text: str, prompt_variables: dict[str, str], out_res
  149. ) -> str:
  150. variables = {
  151. **prompt_variables,
  152. self._input_text_key: text,
  153. }
  154. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  155. gen_conf = {"temperature": 0.5}
  156. async with chat_limiter:
  157. response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
  158. response = re.sub(r"```[^\n]*", "", response)
  159. logging.debug(response)
  160. logging.debug(self._todict(markdown_to_json.dictify(response)))
  161. out_res.append(self._todict(markdown_to_json.dictify(response)))