| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- #
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import logging
- from functools import partial
- import networkx as nx
- import trio
-
- from api import settings
- from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
- from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
- from graphrag.general.community_reports_extractor import CommunityReportsExtractor
- from graphrag.entity_resolution import EntityResolution
- from graphrag.general.extractor import Extractor
- from graphrag.utils import (
- graph_merge,
- set_entity,
- get_relation,
- set_relation,
- get_entity,
- get_graph,
- set_graph,
- chunk_id,
- update_nodes_pagerank_nhop_neighbour,
- does_graph_contains,
- get_graph_doc_ids,
- )
- from rag.nlp import rag_tokenizer, search
- from rag.utils.redis_conn import REDIS_CONN
-
-
- def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
- key = f"graphrag:{tenant_id}:{kb_id}"
- ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
- if not ok:
- raise Exception(f"Faild to set the {key} to {doc_id}")
-
-
- def graphrag_task_get(tenant_id, kb_id) -> str | None:
- key = f"graphrag:{tenant_id}:{kb_id}"
- doc_id = REDIS_CONN.get(key)
- return doc_id
-
-
- async def run_graphrag(
- row: dict,
- language,
- with_resolution: bool,
- with_community: bool,
- chat_model,
- embedding_model,
- callback,
- ):
- start = trio.current_time()
- tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
- chunks = []
- for d in settings.retrievaler.chunk_list(
- doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
- ):
- chunks.append(d["content_with_weight"])
-
- graph, doc_ids = await update_graph(
- LightKGExt
- if row["parser_config"]["graphrag"]["method"] != "general"
- else GeneralKGExt,
- tenant_id,
- kb_id,
- doc_id,
- chunks,
- language,
- row["parser_config"]["graphrag"]["entity_types"],
- chat_model,
- embedding_model,
- callback,
- )
- if not graph:
- return
- if with_resolution or with_community:
- graphrag_task_set(tenant_id, kb_id, doc_id)
- if with_resolution:
- await resolve_entities(
- graph,
- doc_ids,
- tenant_id,
- kb_id,
- doc_id,
- chat_model,
- embedding_model,
- callback,
- )
- if with_community:
- await extract_community(
- graph,
- doc_ids,
- tenant_id,
- kb_id,
- doc_id,
- chat_model,
- embedding_model,
- callback,
- )
- now = trio.current_time()
- callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
- return
-
-
- async def update_graph(
- extractor: Extractor,
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- chunks: list[str],
- language,
- entity_types,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- contains = await does_graph_contains(tenant_id, kb_id, doc_id)
- if contains:
- callback(msg=f"Graph already contains {doc_id}, cancel myself")
- return None, None
- start = trio.current_time()
- ext = extractor(
- llm_bdl,
- language=language,
- entity_types=entity_types,
- get_entity=partial(get_entity, tenant_id, kb_id),
- set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
- get_relation=partial(get_relation, tenant_id, kb_id),
- set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
- )
- ents, rels = await ext(doc_id, chunks, callback)
- subgraph = nx.Graph()
- for en in ents:
- subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
-
- for rel in rels:
- subgraph.add_edge(
- rel["src_id"],
- rel["tgt_id"],
- weight=rel["weight"],
- # description=rel["description"]
- )
- # TODO: infinity doesn't support array search
- chunk = {
- "content_with_weight": json.dumps(
- nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
- ),
- "knowledge_graph_kwd": "subgraph",
- "kb_id": kb_id,
- "source_id": [doc_id],
- "available_int": 0,
- "removed_kwd": "N",
- }
- cid = chunk_id(chunk)
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.insert(
- [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
- )
- )
- now = trio.current_time()
- callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
- start = now
-
- while True:
- new_graph = subgraph
- now_docids = set([doc_id])
- old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
- if old_graph is not None:
- logging.info("Merge with an exiting graph...................")
- new_graph = graph_merge(old_graph, subgraph)
- await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
- if old_doc_ids:
- for old_doc_id in old_doc_ids:
- now_docids.add(old_doc_id)
- old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
- delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
- if delta_doc_ids:
- callback(
- msg="The global graph has changed during merging, try again"
- )
- await trio.sleep(1)
- continue
- break
- await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
- now = trio.current_time()
- callback(
- msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
- )
- return new_graph, now_docids
-
-
- async def resolve_entities(
- graph,
- doc_ids,
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- working_doc_id = graphrag_task_get(tenant_id, kb_id)
- if doc_id != working_doc_id:
- callback(
- msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
- )
- return
- start = trio.current_time()
- er = EntityResolution(
- llm_bdl,
- get_entity=partial(get_entity, tenant_id, kb_id),
- set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
- get_relation=partial(get_relation, tenant_id, kb_id),
- set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
- )
- reso = await er(graph)
- graph = reso.graph
- callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
- await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
- callback(msg="Graph resolution updated pagerank.")
-
- working_doc_id = graphrag_task_get(tenant_id, kb_id)
- if doc_id != working_doc_id:
- callback(
- msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
- )
- return
- await set_graph(tenant_id, kb_id, graph, doc_ids)
-
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {
- "knowledge_graph_kwd": "relation",
- "kb_id": kb_id,
- "from_entity_kwd": reso.removed_entities,
- },
- search.index_name(tenant_id),
- kb_id,
- )
- )
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {
- "knowledge_graph_kwd": "relation",
- "kb_id": kb_id,
- "to_entity_kwd": reso.removed_entities,
- },
- search.index_name(tenant_id),
- kb_id,
- )
- )
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {
- "knowledge_graph_kwd": "entity",
- "kb_id": kb_id,
- "entity_kwd": reso.removed_entities,
- },
- search.index_name(tenant_id),
- kb_id,
- )
- )
- now = trio.current_time()
- callback(msg=f"Graph resolution done in {now - start:.2f}s.")
-
-
- async def extract_community(
- graph,
- doc_ids,
- tenant_id: str,
- kb_id: str,
- doc_id: str,
- llm_bdl,
- embed_bdl,
- callback,
- ):
- working_doc_id = graphrag_task_get(tenant_id, kb_id)
- if doc_id != working_doc_id:
- callback(
- msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
- )
- return
- start = trio.current_time()
- ext = CommunityReportsExtractor(
- llm_bdl,
- get_entity=partial(get_entity, tenant_id, kb_id),
- set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
- get_relation=partial(get_relation, tenant_id, kb_id),
- set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
- )
- cr = await ext(graph, callback=callback)
- community_structure = cr.structured_output
- community_reports = cr.output
- working_doc_id = graphrag_task_get(tenant_id, kb_id)
- if doc_id != working_doc_id:
- callback(
- msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
- )
- return
- await set_graph(tenant_id, kb_id, graph, doc_ids)
-
- now = trio.current_time()
- callback(
- msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
- )
- start = now
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.delete(
- {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
- search.index_name(tenant_id),
- kb_id,
- )
- )
- for stru, rep in zip(community_structure, community_reports):
- obj = {
- "report": rep,
- "evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
- }
- chunk = {
- "docnm_kwd": stru["title"],
- "title_tks": rag_tokenizer.tokenize(stru["title"]),
- "content_with_weight": json.dumps(obj, ensure_ascii=False),
- "content_ltks": rag_tokenizer.tokenize(
- obj["report"] + " " + obj["evidences"]
- ),
- "knowledge_graph_kwd": "community_report",
- "weight_flt": stru["weight"],
- "entities_kwd": stru["entities"],
- "important_kwd": stru["entities"],
- "kb_id": kb_id,
- "source_id": doc_ids,
- "available_int": 0,
- }
- chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
- chunk["content_ltks"]
- )
- # try:
- # ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
- # chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
- # except Exception as e:
- # logging.exception(f"Fail to embed entity relation: {e}")
- await trio.to_thread.run_sync(
- lambda: settings.docStoreConn.insert(
- [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
- )
- )
-
- now = trio.current_time()
- callback(
- msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
- )
- return community_structure, community_reports
|