You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mind_map_extractor.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import traceback
  18. from concurrent.futures import ThreadPoolExecutor
  19. from dataclasses import dataclass
  20. from typing import Any
  21. from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
  22. from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
  23. from rag.llm.chat_model import Base as CompletionLLM
  24. import markdown_to_json
  25. from functools import reduce
  26. from rag.utils import num_tokens_from_string
  27. @dataclass
  28. class MindMapResult:
  29. """Unipartite Mind Graph result class definition."""
  30. output: dict
  31. class MindMapExtractor:
  32. _llm: CompletionLLM
  33. _input_text_key: str
  34. _mind_map_prompt: str
  35. _on_error: ErrorHandlerFn
  36. def __init__(
  37. self,
  38. llm_invoker: CompletionLLM,
  39. prompt: str | None = None,
  40. input_text_key: str | None = None,
  41. on_error: ErrorHandlerFn | None = None,
  42. ):
  43. """Init method definition."""
  44. # TODO: streamline construction
  45. self._llm = llm_invoker
  46. self._input_text_key = input_text_key or "input_text"
  47. self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
  48. self._on_error = on_error or (lambda _e, _s, _d: None)
  49. def __call__(
  50. self, sections: list[str], prompt_variables: dict[str, Any] | None = None
  51. ) -> MindMapResult:
  52. """Call method definition."""
  53. if prompt_variables is None:
  54. prompt_variables = {}
  55. try:
  56. exe = ThreadPoolExecutor(max_workers=12)
  57. threads = []
  58. token_count = self._llm.max_length * 0.7
  59. texts = []
  60. res = []
  61. cnt = 0
  62. for i in range(len(sections)):
  63. section_cnt = num_tokens_from_string(sections[i])
  64. if cnt + section_cnt >= token_count and texts:
  65. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  66. texts = []
  67. cnt = 0
  68. texts.append(sections[i])
  69. cnt += section_cnt
  70. if texts:
  71. threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
  72. for i, _ in enumerate(threads):
  73. res.append(_.result())
  74. merge_json = reduce(self._merge, res)
  75. merge_json = self._list_to_kv(merge_json)
  76. except Exception as e:
  77. logging.exception("error mind graph")
  78. self._on_error(
  79. e,
  80. traceback.format_exc(), None
  81. )
  82. return MindMapResult(output=merge_json)
  83. def _merge(self, d1, d2):
  84. for k in d1:
  85. if k in d2:
  86. if isinstance(d1[k], dict) and isinstance(d2[k], dict):
  87. self._merge(d1[k], d2[k])
  88. elif isinstance(d1[k], list) and isinstance(d2[k], list):
  89. d2[k].extend(d1[k])
  90. else:
  91. d2[k] = d1[k]
  92. else:
  93. d2[k] = d1[k]
  94. return d2
  95. def _list_to_kv(self, data):
  96. for key, value in data.items():
  97. if isinstance(value, dict):
  98. self._list_to_kv(value)
  99. elif isinstance(value, list):
  100. new_value = {}
  101. for i in range(len(value)):
  102. if isinstance(value[i], list):
  103. new_value[value[i - 1]] = value[i][0]
  104. data[key] = new_value
  105. else:
  106. continue
  107. return data
  108. def _process_document(
  109. self, text: str, prompt_variables: dict[str, str]
  110. ) -> str:
  111. variables = {
  112. **prompt_variables,
  113. self._input_text_key: text,
  114. }
  115. text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
  116. gen_conf = {"temperature": 0.5}
  117. response = self._llm.chat(text, [], gen_conf)
  118. print(response)
  119. print("---------------------------------------------------\n", markdown_to_json.dictify(response))
  120. return dict(markdown_to_json.dictify(response))