### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.10.0
| if 0 <= t.progress < 1: | if 0 <= t.progress < 1: | ||||
| finished = False | finished = False | ||||
| prg += t.progress if t.progress >= 0 else 0 | prg += t.progress if t.progress >= 0 else 0 | ||||
| msg.append(t.progress_msg) | |||||
| if t.progress_msg not in msg: | |||||
| msg.append(t.progress_msg) | |||||
| if t.progress == -1: | if t.progress == -1: | ||||
| bad += 1 | bad += 1 | ||||
| prg /= len(tsks) | prg /= len(tsks) |
| import re | import re | ||||
| import traceback | import traceback | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Any, List | |||||
| from typing import Any, List, Callable | |||||
| import networkx as nx | import networkx as nx | ||||
| import pandas as pd | import pandas as pd | ||||
| from graphrag import leiden | from graphrag import leiden | ||||
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT | from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT | ||||
| from graphrag.leiden import add_community_info2graph | from graphrag.leiden import add_community_info2graph | ||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types | from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types | ||||
| from rag.utils import num_tokens_from_string | |||||
| from timeit import default_timer as timer | |||||
| log = logging.getLogger(__name__) | log = logging.getLogger(__name__) | ||||
| self._on_error = on_error or (lambda _e, _s, _d: None) | self._on_error = on_error or (lambda _e, _s, _d: None) | ||||
| self._max_report_length = max_report_length or 1500 | self._max_report_length = max_report_length or 1500 | ||||
| def __call__(self, graph: nx.Graph): | |||||
| def __call__(self, graph: nx.Graph, callback: Callable | None = None): | |||||
| communities: dict[str, dict[str, List]] = leiden.run(graph, {}) | communities: dict[str, dict[str, List]] = leiden.run(graph, {}) | ||||
| total = sum([len(comm.items()) for _, comm in communities.items()]) | |||||
| relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) | relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) | ||||
| res_str = [] | res_str = [] | ||||
| res_dict = [] | res_dict = [] | ||||
| over, token_count = 0, 0 | |||||
| st = timer() | |||||
| for level, comm in communities.items(): | for level, comm in communities.items(): | ||||
| for cm_id, ents in comm.items(): | for cm_id, ents in comm.items(): | ||||
| weight = ents["weight"] | weight = ents["weight"] | ||||
| "relation_df": rela_df.to_csv(index_label="id") | "relation_df": rela_df.to_csv(index_label="id") | ||||
| } | } | ||||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | ||||
| gen_conf = {"temperature": 0.5} | |||||
| gen_conf = {"temperature": 0.3} | |||||
| try: | try: | ||||
| response = self._llm.chat(text, [], gen_conf) | response = self._llm.chat(text, [], gen_conf) | ||||
| token_count += num_tokens_from_string(text + response) | |||||
| response = re.sub(r"^[^\{]*", "", response) | response = re.sub(r"^[^\{]*", "", response) | ||||
| response = re.sub(r"[^\}]*$", "", response) | response = re.sub(r"[^\}]*$", "", response) | ||||
| print(response) | print(response) | ||||
| add_community_info2graph(graph, ents, response["title"]) | add_community_info2graph(graph, ents, response["title"]) | ||||
| res_str.append(self._get_text_output(response)) | res_str.append(self._get_text_output(response)) | ||||
| res_dict.append(response) | res_dict.append(response) | ||||
| over += 1 | |||||
| if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |||||
| return CommunityReportsResult( | return CommunityReportsResult( | ||||
| structured_output=res_dict, | structured_output=res_dict, |
| import re | import re | ||||
| import traceback | import traceback | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| from typing import Any, Mapping | |||||
| from typing import Any, Mapping, Callable | |||||
| import tiktoken | import tiktoken | ||||
| from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | ||||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str | from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str | ||||
| from rag.llm.chat_model import Base as CompletionLLM | from rag.llm.chat_model import Base as CompletionLLM | ||||
| import networkx as nx | import networkx as nx | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| from timeit import default_timer as timer | |||||
| DEFAULT_TUPLE_DELIMITER = "<|>" | DEFAULT_TUPLE_DELIMITER = "<|>" | ||||
| DEFAULT_RECORD_DELIMITER = "##" | DEFAULT_RECORD_DELIMITER = "##" | ||||
| self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} | self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} | ||||
| def __call__( | def __call__( | ||||
| self, texts: list[str], prompt_variables: dict[str, Any] | None = None | |||||
| self, texts: list[str], | |||||
| prompt_variables: dict[str, Any] | None = None, | |||||
| callback: Callable | None = None | |||||
| ) -> GraphExtractionResult: | ) -> GraphExtractionResult: | ||||
| """Call method definition.""" | """Call method definition.""" | ||||
| if prompt_variables is None: | if prompt_variables is None: | ||||
| ), | ), | ||||
| } | } | ||||
| st = timer() | |||||
| total = len(texts) | |||||
| total_token_count = 0 | |||||
| for doc_index, text in enumerate(texts): | for doc_index, text in enumerate(texts): | ||||
| try: | try: | ||||
| # Invoke the entity extraction | # Invoke the entity extraction | ||||
| result = self._process_document(text, prompt_variables) | |||||
| result, token_count = self._process_document(text, prompt_variables) | |||||
| source_doc_map[doc_index] = text | source_doc_map[doc_index] = text | ||||
| all_records[doc_index] = result | all_records[doc_index] = result | ||||
| total_token_count += token_count | |||||
| if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("error extracting graph") | logging.exception("error extracting graph") | ||||
| self._on_error( | self._on_error( | ||||
| **prompt_variables, | **prompt_variables, | ||||
| self._input_text_key: text, | self._input_text_key: text, | ||||
| } | } | ||||
| token_count = 0 | |||||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | text = perform_variable_replacements(self._extraction_prompt, variables=variables) | ||||
| gen_conf = {"temperature": 0.5} | |||||
| gen_conf = {"temperature": 0.3} | |||||
| response = self._llm.chat(text, [], gen_conf) | response = self._llm.chat(text, [], gen_conf) | ||||
| token_count = num_tokens_from_string(text + response) | |||||
| results = response or "" | results = response or "" | ||||
| history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] | history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] | ||||
| if continuation != "YES": | if continuation != "YES": | ||||
| break | break | ||||
| return results | |||||
| return results, token_count | |||||
| def _process_results( | def _process_results( | ||||
| self, | self, |
| for i in range(len(chunks)): | for i in range(len(chunks)): | ||||
| tkn_cnt = num_tokens_from_string(chunks[i]) | tkn_cnt = num_tokens_from_string(chunks[i]) | ||||
| if cnt+tkn_cnt >= left_token_count and texts: | if cnt+tkn_cnt >= left_token_count and texts: | ||||
| threads.append(exe.submit(ext, texts, {"entity_types": entity_types})) | |||||
| threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback)) | |||||
| texts = [] | texts = [] | ||||
| cnt = 0 | cnt = 0 | ||||
| texts.append(chunks[i]) | texts.append(chunks[i]) | ||||
| graphs = [] | graphs = [] | ||||
| for i, _ in enumerate(threads): | for i, _ in enumerate(threads): | ||||
| graphs.append(_.result().output) | graphs.append(_.result().output) | ||||
| callback(0.5 + 0.1*i/len(threads)) | |||||
| callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") | |||||
| graph = reduce(graph_merge, graphs) | graph = reduce(graph_merge, graphs) | ||||
| er = EntityResolution(llm_bdl) | er = EntityResolution(llm_bdl) | ||||
| callback(0.6, "Extracting community reports.") | callback(0.6, "Extracting community reports.") | ||||
| cr = CommunityReportsExtractor(llm_bdl) | cr = CommunityReportsExtractor(llm_bdl) | ||||
| cr = cr(graph) | |||||
| cr = cr(graph, callback=callback) | |||||
| for community, desc in zip(cr.structured_output, cr.output): | for community, desc in zip(cr.structured_output, cr.output): | ||||
| chunk = { | chunk = { | ||||
| "title_tks": rag_tokenizer.tokenize(community["title"]), | "title_tks": rag_tokenizer.tokenize(community["title"]), |
| es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) | es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) | ||||
| if self.es.getTotal(res) == 0 and "knn" in s: | if self.es.getTotal(res) == 0 and "knn" in s: | ||||
| bqry, _ = self.qryr.question(qst, min_match="10%") | bqry, _ = self.qryr.question(qst, min_match="10%") | ||||
| bqry = self._add_filters(bqry) | |||||
| bqry = self._add_filters(bqry, req) | |||||
| s["query"] = bqry.to_dict() | s["query"] = bqry.to_dict() | ||||
| s["knn"]["filter"] = bqry.to_dict() | s["knn"]["filter"] = bqry.to_dict() | ||||
| s["knn"]["similarity"] = 0.17 | s["knn"]["similarity"] = 0.17 |