- #
 - #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - 
 - import argparse
 - import json
 - import logging
 - import networkx as nx
 - import trio
 - 
 - from api import settings
 - from api.db import LLMType
 - from api.db.services.document_service import DocumentService
 - from api.db.services.knowledgebase_service import KnowledgebaseService
 - from api.db.services.llm_service import LLMBundle
 - from api.db.services.user_service import TenantService
 - from graphrag.general.graph_extractor import GraphExtractor
 - from graphrag.general.index import update_graph, with_resolution, with_community
 - 
 - settings.init_settings()
 - 
 - 
 - def callback(prog=None, msg="Processing..."):
 -     logging.info(msg)
 - 
 - 
 - async def main():
 -     parser = argparse.ArgumentParser()
 -     parser.add_argument(
 -         "-t",
 -         "--tenant_id",
 -         default=False,
 -         help="Tenant ID",
 -         action="store",
 -         required=True,
 -     )
 -     parser.add_argument(
 -         "-d",
 -         "--doc_id",
 -         default=False,
 -         help="Document ID",
 -         action="store",
 -         required=True,
 -     )
 -     args = parser.parse_args()
 -     e, doc = DocumentService.get_by_id(args.doc_id)
 -     if not e:
 -         raise LookupError("Document not found.")
 -     kb_id = doc.kb_id
 - 
 -     chunks = [
 -         d["content_with_weight"]
 -         for d in settings.retrievaler.chunk_list(
 -             args.doc_id,
 -             args.tenant_id,
 -             [kb_id],
 -             max_count=6,
 -             fields=["content_with_weight"],
 -         )
 -     ]
 - 
 -     _, tenant = TenantService.get_by_id(args.tenant_id)
 -     llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
 -     _, kb = KnowledgebaseService.get_by_id(kb_id)
 -     embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
 - 
 -     graph, doc_ids = await update_graph(
 -         GraphExtractor,
 -         args.tenant_id,
 -         kb_id,
 -         args.doc_id,
 -         chunks,
 -         "English",
 -         llm_bdl,
 -         embed_bdl,
 -         callback,
 -     )
 -     print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
 - 
 -     await with_resolution(
 -         args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
 -     )
 -     community_structure, community_reports = await with_community(
 -         args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
 -     )
 - 
 -     print(
 -         "------------------ COMMUNITY STRUCTURE--------------------\n",
 -         json.dumps(community_structure, ensure_ascii=False, indent=2),
 -     )
 -     print(
 -         "------------------ COMMUNITY REPORTS----------------------\n",
 -         community_reports,
 -     )
 - 
 - 
 - if __name__ == "__main__":
 -     trio.run(main)
 
 
  |