| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import logging
- import networkx as nx
- import trio
-
- from api import settings
- from api.utils import get_uuid
- from api.utils.api_utils import timeout
- from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
- from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
- from graphrag.general.community_reports_extractor import CommunityReportsExtractor
- from graphrag.entity_resolution import EntityResolution
- from graphrag.general.extractor import Extractor
- from graphrag.utils import (
- graph_merge,
- get_graph,
- set_graph,
- chunk_id,
- does_graph_contains,
- tidy_graph,
- GraphChange,
- )
- from rag.nlp import rag_tokenizer, search
- from rag.utils.redis_conn import RedisDistributedLock
-
-
-
- async def run_graphrag(
- row: dict,
- language,
- with_resolution: bool,
- with_community: bool,
- chat_model,
- embedding_model,
- callback,
- ):
- start = trio.current_time()
- tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
- chunks = []
- for d in settings.retrievaler.chunk_list(
- doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
- ):
- chunks.append(d["content_with_weight"])
-
- with trio.fail_after(max(120, len(chunks)*120)):
- subgraph = await generate_subgraph(
- LightKGExt
- if "method" not in row["kb_parser_config"].get("graphrag", {}) or row["kb_parser_config"]["graphrag"]["method"] != "general"
- else GeneralKGExt,
- tenant_id,
- kb_id,
- doc_id,
- chunks,
- language,
- row["kb_parser_config"]["graphrag"].get("entity_types", []),
- chat_model,
- embedding_model,
- callback,
- )
-
- if not subgraph:
- return
-
- graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
- await graphrag_task_lock.spin_acquire()
- callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
-
- try:
- subgraph_nodes = set(subgraph.nodes())
- new_graph = await merge_subgraph(
- tenant_id,
- kb_id,
- doc_id,
- subgraph,
- embedding_model,
- callback,
- )
- assert new_graph is not None
-
- if not with_resolution and not with_community:
- return
-
- if with_resolution:
- await graphrag_task_lock.spin_acquire()
- callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
- await resolve_entities(
- new_graph,
- subgraph_nodes,
- tenant_id,
- kb_id,
- doc_id,
- chat_model,
- embedding_model,
- callback,
- )
- if with_community:
- await graphrag_task_lock.spin_acquire()
- callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
- await extract_community(
- new_graph,
- tenant_id,
- kb_id,
- doc_id,
- chat_model,
- embedding_model,
- callback,
- )
- finally:
- graphrag_task_lock.release()
- now = trio.current_time()
- callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
- return
-
-
- async def generate_subgraph(
- extractor: Extractor,
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- chunks: list[str],
- language,
- entity_types,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- contains = await does_graph_contains(tenant_id, kb_id, doc_id)
- if contains:
- callback(msg=f"Graph already contains {doc_id}")
- return None
- start = trio.current_time()
- ext = extractor(
- llm_bdl,
- language=language,
- entity_types=entity_types,
- )
- ents, rels = await ext(doc_id, chunks, callback)
- subgraph = nx.Graph()
- for ent in ents:
- assert "description" in ent, f"entity {ent} does not have description"
- ent["source_id"] = [doc_id]
- subgraph.add_node(ent["entity_name"], **ent)
-
- ignored_rels = 0
- for rel in rels:
- assert "description" in rel, f"relation {rel} does not have description"
- if not subgraph.has_node(rel["src_id"]) or not subgraph.has_node(rel["tgt_id"]):
- ignored_rels += 1
- continue
- rel["source_id"] = [doc_id]
- subgraph.add_edge(
- rel["src_id"],
- rel["tgt_id"],
- **rel,
- )
- if ignored_rels:
- callback(msg=f"ignored {ignored_rels} relations due to missing entities.")
- tidy_graph(subgraph, callback, check_attribute=False)
-
- subgraph.graph["source_id"] = [doc_id]
- chunk = {
- "content_with_weight": json.dumps(
- nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False
- ),
- "knowledge_graph_kwd": "subgraph",
- "kb_id": kb_id,
- "source_id": [doc_id],
- "available_int": 0,
- "removed_kwd": "N",
- }
- cid = chunk_id(chunk)
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {"knowledge_graph_kwd": "subgraph", "source_id": doc_id}, search.index_name(tenant_id), kb_id
- )
- )
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.insert(
- [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
- )
- )
- now = trio.current_time()
- callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
- return subgraph
-
-
- @timeout(60*3)
- async def merge_subgraph(
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- subgraph: nx.Graph,
- embedding_model,
- callback,
- ):
- start = trio.current_time()
- change = GraphChange()
- old_graph = await get_graph(tenant_id, kb_id, subgraph.graph["source_id"])
- if old_graph is not None:
- logging.info("Merge with an exiting graph...................")
- tidy_graph(old_graph, callback)
- new_graph = graph_merge(old_graph, subgraph, change)
- else:
- new_graph = subgraph
- change.added_updated_nodes = set(new_graph.nodes())
- change.added_updated_edges = set(new_graph.edges())
- pr = nx.pagerank(new_graph)
- for node_name, pagerank in pr.items():
- new_graph.nodes[node_name]["pagerank"] = pagerank
-
- await set_graph(tenant_id, kb_id, embedding_model, new_graph, change, callback)
- now = trio.current_time()
- callback(
- msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
- )
- return new_graph
-
-
- @timeout(60*30, 1)
- async def resolve_entities(
- graph,
- subgraph_nodes: set[str],
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- start = trio.current_time()
- er = EntityResolution(
- llm_bdl,
- )
- reso = await er(graph, subgraph_nodes, callback=callback)
- graph = reso.graph
- change = reso.change
- callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
- callback(msg="Graph resolution updated pagerank.")
-
- await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
- now = trio.current_time()
- callback(msg=f"Graph resolution done in {now - start:.2f}s.")
-
-
- @timeout(60*30, 1)
- async def extract_community(
- graph,
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- start = trio.current_time()
- ext = CommunityReportsExtractor(
- llm_bdl,
- )
- cr = await ext(graph, callback=callback)
- community_structure = cr.structured_output
- community_reports = cr.output
- doc_ids = graph.graph["source_id"]
-
- now = trio.current_time()
- callback(
- msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
- )
- start = now
- chunks = []
- for stru, rep in zip(community_structure, community_reports):
- obj = {
- "report": rep,
- "evidences": "\n".join([f.get("explanation", "") for f in stru["findings"]]),
- }
- chunk = {
- "id": get_uuid(),
- "docnm_kwd": stru["title"],
- "title_tks": rag_tokenizer.tokenize(stru["title"]),
- "content_with_weight": json.dumps(obj, ensure_ascii=False),
- "content_ltks": rag_tokenizer.tokenize(
- obj["report"] + " " + obj["evidences"]
- ),
- "knowledge_graph_kwd": "community_report",
- "weight_flt": stru["weight"],
- "entities_kwd": stru["entities"],
- "important_kwd": stru["entities"],
- "kb_id": kb_id,
- "source_id": list(doc_ids),
- "available_int": 0,
- }
- chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
- chunk["content_ltks"]
- )
- chunks.append(chunk)
-
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
- search.index_name(tenant_id),
- kb_id,
- )
- )
- es_bulk_size = 4
- for b in range(0, len(chunks), es_bulk_size):
- doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
- if doc_store_result:
- error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
- raise Exception(error_message)
-
- now = trio.current_time()
- callback(
- msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
- )
- return community_structure, community_reports
|