| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 | 
                        - #
 - #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - import json
 - import logging
 - from functools import partial
 - import networkx as nx
 - import trio
 - 
 - from api import settings
 - from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
 - from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
 - from graphrag.general.community_reports_extractor import CommunityReportsExtractor
 - from graphrag.entity_resolution import EntityResolution
 - from graphrag.general.extractor import Extractor
 - from graphrag.utils import (
 -     graph_merge,
 -     set_entity,
 -     get_relation,
 -     set_relation,
 -     get_entity,
 -     get_graph,
 -     set_graph,
 -     chunk_id,
 -     update_nodes_pagerank_nhop_neighbour,
 -     does_graph_contains,
 -     get_graph_doc_ids,
 - )
 - from rag.nlp import rag_tokenizer, search
 - from rag.utils.redis_conn import REDIS_CONN
 - 
 - 
 - def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
 -     key = f"graphrag:{tenant_id}:{kb_id}"
 -     ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
 -     if not ok:
 -         raise Exception(f"Faild to set the {key} to {doc_id}")
 - 
 - 
 - def graphrag_task_get(tenant_id, kb_id) -> str | None:
 -     key = f"graphrag:{tenant_id}:{kb_id}"
 -     doc_id = REDIS_CONN.get(key)
 -     return doc_id
 - 
 - 
 - async def run_graphrag(
 -     row: dict,
 -     language,
 -     with_resolution: bool,
 -     with_community: bool,
 -     chat_model,
 -     embedding_model,
 -     callback,
 - ):
 -     start = trio.current_time()
 -     tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
 -     chunks = []
 -     for d in settings.retrievaler.chunk_list(
 -         doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
 -     ):
 -         chunks.append(d["content_with_weight"])
 - 
 -     graph, doc_ids = await update_graph(
 -         LightKGExt
 -         if row["parser_config"]["graphrag"]["method"] != "general"
 -         else GeneralKGExt,
 -         tenant_id,
 -         kb_id,
 -         doc_id,
 -         chunks,
 -         language,
 -         row["parser_config"]["graphrag"]["entity_types"],
 -         chat_model,
 -         embedding_model,
 -         callback,
 -     )
 -     if not graph:
 -         return
 -     if with_resolution or with_community:
 -         graphrag_task_set(tenant_id, kb_id, doc_id)
 -     if with_resolution:
 -         await resolve_entities(
 -             graph,
 -             doc_ids,
 -             tenant_id,
 -             kb_id,
 -             doc_id,
 -             chat_model,
 -             embedding_model,
 -             callback,
 -         )
 -     if with_community:
 -         await extract_community(
 -             graph,
 -             doc_ids,
 -             tenant_id,
 -             kb_id,
 -             doc_id,
 -             chat_model,
 -             embedding_model,
 -             callback,
 -         )
 -     now = trio.current_time()
 -     callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
 -     return
 - 
 - 
 - async def update_graph(
 -     extractor: Extractor,
 -     tenant_id: str,
 -     kb_id: str,
 -     doc_id: str,
 -     chunks: list[str],
 -     language,
 -     entity_types,
 -     llm_bdl,
 -     embed_bdl,
 -     callback,
 - ):
 -     contains = await does_graph_contains(tenant_id, kb_id, doc_id)
 -     if contains:
 -         callback(msg=f"Graph already contains {doc_id}, cancel myself")
 -         return None, None
 -     start = trio.current_time()
 -     ext = extractor(
 -         llm_bdl,
 -         language=language,
 -         entity_types=entity_types,
 -         get_entity=partial(get_entity, tenant_id, kb_id),
 -         set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
 -         get_relation=partial(get_relation, tenant_id, kb_id),
 -         set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
 -     )
 -     ents, rels = await ext(doc_id, chunks, callback)
 -     subgraph = nx.Graph()
 -     for en in ents:
 -         subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
 - 
 -     for rel in rels:
 -         subgraph.add_edge(
 -             rel["src_id"],
 -             rel["tgt_id"],
 -             weight=rel["weight"],
 -             # description=rel["description"]
 -         )
 -     # TODO: infinity doesn't support array search
 -     chunk = {
 -         "content_with_weight": json.dumps(
 -             nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
 -         ),
 -         "knowledge_graph_kwd": "subgraph",
 -         "kb_id": kb_id,
 -         "source_id": [doc_id],
 -         "available_int": 0,
 -         "removed_kwd": "N",
 -     }
 -     cid = chunk_id(chunk)
 -     await trio.to_thread.run_sync(
 -         lambda: settings.docStoreConn.insert(
 -             [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
 -         )
 -     )
 -     now = trio.current_time()
 -     callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
 -     start = now
 - 
 -     while True:
 -         new_graph = subgraph
 -         now_docids = set([doc_id])
 -         old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
 -         if old_graph is not None:
 -             logging.info("Merge with an exiting graph...................")
 -             new_graph = graph_merge(old_graph, subgraph)
 -         await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
 -         if old_doc_ids:
 -             for old_doc_id in old_doc_ids:
 -                 now_docids.add(old_doc_id)
 -         old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
 -         delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
 -         if delta_doc_ids:
 -             callback(
 -                 msg="The global graph has changed during merging, try again"
 -             )
 -             await trio.sleep(1)
 -             continue
 -         break
 -     await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
 -     now = trio.current_time()
 -     callback(
 -         msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
 -     )
 -     return new_graph, now_docids
 - 
 - 
 - async def resolve_entities(
 -     graph,
 -     doc_ids,
 -     tenant_id: str,
 -     kb_id: str,
 -     doc_id: str,
 -     llm_bdl,
 -     embed_bdl,
 -     callback,
 - ):
 -     working_doc_id = graphrag_task_get(tenant_id, kb_id)
 -     if doc_id != working_doc_id:
 -         callback(
 -             msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
 -         )
 -         return
 -     start = trio.current_time()
 -     er = EntityResolution(
 -         llm_bdl,
 -         get_entity=partial(get_entity, tenant_id, kb_id),
 -         set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
 -         get_relation=partial(get_relation, tenant_id, kb_id),
 -         set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
 -     )
 -     reso = await er(graph)
 -     graph = reso.graph
 -     callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
 -     await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
 -     callback(msg="Graph resolution updated pagerank.")
 - 
 -     working_doc_id = graphrag_task_get(tenant_id, kb_id)
 -     if doc_id != working_doc_id:
 -         callback(
 -             msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
 -         )
 -         return
 -     await set_graph(tenant_id, kb_id, graph, doc_ids)
 - 
 -     await trio.to_thread.run_sync(
 -         lambda: settings.docStoreConn.delete(
 -             {
 -                 "knowledge_graph_kwd": "relation",
 -                 "kb_id": kb_id,
 -                 "from_entity_kwd": reso.removed_entities,
 -             },
 -             search.index_name(tenant_id),
 -             kb_id,
 -         )
 -     )
 -     await trio.to_thread.run_sync(
 -         lambda: settings.docStoreConn.delete(
 -             {
 -                 "knowledge_graph_kwd": "relation",
 -                 "kb_id": kb_id,
 -                 "to_entity_kwd": reso.removed_entities,
 -             },
 -             search.index_name(tenant_id),
 -             kb_id,
 -         )
 -     )
 -     await trio.to_thread.run_sync(
 -         lambda: settings.docStoreConn.delete(
 -             {
 -                 "knowledge_graph_kwd": "entity",
 -                 "kb_id": kb_id,
 -                 "entity_kwd": reso.removed_entities,
 -             },
 -             search.index_name(tenant_id),
 -             kb_id,
 -         )
 -     )
 -     now = trio.current_time()
 -     callback(msg=f"Graph resolution done in {now - start:.2f}s.")
 - 
 - 
 - async def extract_community(
 -     graph,
 -     doc_ids,
 -     tenant_id: str,
 -     kb_id: str,
 -     doc_id: str,
 -     llm_bdl,
 -     embed_bdl,
 -     callback,
 - ):
 -     working_doc_id = graphrag_task_get(tenant_id, kb_id)
 -     if doc_id != working_doc_id:
 -         callback(
 -             msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
 -         )
 -         return
 -     start = trio.current_time()
 -     ext = CommunityReportsExtractor(
 -         llm_bdl,
 -         get_entity=partial(get_entity, tenant_id, kb_id),
 -         set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
 -         get_relation=partial(get_relation, tenant_id, kb_id),
 -         set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
 -     )
 -     cr = await ext(graph, callback=callback)
 -     community_structure = cr.structured_output
 -     community_reports = cr.output
 -     working_doc_id = graphrag_task_get(tenant_id, kb_id)
 -     if doc_id != working_doc_id:
 -         callback(
 -             msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
 -         )
 -         return
 -     await set_graph(tenant_id, kb_id, graph, doc_ids)
 - 
 -     now = trio.current_time()
 -     callback(
 -         msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
 -     )
 -     start = now
 -     await trio.to_thread.run_sync(
 -         lambda: settings.docStoreConn.delete(
 -             {"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
 -             search.index_name(tenant_id),
 -             kb_id,
 -         )
 -     )
 -     for stru, rep in zip(community_structure, community_reports):
 -         obj = {
 -             "report": rep,
 -             "evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
 -         }
 -         chunk = {
 -             "docnm_kwd": stru["title"],
 -             "title_tks": rag_tokenizer.tokenize(stru["title"]),
 -             "content_with_weight": json.dumps(obj, ensure_ascii=False),
 -             "content_ltks": rag_tokenizer.tokenize(
 -                 obj["report"] + " " + obj["evidences"]
 -             ),
 -             "knowledge_graph_kwd": "community_report",
 -             "weight_flt": stru["weight"],
 -             "entities_kwd": stru["entities"],
 -             "important_kwd": stru["entities"],
 -             "kb_id": kb_id,
 -             "source_id": doc_ids,
 -             "available_int": 0,
 -         }
 -         chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
 -             chunk["content_ltks"]
 -         )
 -         # try:
 -         #    ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
 -         #    chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
 -         # except Exception as e:
 -         #    logging.exception(f"Fail to embed entity relation: {e}")
 -         await trio.to_thread.run_sync(
 -             lambda: settings.docStoreConn.insert(
 -                 [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
 -             )
 -         )
 - 
 -     now = trio.current_time()
 -     callback(
 -         msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
 -     )
 -     return community_structure, community_reports
 
 
  |