You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mind_map_extractor.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import collections
  17. import logging
  18. import re
  19. import logging
  20. import traceback
  21. from concurrent.futures import ThreadPoolExecutor
  22. from dataclasses import dataclass
  23. from typing import Any
  24. from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  25. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  26. from rag.llm.chat_model import Base as CompletionLLM
  27. import markdown_to_json
  28. from functools import reduce
  29. from rag.utils import num_tokens_from_string
  30. @dataclass
  31. class MindMapResult:
  32. """Unipartite Mind Graph result class definition."""
  33. output: dict
  34. class MindMapExtractor:
  35. _llm: CompletionLLM
  36. _input_text_key: str
  37. _mind_map_prompt: str
  38. _on_error: ErrorHandlerFn
  39. def __init__(
  40. self,
  41. llm_invoker: CompletionLLM,
  42. prompt: str | None = None,
  43. input_text_key: str | None = None,
  44. on_error: ErrorHandlerFn | None = None,
  45. ):
  46. """Init method definition."""
  47. # TODO: streamline construction
  48. self._llm = llm_invoker
  49. self._input_text_key = input_text_key or "input_text"
  50. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  51. self._on_error = on_error or (lambda _e, _s, _d: None)
  52. def _key(self, k):
  53. return re.sub(r"\*+", "", k)
  54. def _be_children(self, obj: dict, keyset: set):
  55. if isinstance(obj, str):
  56. obj = [obj]
  57. if isinstance(obj, list):
  58. for i in obj: keyset.add(i)
  59. return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj]
  60. arr = []
  61. for k, v in obj.items():
  62. k = self._key(k)
  63. if not k or k in keyset: continue
  64. keyset.add(k)
  65. arr.append({
  66. "id": k,
  67. "children": self._be_children(v, keyset)
  68. })
  69. return arr
  70. def __call__(
  71. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  72. ) -> MindMapResult:
  73. """Call method definition."""
  74. if prompt_variables is None:
  75. prompt_variables = {}
  76. try:
  77. exe = ThreadPoolExecutor(max_workers=12)
  78. threads = []
  79. token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
  80. texts = []
  81. res = []
  82. cnt = 0
  83. for i in range(len(sections)):
  84. section_cnt = num_tokens_from_string(sections[i])
  85. if cnt + section_cnt >= token_count and texts:
  86. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  87. texts = []
  88. cnt = 0
  89. texts.append(sections[i])
  90. cnt += section_cnt
  91. if texts:
  92. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  93. for i, _ in enumerate(threads):
  94. res.append(_.result())
  95. merge_json = reduce(self._merge, res)
  96. if len(merge_json.keys()) > 1:
  97. keyset = set(
  98. [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
  99. merge_json = {"id": "root",
  100. "children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in
  101. merge_json.items() if isinstance(v, dict) and self._key(k)]}
  102. else:
  103. k = self._key(list(merge_json.keys())[0])
  104. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))}
  105. except Exception as e:
  106. logging.exception("error mind graph")
  107. self._on_error(
  108. e,
  109. traceback.format_exc(), None
  110. )
  111. merge_json = {"error": str(e)}
  112. return MindMapResult(output=merge_json)
  113. def _merge(self, d1, d2):
  114. for k in d1:
  115. if k in d2:
  116. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  117. self._merge(d1[k], d2[k])
  118. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  119. d2[k].extend(d1[k])
  120. else:
  121. d2[k] = d1[k]
  122. else:
  123. d2[k] = d1[k]
  124. return d2
  125. def _list_to_kv(self, data):
  126. for key, value in data.items():
  127. if isinstance(value, dict):
  128. self._list_to_kv(value)
  129. elif isinstance(value, list):
  130. new_value = {}
  131. for i in range(len(value)):
  132. if isinstance(value[i], list):
  133. new_value[value[i - 1]] = value[i][0]
  134. data[key] = new_value
  135. else:
  136. continue
  137. return data
  138. def _todict(self, layer:collections.OrderedDict):
  139. to_ret = layer
  140. if isinstance(layer, collections.OrderedDict):
  141. to_ret = dict(layer)
  142. try:
  143. for key, value in to_ret.items():
  144. to_ret[key] = self._todict(value)
  145. except AttributeError:
  146. pass
  147. return self._list_to_kv(to_ret)
  148. def _process_document(
  149. self, text: str, prompt_variables: dict[str, str]
  150. ) -> str:
  151. variables = {
  152. **prompt_variables,
  153. self._input_text_key: text,
  154. }
  155. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  156. gen_conf = {"temperature": 0.5}
  157. response = self._llm.chat(text, [], gen_conf)
  158. response = re.sub(r"```[^\n]*", "", response)
  159. print(response)
  160. print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response)))
  161. return self._todict(markdown_to_json.dictify(response))