### What problem does this PR solve? Optimized graphrag again ### Type of change - [x] Performance Improvementtags/v0.17.2
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import itertools | |||
| import re | |||
| import time | |||
| @@ -67,7 +68,7 @@ class EntityResolution(Extractor): | |||
| self._resolution_result_delimiter_key = "resolution_result_delimiter" | |||
| self._input_text_key = "input_text" | |||
| async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult: | |||
| async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult: | |||
| """Call method definition.""" | |||
| if prompt_variables is None: | |||
| prompt_variables = {} | |||
| @@ -93,6 +94,8 @@ class EntityResolution(Extractor): | |||
| candidate_resolution = {entity_type: [] for entity_type in entity_types} | |||
| for k, v in node_clusters.items(): | |||
| candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)] | |||
| num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()]) | |||
| callback(msg=f"Identified {num_candidates} candidate pairs") | |||
| resolution_result = set() | |||
| async with trio.open_nursery() as nursery: | |||
| @@ -100,48 +103,52 @@ class EntityResolution(Extractor): | |||
| if not candidate_resolution_i[1]: | |||
| continue | |||
| nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result)) | |||
| callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.") | |||
| connect_graph = nx.Graph() | |||
| removed_entities = [] | |||
| connect_graph.add_edges_from(resolution_result) | |||
| all_entities_data = [] | |||
| all_relationships_data = [] | |||
| all_remove_nodes = [] | |||
| for sub_connect_graph in nx.connected_components(connect_graph): | |||
| sub_connect_graph = connect_graph.subgraph(sub_connect_graph) | |||
| remove_nodes = list(sub_connect_graph.nodes) | |||
| keep_node = remove_nodes.pop() | |||
| await self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data) | |||
| for remove_node in remove_nodes: | |||
| removed_entities.append(remove_node) | |||
| remove_node_neighbors = graph[remove_node] | |||
| remove_node_neighbors = list(remove_node_neighbors) | |||
| for remove_node_neighbor in remove_node_neighbors: | |||
| rel = self._get_relation_(remove_node, remove_node_neighbor) | |||
| if graph.has_edge(remove_node, remove_node_neighbor): | |||
| graph.remove_edge(remove_node, remove_node_neighbor) | |||
| if remove_node_neighbor == keep_node: | |||
| if graph.has_edge(keep_node, remove_node): | |||
| graph.remove_edge(keep_node, remove_node) | |||
| continue | |||
| if not rel: | |||
| continue | |||
| if graph.has_edge(keep_node, remove_node_neighbor): | |||
| await self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data) | |||
| else: | |||
| pair = sorted([keep_node, remove_node_neighbor]) | |||
| graph.add_edge(pair[0], pair[1], weight=rel['weight']) | |||
| self._set_relation_(pair[0], pair[1], | |||
| dict( | |||
| src_id=pair[0], | |||
| tgt_id=pair[1], | |||
| weight=rel['weight'], | |||
| description=rel['description'], | |||
| keywords=[], | |||
| source_id=rel.get("source_id", ""), | |||
| metadata={"created_at": time.time()} | |||
| )) | |||
| graph.remove_node(remove_node) | |||
| async with trio.open_nursery() as nursery: | |||
| for sub_connect_graph in nx.connected_components(connect_graph): | |||
| sub_connect_graph = connect_graph.subgraph(sub_connect_graph) | |||
| remove_nodes = list(sub_connect_graph.nodes) | |||
| keep_node = remove_nodes.pop() | |||
| all_remove_nodes.append(remove_nodes) | |||
| nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data)) | |||
| for remove_node in remove_nodes: | |||
| removed_entities.append(remove_node) | |||
| remove_node_neighbors = graph[remove_node] | |||
| remove_node_neighbors = list(remove_node_neighbors) | |||
| for remove_node_neighbor in remove_node_neighbors: | |||
| rel = self._get_relation_(remove_node, remove_node_neighbor) | |||
| if graph.has_edge(remove_node, remove_node_neighbor): | |||
| graph.remove_edge(remove_node, remove_node_neighbor) | |||
| if remove_node_neighbor == keep_node: | |||
| if graph.has_edge(keep_node, remove_node): | |||
| graph.remove_edge(keep_node, remove_node) | |||
| continue | |||
| if not rel: | |||
| continue | |||
| if graph.has_edge(keep_node, remove_node_neighbor): | |||
| nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data)) | |||
| else: | |||
| pair = sorted([keep_node, remove_node_neighbor]) | |||
| graph.add_edge(pair[0], pair[1], weight=rel['weight']) | |||
| self._set_relation_(pair[0], pair[1], | |||
| dict( | |||
| src_id=pair[0], | |||
| tgt_id=pair[1], | |||
| weight=rel['weight'], | |||
| description=rel['description'], | |||
| keywords=[], | |||
| source_id=rel.get("source_id", ""), | |||
| metadata={"created_at": time.time()} | |||
| )) | |||
| graph.remove_node(remove_node) | |||
| return EntityResolutionResult( | |||
| graph=graph, | |||
| @@ -164,8 +171,10 @@ class EntityResolution(Extractor): | |||
| self._input_text_key: pair_prompt | |||
| } | |||
| text = perform_variable_replacements(self._resolution_prompt, variables=variables) | |||
| logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}") | |||
| async with chat_limiter: | |||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||
| logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}") | |||
| result = self._process_results(len(candidate_resolution_i[1]), response, | |||
| self.prompt_variables.get(self._record_delimiter_key, | |||
| DEFAULT_RECORD_DELIMITER), | |||
| @@ -19,7 +19,6 @@ from graphrag.general.leiden import add_community_info2graph | |||
| from rag.llm.chat_model import Base as CompletionLLM | |||
| from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter | |||
| from rag.utils import num_tokens_from_string | |||
| from timeit import default_timer as timer | |||
| import trio | |||
| @@ -62,62 +61,69 @@ class CommunityReportsExtractor(Extractor): | |||
| res_str = [] | |||
| res_dict = [] | |||
| over, token_count = 0, 0 | |||
| st = timer() | |||
| for level, comm in communities.items(): | |||
| logging.info(f"Level {level}: Community: {len(comm.keys())}") | |||
| for cm_id, ents in comm.items(): | |||
| weight = ents["weight"] | |||
| ents = ents["nodes"] | |||
| ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents]) | |||
| if ent_df.empty or "entity_name" not in ent_df.columns: | |||
| continue | |||
| ent_df["entity"] = ent_df["entity_name"] | |||
| del ent_df["entity_name"] | |||
| rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) | |||
| if rela_df.empty: | |||
| continue | |||
| rela_df["source"] = rela_df["src_id"] | |||
| rela_df["target"] = rela_df["tgt_id"] | |||
| del rela_df["src_id"] | |||
| del rela_df["tgt_id"] | |||
| prompt_variables = { | |||
| "entity_df": ent_df.to_csv(index_label="id"), | |||
| "relation_df": rela_df.to_csv(index_label="id") | |||
| } | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | |||
| gen_conf = {"temperature": 0.3} | |||
| async with chat_limiter: | |||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||
| token_count += num_tokens_from_string(text + response) | |||
| response = re.sub(r"^[^\{]*", "", response) | |||
| response = re.sub(r"[^\}]*$", "", response) | |||
| response = re.sub(r"\{\{", "{", response) | |||
| response = re.sub(r"\}\}", "}", response) | |||
| logging.debug(response) | |||
| try: | |||
| response = json.loads(response) | |||
| except json.JSONDecodeError as e: | |||
| logging.error(f"Failed to parse JSON response: {e}") | |||
| logging.error(f"Response content: {response}") | |||
| continue | |||
| if not dict_has_keys_with_types(response, [ | |||
| ("title", str), | |||
| ("summary", str), | |||
| ("findings", list), | |||
| ("rating", float), | |||
| ("rating_explanation", str), | |||
| ]): | |||
| continue | |||
| response["weight"] = weight | |||
| response["entities"] = ents | |||
| add_community_info2graph(graph, ents, response["title"]) | |||
| res_str.append(self._get_text_output(response)) | |||
| res_dict.append(response) | |||
| over += 1 | |||
| if callback: | |||
| callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}") | |||
| async def extract_community_report(community): | |||
| nonlocal res_str, res_dict, over, token_count | |||
| cm_id, ents = community | |||
| weight = ents["weight"] | |||
| ents = ents["nodes"] | |||
| ent_df = pd.DataFrame(self._get_entity_(ents)).dropna() | |||
| if ent_df.empty or "entity_name" not in ent_df.columns: | |||
| return | |||
| ent_df["entity"] = ent_df["entity_name"] | |||
| del ent_df["entity_name"] | |||
| rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000)) | |||
| if rela_df.empty: | |||
| return | |||
| rela_df["source"] = rela_df["src_id"] | |||
| rela_df["target"] = rela_df["tgt_id"] | |||
| del rela_df["src_id"] | |||
| del rela_df["tgt_id"] | |||
| prompt_variables = { | |||
| "entity_df": ent_df.to_csv(index_label="id"), | |||
| "relation_df": rela_df.to_csv(index_label="id") | |||
| } | |||
| text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables) | |||
| gen_conf = {"temperature": 0.3} | |||
| async with chat_limiter: | |||
| response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)) | |||
| token_count += num_tokens_from_string(text + response) | |||
| response = re.sub(r"^[^\{]*", "", response) | |||
| response = re.sub(r"[^\}]*$", "", response) | |||
| response = re.sub(r"\{\{", "{", response) | |||
| response = re.sub(r"\}\}", "}", response) | |||
| logging.debug(response) | |||
| try: | |||
| response = json.loads(response) | |||
| except json.JSONDecodeError as e: | |||
| logging.error(f"Failed to parse JSON response: {e}") | |||
| logging.error(f"Response content: {response}") | |||
| return | |||
| if not dict_has_keys_with_types(response, [ | |||
| ("title", str), | |||
| ("summary", str), | |||
| ("findings", list), | |||
| ("rating", float), | |||
| ("rating_explanation", str), | |||
| ]): | |||
| return | |||
| response["weight"] = weight | |||
| response["entities"] = ents | |||
| add_community_info2graph(graph, ents, response["title"]) | |||
| res_str.append(self._get_text_output(response)) | |||
| res_dict.append(response) | |||
| over += 1 | |||
| if callback: | |||
| callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}") | |||
| st = trio.current_time() | |||
| async with trio.open_nursery() as nursery: | |||
| for level, comm in communities.items(): | |||
| logging.info(f"Level {level}: Community: {len(comm.keys())}") | |||
| for community in comm.items(): | |||
| nursery.start_soon(lambda: extract_community_report(community)) | |||
| if callback: | |||
| callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}") | |||
| return CommunityReportsResult( | |||
| structured_output=res_dict, | |||
| @@ -228,7 +228,7 @@ async def resolve_entities( | |||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||
| ) | |||
| reso = await er(graph) | |||
| reso = await er(graph, callback=callback) | |||
| graph = reso.graph | |||
| callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") | |||
| await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) | |||
| @@ -489,15 +489,16 @@ async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||
| return nbrs | |||
| pr = nx.pagerank(graph) | |||
| for n, p in pr.items(): | |||
| graph.nodes[n]["pagerank"] = p | |||
| try: | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||
| {"rank_flt": p, | |||
| "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, | |||
| search.index_name(tenant_id), kb_id)) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| try: | |||
| async with trio.open_nursery() as nursery: | |||
| for n, p in pr.items(): | |||
| graph.nodes[n]["pagerank"] = p | |||
| nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||
| {"rank_flt": p, | |||
| "n_hop_with_weight": json.dumps((n), ensure_ascii=False)}, | |||
| search.index_name(tenant_id), kb_id))) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| ty2ents = defaultdict(list) | |||
| for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True): | |||