### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.10.0
| @@ -317,7 +317,8 @@ class DocumentService(CommonService): | |||
| if 0 <= t.progress < 1: | |||
| finished = False | |||
| prg += t.progress if t.progress >= 0 else 0 | |||
| msg.append(t.progress_msg) | |||
| if t.progress_msg not in msg: | |||
| msg.append(t.progress_msg) | |||
| if t.progress == -1: | |||
| bad += 1 | |||
| prg /= len(tsks) | |||
| @@ -23,16 +23,16 @@ import logging | |||
| import re | |||
| import traceback | |||
| from dataclasses import dataclass | |||
| from typing import Any, List | |||
| from typing import Any, List, Callable | |||
| import networkx as nx | |||
| import pandas as pd | |||
| from graphrag import leiden | |||
| from graphrag.community_report_prompt import COMMUNITY_REPORT_PROMPT | |||
| from graphrag.leiden import add_community_info2graph | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types | |||
| from rag.utils import num_tokens_from_string | |||
| from timeit import default_timer as timer | |||
| log = logging.getLogger(__name__) | |||
| @@ -67,11 +67,14 @@ class CommunityReportsExtractor: | |||
| self._on_error = on_error or (lambda _e, _s, _d: None) | |||
| self._max_report_length = max_report_length or 1500 | |||
| def __call__(self, graph: nx.Graph): | |||
| def __call__(self, graph: nx.Graph, callback: Callable | None = None): | |||
| communities: dict[str, dict[str, List]] = leiden.run(graph, {}) | |||
| total = sum([len(comm.items()) for _, comm in communities.items()]) | |||
| relations_df = pd.DataFrame([{"source":s, "target": t, **attr} for s, t, attr in graph.edges(data=True)]) | |||
| res_str = [] | |||
| res_dict = [] | |||
| over, token_count = 0, 0 | |||
| st = timer() | |||
| for level, comm in communities.items(): | |||
| for cm_id, ents in comm.items(): | |||
| weight = ents["weight"] | |||
| @@ -84,9 +87,10 @@ class CommunityReportsExtractor: | |||
| "relation_df": rela_df.to_csv(index_label="id") | |||
| } | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | |||
| gen_conf = {"temperature": 0.5} | |||
| gen_conf = {"temperature": 0.3} | |||
| try: | |||
| response = self._llm.chat(text, [], gen_conf) | |||
| token_count += num_tokens_from_string(text + response) | |||
| response = re.sub(r"^[^\{]*", "", response) | |||
| response = re.sub(r"[^\}]*$", "", response) | |||
| print(response) | |||
| @@ -108,6 +112,8 @@ class CommunityReportsExtractor: | |||
| add_community_info2graph(graph, ents, response["title"]) | |||
| res_str.append(self._get_text_output(response)) | |||
| res_dict.append(response) | |||
| over += 1 | |||
| if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |||
| return CommunityReportsResult( | |||
| structured_output=res_dict, | |||
| @@ -21,13 +21,14 @@ import numbers | |||
| import re | |||
| import traceback | |||
| from dataclasses import dataclass | |||
| from typing import Any, Mapping | |||
| from typing import Any, Mapping, Callable | |||
| import tiktoken | |||
| from graphrag.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT | |||
| from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, clean_str | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| import networkx as nx | |||
| from rag.utils import num_tokens_from_string | |||
| from timeit import default_timer as timer | |||
| DEFAULT_TUPLE_DELIMITER = "<|>" | |||
| DEFAULT_RECORD_DELIMITER = "##" | |||
| @@ -103,7 +104,9 @@ class GraphExtractor: | |||
| self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1} | |||
| def __call__( | |||
| self, texts: list[str], prompt_variables: dict[str, Any] | None = None | |||
| self, texts: list[str], | |||
| prompt_variables: dict[str, Any] | None = None, | |||
| callback: Callable | None = None | |||
| ) -> GraphExtractionResult: | |||
| """Call method definition.""" | |||
| if prompt_variables is None: | |||
| @@ -127,12 +130,17 @@ class GraphExtractor: | |||
| ), | |||
| } | |||
| st = timer() | |||
| total = len(texts) | |||
| total_token_count = 0 | |||
| for doc_index, text in enumerate(texts): | |||
| try: | |||
| # Invoke the entity extraction | |||
| result = self._process_document(text, prompt_variables) | |||
| result, token_count = self._process_document(text, prompt_variables) | |||
| source_doc_map[doc_index] = text | |||
| all_records[doc_index] = result | |||
| total_token_count += token_count | |||
| if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}") | |||
| except Exception as e: | |||
| logging.exception("error extracting graph") | |||
| self._on_error( | |||
| @@ -162,9 +170,11 @@ class GraphExtractor: | |||
| **prompt_variables, | |||
| self._input_text_key: text, | |||
| } | |||
| token_count = 0 | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=variables) | |||
| gen_conf = {"temperature": 0.5} | |||
| gen_conf = {"temperature": 0.3} | |||
| response = self._llm.chat(text, [], gen_conf) | |||
| token_count = num_tokens_from_string(text + response) | |||
| results = response or "" | |||
| history = [{"role": "system", "content": text}, {"role": "assistant", "content": response}] | |||
| @@ -185,7 +195,7 @@ class GraphExtractor: | |||
| if continuation != "YES": | |||
| break | |||
| return results | |||
| return results, token_count | |||
| def _process_results( | |||
| self, | |||
| @@ -86,7 +86,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent | |||
| for i in range(len(chunks)): | |||
| tkn_cnt = num_tokens_from_string(chunks[i]) | |||
| if cnt+tkn_cnt >= left_token_count and texts: | |||
| threads.append(exe.submit(ext, texts, {"entity_types": entity_types})) | |||
| threads.append(exe.submit(ext, texts, {"entity_types": entity_types}, callback)) | |||
| texts = [] | |||
| cnt = 0 | |||
| texts.append(chunks[i]) | |||
| @@ -98,7 +98,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent | |||
| graphs = [] | |||
| for i, _ in enumerate(threads): | |||
| graphs.append(_.result().output) | |||
| callback(0.5 + 0.1*i/len(threads)) | |||
| callback(0.5 + 0.1*i/len(threads), f"Entities extraction progress ... {i+1}/{len(threads)}") | |||
| graph = reduce(graph_merge, graphs) | |||
| er = EntityResolution(llm_bdl) | |||
| @@ -125,7 +125,7 @@ def build_knowlege_graph_chunks(tenant_id: str, chunks: List[str], callback, ent | |||
| callback(0.6, "Extracting community reports.") | |||
| cr = CommunityReportsExtractor(llm_bdl) | |||
| cr = cr(graph) | |||
| cr = cr(graph, callback=callback) | |||
| for community, desc in zip(cr.structured_output, cr.output): | |||
| chunk = { | |||
| "title_tks": rag_tokenizer.tokenize(community["title"]), | |||
| @@ -138,7 +138,7 @@ class Dealer: | |||
| es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) | |||
| if self.es.getTotal(res) == 0 and "knn" in s: | |||
| bqry, _ = self.qryr.question(qst, min_match="10%") | |||
| bqry = self._add_filters(bqry) | |||
| bqry = self._add_filters(bqry, req) | |||
| s["query"] = bqry.to_dict() | |||
| s["knn"]["filter"] = bqry.to_dict() | |||
| s["knn"]["similarity"] = 0.17 | |||