| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 | 
							- #
 - #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - import collections
 - import logging
 - import re
 - import logging
 - import traceback
 - from concurrent.futures import ThreadPoolExecutor
 - from dataclasses import dataclass
 - from typing import Any
 - 
 - from graphrag.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
 - from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
 - from rag.llm.chat_model import Base as CompletionLLM
 - import markdown_to_json
 - from functools import reduce
 - from rag.utils import num_tokens_from_string
 - 
 - 
 - @dataclass
 - class MindMapResult:
 -     """Unipartite Mind Graph result class definition."""
 -     output: dict
 - 
 - 
 - class MindMapExtractor:
 - 
 -     _llm: CompletionLLM
 -     _input_text_key: str
 -     _mind_map_prompt: str
 -     _on_error: ErrorHandlerFn
 - 
 -     def __init__(
 -             self,
 -             llm_invoker: CompletionLLM,
 -             prompt: str | None = None,
 -             input_text_key: str | None = None,
 -             on_error: ErrorHandlerFn | None = None,
 -     ):
 -         """Init method definition."""
 -         # TODO: streamline construction
 -         self._llm = llm_invoker
 -         self._input_text_key = input_text_key or "input_text"
 -         self._mind_map_prompt = prompt or MIND_MAP_EXTRACTION_PROMPT
 -         self._on_error = on_error or (lambda _e, _s, _d: None)
 - 
 -     def _key(self, k):
 -         return re.sub(r"\*+", "", k)
 - 
 -     def _be_children(self, obj: dict, keyset: set):
 -         if isinstance(obj, str):
 -             obj = [obj]
 -         if isinstance(obj, list):
 -             for i in obj: keyset.add(i)
 -             return [{"id": re.sub(r"\*+", "", i), "children": []} for i in obj]
 -         arr = []
 -         for k, v in obj.items():
 -             k = self._key(k)
 -             if not k or k in keyset: continue
 -             keyset.add(k)
 -             arr.append({
 -                 "id": k,
 -                 "children": self._be_children(v, keyset)
 -             })
 -         return arr
 - 
 -     def __call__(
 -             self, sections: list[str], prompt_variables: dict[str, Any] | None = None
 -     ) -> MindMapResult:
 -         """Call method definition."""
 -         if prompt_variables is None:
 -             prompt_variables = {}
 - 
 -         try:
 -             exe = ThreadPoolExecutor(max_workers=12)
 -             threads = []
 -             token_count = max(self._llm.max_length * 0.8, self._llm.max_length-512)
 -             texts = []
 -             res = []
 -             cnt = 0
 -             for i in range(len(sections)):
 -                 section_cnt = num_tokens_from_string(sections[i])
 -                 if cnt + section_cnt >= token_count and texts:
 -                     threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
 -                     texts = []
 -                     cnt = 0
 -                 texts.append(sections[i])
 -                 cnt += section_cnt
 -             if texts:
 -                 threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
 - 
 -             for i, _ in enumerate(threads):
 -                 res.append(_.result())
 - 
 -             merge_json = reduce(self._merge, res)
 -             if len(merge_json.keys()) > 1:
 -                 keyset = set(
 -                     [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict) and re.sub(r"\*+", "", k)])
 -                 merge_json = {"id": "root",
 -                           "children": [{"id": self._key(k), "children": self._be_children(v, keyset)} for k, v in
 -                                        merge_json.items() if isinstance(v, dict) and self._key(k)]}
 -             else:
 -                 k = self._key(list(merge_json.keys())[0])
 -                 merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], set([k]))}
 - 
 -         except Exception as e:
 -             logging.exception("error mind graph")
 -             self._on_error(
 -                 e,
 -                 traceback.format_exc(), None
 -             )
 -             merge_json = {"error": str(e)}
 - 
 -         return MindMapResult(output=merge_json)
 - 
 -     def _merge(self, d1, d2):
 -         for k in d1:
 -             if k in d2:
 -                 if isinstance(d1[k], dict) and isinstance(d2[k], dict):
 -                     self._merge(d1[k], d2[k])
 -                 elif isinstance(d1[k], list) and isinstance(d2[k], list):
 -                     d2[k].extend(d1[k])
 -                 else:
 -                     d2[k] = d1[k]
 -             else:
 -                 d2[k] = d1[k]
 - 
 -         return d2
 - 
 -     def _list_to_kv(self, data):
 -         for key, value in data.items():
 -             if isinstance(value, dict):
 -                 self._list_to_kv(value)
 -             elif isinstance(value, list):
 -                 new_value = {}
 -                 for i in range(len(value)):
 -                     if isinstance(value[i], list):
 -                         new_value[value[i - 1]] = value[i][0]
 -                 data[key] = new_value
 -             else:
 -                 continue
 -         return data
 - 
 -     def _todict(self, layer:collections.OrderedDict):
 -         to_ret = layer
 -         if isinstance(layer, collections.OrderedDict):
 -             to_ret = dict(layer)
 - 
 -         try:
 -             for key, value in to_ret.items():
 -                 to_ret[key] = self._todict(value)
 -         except AttributeError:
 -             pass
 - 
 -         return self._list_to_kv(to_ret)
 - 
 -     def _process_document(
 -             self, text: str, prompt_variables: dict[str, str]
 -     ) -> str:
 -         variables = {
 -             **prompt_variables,
 -             self._input_text_key: text,
 -         }
 -         text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
 -         gen_conf = {"temperature": 0.5}
 -         response = self._llm.chat(text, [], gen_conf)
 -         response = re.sub(r"```[^\n]*", "", response)
 -         print(response)
 -         print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response)))
 -         return self._todict(markdown_to_json.dictify(response))
 
 
  |