Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

mind_map_extractor.py 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import collections
  18. import os
  19. import re
  20. import traceback
  21. from typing import Any
  22. from concurrent.futures import ThreadPoolExecutor
  23. from dataclasses import dataclass
  24. from graphrag.extractor import Extractor
  25. from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  26. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  27. from rag.llm.chat_model import Base as CompletionLLM
  28. import markdown_to_json
  29. from functools import reduce
  30. from rag.utils import num_tokens_from_string
  31. @dataclass
  32. class MindMapResult:
  33. """Unipartite Mind Graph result class definition."""
  34. output: dict
  35. class MindMapExtractor(Extractor):
  36. _input_text_key: str
  37. _mind_map_prompt: str
  38. _on_error: ErrorHandlerFn
  39. def __init__(
  40. self,
  41. llm_invoker: CompletionLLM,
  42. prompt: str | None = None,
  43. input_text_key: str | None = None,
  44. on_error: ErrorHandlerFn | None = None,
  45. ):
  46. """Init method definition."""
  47. # TODO: streamline construction
  48. self._llm = llm_invoker
  49. self._input_text_key = input_text_key or "input_text"
  50. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  51. self._on_error = on_error or (lambda _e, _s, _d: None)
  52. def _key(self, k):
  53. return re.sub(r"\*+", "", k)
  54. def _be_children(self, obj: dict, keyset: set):
  55. if isinstance(obj, str):
  56. obj = [obj]
  57. if isinstance(obj, list):
  58. keyset.update(obj)
  59. obj = [re.sub(r"\*+", "", i) for i in obj]
  60. return [{"id": i, "children": []} for i in obj if i]
  61. arr = []
  62. for k, v in obj.items():
  63. k = self._key(k)
  64. if k and k not in keyset:
  65. keyset.add(k)
  66. arr.append(
  67. {
  68. "id": k,
  69. "children": self._be_children(v, keyset)
  70. }
  71. )
  72. return arr
  73. def __call__(
  74. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  75. ) -> MindMapResult:
  76. """Call method definition."""
  77. if prompt_variables is None:
  78. prompt_variables = {}
  79. try:
  80. res = []
  81. max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
  82. with ThreadPoolExecutor(max_workers=max_workers) as exe:
  83. threads = []
  84. token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
  85. texts = []
  86. cnt = 0
  87. for i in range(len(sections)):
  88. section_cnt = num_tokens_from_string(sections[i])
  89. if cnt + section_cnt >= token_count and texts:
  90. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  91. texts = []
  92. cnt = 0
  93. texts.append(sections[i])
  94. cnt += section_cnt
  95. if texts:
  96. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  97. for i, _ in enumerate(threads):
  98. res.append(_.result())
  99. if not res:
  100. return MindMapResult(output={"id": "root", "children": []})
  101. merge_json = reduce(self._merge, res)
  102. if len(merge_json) > 1:
  103. keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
  104. keyset = set(i for i in keys if i)
  105. merge_json = {
  106. "id": "root",
  107. "children": [
  108. {
  109. "id": self._key(k),
  110. "children": self._be_children(v, keyset)
  111. }
  112. for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
  113. ]
  114. }
  115. else:
  116. k = self._key(list(merge_json.keys())[0])
  117. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
  118. except Exception as e:
  119. logging.exception("error mind graph")
  120. self._on_error(
  121. e,
  122. traceback.format_exc(), None
  123. )
  124. merge_json = {"error": str(e)}
  125. return MindMapResult(output=merge_json)
  126. def _merge(self, d1, d2):
  127. for k in d1:
  128. if k in d2:
  129. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  130. self._merge(d1[k], d2[k])
  131. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  132. d2[k].extend(d1[k])
  133. else:
  134. d2[k] = d1[k]
  135. else:
  136. d2[k] = d1[k]
  137. return d2
  138. def _list_to_kv(self, data):
  139. for key, value in data.items():
  140. if isinstance(value, dict):
  141. self._list_to_kv(value)
  142. elif isinstance(value, list):
  143. new_value = {}
  144. for i in range(len(value)):
  145. if isinstance(value[i], list) and i > 0:
  146. new_value[value[i - 1]] = value[i][0]
  147. data[key] = new_value
  148. else:
  149. continue
  150. return data
  151. def _todict(self, layer: collections.OrderedDict):
  152. to_ret = layer
  153. if isinstance(layer, collections.OrderedDict):
  154. to_ret = dict(layer)
  155. try:
  156. for key, value in to_ret.items():
  157. to_ret[key] = self._todict(value)
  158. except AttributeError:
  159. pass
  160. return self._list_to_kv(to_ret)
  161. def _process_document(
  162. self, text: str, prompt_variables: dict[str, str]
  163. ) -> str:
  164. variables = {
  165. **prompt_variables,
  166. self._input_text_key: text,
  167. }
  168. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  169. gen_conf = {"temperature": 0.5}
  170. response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  171. response = re.sub(r"```[^\n]*", "", response)
  172. logging.debug(response)
  173. logging.debug(self._todict(markdown_to_json.dictify(response)))
  174. return self._todict(markdown_to_json.dictify(response))