Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

mind_map_extractor.py 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import collections
  17. import logging
  18. import os
  19. import re
  20. import traceback
  21. from concurrent.futures import ThreadPoolExecutor
  22. from dataclasses import dataclass
  23. from typing import Any
  24. from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  25. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  26. from rag.llm.chat_model import Base as CompletionLLM
  27. import markdown_to_json
  28. from functools import reduce
  29. from rag.utils import num_tokens_from_string
  30. from api.utils.log_utils import logger
  31. @dataclass
  32. class MindMapResult:
  33. """Unipartite Mind Graph result class definition."""
  34. output: dict
  35. class MindMapExtractor:
  36. _llm: CompletionLLM
  37. _input_text_key: str
  38. _mind_map_prompt: str
  39. _on_error: ErrorHandlerFn
  40. def __init__(
  41. self,
  42. llm_invoker: CompletionLLM,
  43. prompt: str | None = None,
  44. input_text_key: str | None = None,
  45. on_error: ErrorHandlerFn | None = None,
  46. ):
  47. """Init method definition."""
  48. # TODO: streamline construction
  49. self._llm = llm_invoker
  50. self._input_text_key = input_text_key or "input_text"
  51. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  52. self._on_error = on_error or (lambda _e, _s, _d: None)
  53. def _key(self, k):
  54. return re.sub(r"\*+", "", k)
  55. def _be_children(self, obj: dict, keyset: set):
  56. if isinstance(obj, str):
  57. obj = [obj]
  58. if isinstance(obj, list):
  59. keyset.update(obj)
  60. obj = [re.sub(r"\*+", "", i) for i in obj]
  61. return [{"id": i, "children": []} for i in obj if i]
  62. arr = []
  63. for k, v in obj.items():
  64. k = self._key(k)
  65. if k and k not in keyset:
  66. keyset.add(k)
  67. arr.append(
  68. {
  69. "id": k,
  70. "children": self._be_children(v, keyset)
  71. }
  72. )
  73. return arr
  74. def __call__(
  75. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  76. ) -> MindMapResult:
  77. """Call method definition."""
  78. if prompt_variables is None:
  79. prompt_variables = {}
  80. try:
  81. max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
  82. exe = ThreadPoolExecutor(max_workers=max_workers)
  83. threads = []
  84. token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
  85. texts = []
  86. res = []
  87. cnt = 0
  88. for i in range(len(sections)):
  89. section_cnt = num_tokens_from_string(sections[i])
  90. if cnt + section_cnt >= token_count and texts:
  91. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  92. texts = []
  93. cnt = 0
  94. texts.append(sections[i])
  95. cnt += section_cnt
  96. if texts:
  97. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  98. for i, _ in enumerate(threads):
  99. res.append(_.result())
  100. if not res:
  101. return MindMapResult(output={"id": "root", "children": []})
  102. merge_json = reduce(self._merge, res)
  103. if len(merge_json) > 1:
  104. keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
  105. keyset = set(i for i in keys if i)
  106. merge_json = {
  107. "id": "root",
  108. "children": [
  109. {
  110. "id": self._key(k),
  111. "children": self._be_children(v, keyset)
  112. }
  113. for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
  114. ]
  115. }
  116. else:
  117. k = self._key(list(merge_json.keys())[0])
  118. merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
  119. except Exception as e:
  120. logging.exception("error mind graph")
  121. self._on_error(
  122. e,
  123. traceback.format_exc(), None
  124. )
  125. merge_json = {"error": str(e)}
  126. return MindMapResult(output=merge_json)
  127. def _merge(self, d1, d2):
  128. for k in d1:
  129. if k in d2:
  130. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  131. self._merge(d1[k], d2[k])
  132. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  133. d2[k].extend(d1[k])
  134. else:
  135. d2[k] = d1[k]
  136. else:
  137. d2[k] = d1[k]
  138. return d2
  139. def _list_to_kv(self, data):
  140. for key, value in data.items():
  141. if isinstance(value, dict):
  142. self._list_to_kv(value)
  143. elif isinstance(value, list):
  144. new_value = {}
  145. for i in range(len(value)):
  146. if isinstance(value[i], list) and i > 0:
  147. new_value[value[i - 1]] = value[i][0]
  148. data[key] = new_value
  149. else:
  150. continue
  151. return data
  152. def _todict(self, layer: collections.OrderedDict):
  153. to_ret = layer
  154. if isinstance(layer, collections.OrderedDict):
  155. to_ret = dict(layer)
  156. try:
  157. for key, value in to_ret.items():
  158. to_ret[key] = self._todict(value)
  159. except AttributeError:
  160. pass
  161. return self._list_to_kv(to_ret)
  162. def _process_document(
  163. self, text: str, prompt_variables: dict[str, str]
  164. ) -> str:
  165. variables = {
  166. **prompt_variables,
  167. self._input_text_key: text,
  168. }
  169. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  170. gen_conf = {"temperature": 0.5}
  171. response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
  172. response = re.sub(r"```[^\n]*", "", response)
  173. logger.info(response)
  174. logger.info(self._todict(markdown_to_json.dictify(response)))
  175. return self._todict(markdown_to_json.dictify(response))