You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mind_map_extractor.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import collections
  17. import logging
  18. import re
  19. import logging
  20. import traceback
  21. from concurrent.futures import ThreadPoolExecutor
  22. from dataclasses import dataclass
  23. from typing import Any
  24. from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  25. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  26. from rag.llm.chat_model import Base as CompletionLLM
  27. import markdown_to_json
  28. from functools import reduce
  29. from rag.utils import num_tokens_from_string
  30. @dataclass
  31. class MindMapResult:
  32. """Unipartite Mind Graph result class definition."""
  33. output: dict
  34. class MindMapExtractor:
  35. _llm: CompletionLLM
  36. _input_text_key: str
  37. _mind_map_prompt: str
  38. _on_error: ErrorHandlerFn
  39. def __init__(
  40. self,
  41. llm_invoker: CompletionLLM,
  42. prompt: str | None = None,
  43. input_text_key: str | None = None,
  44. on_error: ErrorHandlerFn | None = None,
  45. ):
  46. """Init method definition."""
  47. # TODO: streamline construction
  48. self._llm = llm_invoker
  49. self._input_text_key = input_text_key or "input_text"
  50. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  51. self._on_error = on_error or (lambda _e, _s, _d: None)
  52. def _key(self, k):
  53. return re.sub(r"\*+", "", k)
  54. def _be_children(self, obj: dict, keyset: set):
  55. if isinstance(obj, str):
  56. obj = [obj]
  57. if isinstance(obj, list):
  58. for i in obj: keyset.add(i)
  59. return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj]
  60. arr = []
  61. for k, v in obj.items():
  62. k = self._key(k)
  63. if not k or k in keyset: continue
  64. keyset.add(k)
  65. arr.append({
  66. "id": k,
  67. "children": self._be_children(v, keyset)
  68. })
  69. return arr
  70. def __call__(
  71. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  72. ) -> MindMapResult:
  73. """Call method definition."""
  74. if prompt_variables is None:
  75. prompt_variables = {}
  76. try:
  77. exe = ThreadPoolExecutor(max_workers=12)
  78. threads = []
  79. token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
  80. texts = []
  81. res = []
  82. cnt = 0
  83. for i in range(len(sections)):
  84. section_cnt = num_tokens_from_string(sections[i])
  85. if cnt + section_cnt >= token_count and texts:
  86. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  87. texts = []
  88. cnt = 0
  89. texts.append(sections[i])
  90. cnt += section_cnt
  91. if texts:
  92. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  93. for i, _ in enumerate(threads):
  94. res.append(_.result())
  95. if not res:
  96. return MindMapResult(output={"root":{}})
  97. merge_json = reduce(self._merge, res)
  98. if len(merge_json.keys()) > 1:
  99. keyset = set(
  100. [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
  101. merge_json = {"id": "root",
  102. "children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in
  103. merge_json.items() if isinstance(v, dict) and self._key(k)]}
  104. else:
  105. k = self._key(list(merge_json.keys())[0])
  106. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))}
  107. except Exception as e:
  108. logging.exception("error mind graph")
  109. self._on_error(
  110. e,
  111. traceback.format_exc(), None
  112. )
  113. merge_json = {"error": str(e)}
  114. return MindMapResult(output=merge_json)
  115. def _merge(self, d1, d2):
  116. for k in d1:
  117. if k in d2:
  118. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  119. self._merge(d1[k], d2[k])
  120. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  121. d2[k].extend(d1[k])
  122. else:
  123. d2[k] = d1[k]
  124. else:
  125. d2[k] = d1[k]
  126. return d2
  127. def _list_to_kv(self, data):
  128. for key, value in data.items():
  129. if isinstance(value, dict):
  130. self._list_to_kv(value)
  131. elif isinstance(value, list):
  132. new_value = {}
  133. for i in range(len(value)):
  134. if isinstance(value[i], list):
  135. new_value[value[i - 1]] = value[i][0]
  136. data[key] = new_value
  137. else:
  138. continue
  139. return data
  140. def _todict(self, layer:collections.OrderedDict):
  141. to_ret = layer
  142. if isinstance(layer, collections.OrderedDict):
  143. to_ret = dict(layer)
  144. try:
  145. for key, value in to_ret.items():
  146. to_ret[key] = self._todict(value)
  147. except AttributeError:
  148. pass
  149. return self._list_to_kv(to_ret)
  150. def _process_document(
  151. self, text: str, prompt_variables: dict[str, str]
  152. ) -> str:
  153. variables = {
  154. **prompt_variables,
  155. self._input_text_key: text,
  156. }
  157. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  158. gen_conf = {"temperature": 0.5}
  159. response = self._llm.chat(text, [], gen_conf)
  160. response = re.sub(r"```[^\n]*", "", response)
  161. print(response)
  162. print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response)))
  163. return self._todict(markdown_to_json.dictify(response))