### What problem does this PR solve? Refactor graphrag to remove redis lock ### Type of change - [x] Refactoringtags/v0.17.1
| from api.versions import get_ragflow_version | from api.versions import get_ragflow_version | ||||
| from api.utils import show_configs | from api.utils import show_configs | ||||
| from rag.settings import print_rag_settings | from rag.settings import print_rag_settings | ||||
| from rag.utils.redis_conn import RedisDistributedLock | |||||
| stop_event = threading.Event() | stop_event = threading.Event() | ||||
| def update_progress(): | def update_progress(): | ||||
| redis_lock = RedisDistributedLock("update_progress", timeout=60) | |||||
| while not stop_event.is_set(): | while not stop_event.is_set(): | ||||
| try: | try: | ||||
| if not redis_lock.acquire(): | |||||
| continue | |||||
| DocumentService.update_progress() | DocumentService.update_progress() | ||||
| stop_event.wait(6) | stop_event.wait(6) | ||||
| except Exception: | except Exception: | ||||
| logging.exception("update_progress exception") | logging.exception("update_progress exception") | ||||
| finally: | |||||
| redis_lock.release() | |||||
| def signal_handler(sig, frame): | def signal_handler(sig, frame): | ||||
| logging.info("Received interrupt signal, shutting down...") | logging.info("Received interrupt signal, shutting down...") | 
| return dict(maybe_nodes), dict(maybe_edges) | return dict(maybe_nodes), dict(maybe_edges) | ||||
| async def __call__( | async def __call__( | ||||
| self, chunks: list[tuple[str, str]], | |||||
| self, doc_id: str, chunks: list[str], | |||||
| callback: Callable | None = None | callback: Callable | None = None | ||||
| ): | ): | ||||
| start_ts = trio.current_time() | start_ts = trio.current_time() | ||||
| out_results = [] | out_results = [] | ||||
| async with trio.open_nursery() as nursery: | async with trio.open_nursery() as nursery: | ||||
| for i, (cid, ck) in enumerate(chunks): | |||||
| for i, ck in enumerate(chunks): | |||||
| ck = truncate(ck, int(self._llm.max_length*0.8)) | ck = truncate(ck, int(self._llm.max_length*0.8)) | ||||
| nursery.start_soon(lambda: self._process_single_content((cid, ck), i, len(chunks), out_results)) | |||||
| nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results)) | |||||
| maybe_nodes = defaultdict(list) | maybe_nodes = defaultdict(list) | ||||
| maybe_edges = defaultdict(list) | maybe_edges = defaultdict(list) | ||||
| ) -> str: | ) -> str: | ||||
| summary_max_tokens = 512 | summary_max_tokens = 512 | ||||
| use_description = truncate(description, summary_max_tokens) | use_description = truncate(description, summary_max_tokens) | ||||
| description_list=use_description.split(GRAPH_FIELD_SEP), | |||||
| if len(description_list) <= 12: | |||||
| return use_description | |||||
| prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT | prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT | ||||
| context_base = dict( | context_base = dict( | ||||
| entity_name=entity_or_relation_name, | entity_name=entity_or_relation_name, | ||||
| description_list=use_description.split(GRAPH_FIELD_SEP), | |||||
| description_list=description_list, | |||||
| language=self._language, | language=self._language, | ||||
| ) | ) | ||||
| use_prompt = prompt_template.format(**context_base) | use_prompt = prompt_template.format(**context_base) | 
| # | # | ||||
| import json | import json | ||||
| import logging | import logging | ||||
| from functools import reduce, partial | |||||
| from functools import partial | |||||
| import networkx as nx | import networkx as nx | ||||
| import trio | import trio | ||||
| from api import settings | from api import settings | ||||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||||
| from graphrag.general.community_reports_extractor import CommunityReportsExtractor | from graphrag.general.community_reports_extractor import CommunityReportsExtractor | ||||
| from graphrag.entity_resolution import EntityResolution | from graphrag.entity_resolution import EntityResolution | ||||
| from graphrag.general.extractor import Extractor | from graphrag.general.extractor import Extractor | ||||
| from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES | |||||
| from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \ | |||||
| chunk_id, update_nodes_pagerank_nhop_neighbour | |||||
| from graphrag.utils import ( | |||||
| graph_merge, | |||||
| set_entity, | |||||
| get_relation, | |||||
| set_relation, | |||||
| get_entity, | |||||
| get_graph, | |||||
| set_graph, | |||||
| chunk_id, | |||||
| update_nodes_pagerank_nhop_neighbour, | |||||
| does_graph_contains, | |||||
| get_graph_doc_ids, | |||||
| ) | |||||
| from rag.nlp import rag_tokenizer, search | from rag.nlp import rag_tokenizer, search | ||||
| from rag.utils.redis_conn import RedisDistributedLock | |||||
| class Dealer: | |||||
| def __init__(self, | |||||
| extractor: Extractor, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| llm_bdl, | |||||
| chunks: list[tuple[str, str]], | |||||
| language, | |||||
| entity_types=DEFAULT_ENTITY_TYPES, | |||||
| embed_bdl=None, | |||||
| callback=None | |||||
| ): | |||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.chunks = chunks | |||||
| self.llm_bdl = llm_bdl | |||||
| self.embed_bdl = embed_bdl | |||||
| self.ext = extractor(self.llm_bdl, language=language, | |||||
| entity_types=entity_types, | |||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) | |||||
| ) | |||||
| self.graph = nx.Graph() | |||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| docids = list(set([docid for docid, _ in self.chunks])) | |||||
| ents, rels = await self.ext(self.chunks, self.callback) | |||||
| for en in ents: | |||||
| self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) | |||||
| for rel in rels: | |||||
| self.graph.add_edge( | |||||
| rel["src_id"], | |||||
| rel["tgt_id"], | |||||
| weight=rel["weight"], | |||||
| #description=rel["description"] | |||||
| from rag.utils.redis_conn import REDIS_CONN | |||||
| def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool: | |||||
| key = f"graphrag:{tenant_id}:{kb_id}" | |||||
| ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24) | |||||
| if not ok: | |||||
| raise Exception(f"Faild to set the {key} to {doc_id}") | |||||
| def graphrag_task_get(tenant_id, kb_id) -> str | None: | |||||
| key = f"graphrag:{tenant_id}:{kb_id}" | |||||
| doc_id = REDIS_CONN.get(key) | |||||
| return doc_id | |||||
| async def run_graphrag( | |||||
| row: dict, | |||||
| language, | |||||
| with_resolution: bool, | |||||
| with_community: bool, | |||||
| chat_model, | |||||
| embedding_model, | |||||
| callback, | |||||
| ): | |||||
| start = trio.current_time() | |||||
| tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] | |||||
| chunks = [] | |||||
| for d in settings.retrievaler.chunk_list( | |||||
| doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"] | |||||
| ): | |||||
| chunks.append(d["content_with_weight"]) | |||||
| graph, doc_ids = await update_graph( | |||||
| LightKGExt | |||||
| if row["parser_config"]["graphrag"]["method"] != "general" | |||||
| else GeneralKGExt, | |||||
| tenant_id, | |||||
| kb_id, | |||||
| doc_id, | |||||
| chunks, | |||||
| language, | |||||
| row["parser_config"]["graphrag"]["entity_types"], | |||||
| chat_model, | |||||
| embedding_model, | |||||
| callback, | |||||
| ) | |||||
| if not graph: | |||||
| return | |||||
| if with_resolution or with_community: | |||||
| graphrag_task_set(tenant_id, kb_id, doc_id) | |||||
| if with_resolution: | |||||
| await resolve_entities( | |||||
| graph, | |||||
| doc_ids, | |||||
| tenant_id, | |||||
| kb_id, | |||||
| doc_id, | |||||
| chat_model, | |||||
| embedding_model, | |||||
| callback, | |||||
| ) | |||||
| if with_community: | |||||
| await extract_community( | |||||
| graph, | |||||
| doc_ids, | |||||
| tenant_id, | |||||
| kb_id, | |||||
| doc_id, | |||||
| chat_model, | |||||
| embedding_model, | |||||
| callback, | |||||
| ) | |||||
| now = trio.current_time() | |||||
| callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") | |||||
| return | |||||
| async def update_graph( | |||||
| extractor: Extractor, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| doc_id: str, | |||||
| chunks: list[str], | |||||
| language, | |||||
| entity_types, | |||||
| llm_bdl, | |||||
| embed_bdl, | |||||
| callback, | |||||
| ): | |||||
| contains = await does_graph_contains(tenant_id, kb_id, doc_id) | |||||
| if contains: | |||||
| callback(msg=f"Graph already contains {doc_id}, cancel myself") | |||||
| return None, None | |||||
| start = trio.current_time() | |||||
| ext = extractor( | |||||
| llm_bdl, | |||||
| language=language, | |||||
| entity_types=entity_types, | |||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||||
| ) | |||||
| ents, rels = await ext(doc_id, chunks, callback) | |||||
| subgraph = nx.Graph() | |||||
| for en in ents: | |||||
| subgraph.add_node(en["entity_name"], entity_type=en["entity_type"]) | |||||
| for rel in rels: | |||||
| subgraph.add_edge( | |||||
| rel["src_id"], | |||||
| rel["tgt_id"], | |||||
| weight=rel["weight"], | |||||
| # description=rel["description"] | |||||
| ) | |||||
| # TODO: infinity doesn't support array search | |||||
| chunk = { | |||||
| "content_with_weight": json.dumps( | |||||
| nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2 | |||||
| ), | |||||
| "knowledge_graph_kwd": "subgraph", | |||||
| "kb_id": kb_id, | |||||
| "source_id": [doc_id], | |||||
| "available_int": 0, | |||||
| "removed_kwd": "N", | |||||
| } | |||||
| cid = chunk_id(chunk) | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.insert( | |||||
| [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id | |||||
| ) | |||||
| ) | |||||
| now = trio.current_time() | |||||
| callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") | |||||
| start = now | |||||
| while True: | |||||
| new_graph = subgraph | |||||
| now_docids = set([doc_id]) | |||||
| old_graph, old_doc_ids = await get_graph(tenant_id, kb_id) | |||||
| if old_graph is not None: | |||||
| logging.info("Merge with an exiting graph...................") | |||||
| new_graph = graph_merge(old_graph, subgraph) | |||||
| await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2) | |||||
| if old_doc_ids: | |||||
| for old_doc_id in old_doc_ids: | |||||
| now_docids.add(old_doc_id) | |||||
| old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id) | |||||
| delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids) | |||||
| if delta_doc_ids: | |||||
| callback( | |||||
| msg="The global graph has changed during merging, try again" | |||||
| ) | ) | ||||
| await trio.sleep(1) | |||||
| continue | |||||
| break | |||||
| await set_graph(tenant_id, kb_id, new_graph, list(now_docids)) | |||||
| now = trio.current_time() | |||||
| callback( | |||||
| msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds." | |||||
| ) | |||||
| return new_graph, now_docids | |||||
| async def resolve_entities( | |||||
| graph, | |||||
| doc_ids, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| doc_id: str, | |||||
| llm_bdl, | |||||
| embed_bdl, | |||||
| callback, | |||||
| ): | |||||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||||
| if doc_id != working_doc_id: | |||||
| callback( | |||||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||||
| ) | |||||
| return | |||||
| start = trio.current_time() | |||||
| er = EntityResolution( | |||||
| llm_bdl, | |||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||||
| ) | |||||
| reso = await er(graph) | |||||
| graph = reso.graph | |||||
| callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") | |||||
| await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) | |||||
| callback(msg="Graph resolution updated pagerank.") | |||||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||||
| if doc_id != working_doc_id: | |||||
| callback( | |||||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||||
| ) | |||||
| return | |||||
| await set_graph(tenant_id, kb_id, graph, doc_ids) | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id) | |||||
| if old_graph is not None: | |||||
| logging.info("Merge with an exiting graph...................") | |||||
| self.graph = reduce(graph_merge, [old_graph, self.graph]) | |||||
| update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2) | |||||
| if old_doc_ids: | |||||
| docids.extend(old_doc_ids) | |||||
| docids = list(set(docids)) | |||||
| set_graph(self.tenant_id, self.kb_id, self.graph, docids) | |||||
| class WithResolution(Dealer): | |||||
| def __init__(self, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| llm_bdl, | |||||
| embed_bdl=None, | |||||
| callback=None | |||||
| ): | |||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.llm_bdl = llm_bdl | |||||
| self.embed_bdl = embed_bdl | |||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id)) | |||||
| if not self.graph: | |||||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||||
| if self.callback: | |||||
| self.callback(-1, msg="Faild to fetch the graph.") | |||||
| return | |||||
| if self.callback: | |||||
| self.callback(msg="Fetch the existing graph.") | |||||
| er = EntityResolution(self.llm_bdl, | |||||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||||
| reso = await er(self.graph) | |||||
| self.graph = reso.graph | |||||
| logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||||
| if self.callback: | |||||
| self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||||
| await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)) | |||||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "relation", | |||||
| "kb_id": self.kb_id, | |||||
| "from_entity_kwd": reso.removed_entities | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "relation", | |||||
| "kb_id": self.kb_id, | |||||
| "to_entity_kwd": reso.removed_entities | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| "knowledge_graph_kwd": "entity", | |||||
| "kb_id": self.kb_id, | |||||
| "entity_kwd": reso.removed_entities | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| class WithCommunity(Dealer): | |||||
| def __init__(self, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| llm_bdl, | |||||
| embed_bdl=None, | |||||
| callback=None | |||||
| ): | |||||
| self.tenant_id = tenant_id | |||||
| self.kb_id = kb_id | |||||
| self.community_structure = None | |||||
| self.community_reports = None | |||||
| self.llm_bdl = llm_bdl | |||||
| self.embed_bdl = embed_bdl | |||||
| self.callback = callback | |||||
| async def __call__(self): | |||||
| with RedisDistributedLock(self.kb_id, 60*60): | |||||
| self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id) | |||||
| if not self.graph: | |||||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||||
| if self.callback: | |||||
| self.callback(-1, msg="Faild to fetch the graph.") | |||||
| return | |||||
| if self.callback: | |||||
| self.callback(msg="Fetch the existing graph.") | |||||
| cr = CommunityReportsExtractor(self.llm_bdl, | |||||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||||
| cr = await cr(self.graph, callback=self.callback) | |||||
| self.community_structure = cr.structured_output | |||||
| self.community_reports = cr.output | |||||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||||
| if self.callback: | |||||
| self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.delete( | |||||
| { | |||||
| "knowledge_graph_kwd": "relation", | |||||
| "kb_id": kb_id, | |||||
| "from_entity_kwd": reso.removed_entities, | |||||
| }, | |||||
| search.index_name(tenant_id), | |||||
| kb_id, | |||||
| ) | |||||
| ) | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.delete( | |||||
| { | |||||
| "knowledge_graph_kwd": "relation", | |||||
| "kb_id": kb_id, | |||||
| "to_entity_kwd": reso.removed_entities, | |||||
| }, | |||||
| search.index_name(tenant_id), | |||||
| kb_id, | |||||
| ) | |||||
| ) | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.delete( | |||||
| { | |||||
| "knowledge_graph_kwd": "entity", | |||||
| "kb_id": kb_id, | |||||
| "entity_kwd": reso.removed_entities, | |||||
| }, | |||||
| search.index_name(tenant_id), | |||||
| kb_id, | |||||
| ) | |||||
| ) | |||||
| now = trio.current_time() | |||||
| callback(msg=f"Graph resolution done in {now - start:.2f}s.") | |||||
| async def extract_community( | |||||
| graph, | |||||
| doc_ids, | |||||
| tenant_id: str, | |||||
| kb_id: str, | |||||
| doc_id: str, | |||||
| llm_bdl, | |||||
| embed_bdl, | |||||
| callback, | |||||
| ): | |||||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||||
| if doc_id != working_doc_id: | |||||
| callback( | |||||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||||
| ) | |||||
| return | |||||
| start = trio.current_time() | |||||
| ext = CommunityReportsExtractor( | |||||
| llm_bdl, | |||||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||||
| ) | |||||
| cr = await ext(graph, callback=callback) | |||||
| community_structure = cr.structured_output | |||||
| community_reports = cr.output | |||||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||||
| if doc_id != working_doc_id: | |||||
| callback( | |||||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||||
| ) | |||||
| return | |||||
| await set_graph(tenant_id, kb_id, graph, doc_ids) | |||||
| now = trio.current_time() | |||||
| callback( | |||||
| msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s." | |||||
| ) | |||||
| start = now | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.delete( | |||||
| {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, | |||||
| search.index_name(tenant_id), | |||||
| kb_id, | |||||
| ) | |||||
| ) | |||||
| for stru, rep in zip(community_structure, community_reports): | |||||
| obj = { | |||||
| "report": rep, | |||||
| "evidences": "\n".join([f["explanation"] for f in stru["findings"]]), | |||||
| } | |||||
| chunk = { | |||||
| "docnm_kwd": stru["title"], | |||||
| "title_tks": rag_tokenizer.tokenize(stru["title"]), | |||||
| "content_with_weight": json.dumps(obj, ensure_ascii=False), | |||||
| "content_ltks": rag_tokenizer.tokenize( | |||||
| obj["report"] + " " + obj["evidences"] | |||||
| ), | |||||
| "knowledge_graph_kwd": "community_report", | "knowledge_graph_kwd": "community_report", | ||||
| "kb_id": self.kb_id | |||||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||||
| for stru, rep in zip(self.community_structure, self.community_reports): | |||||
| obj = { | |||||
| "report": rep, | |||||
| "evidences": "\n".join([f["explanation"] for f in stru["findings"]]) | |||||
| } | |||||
| chunk = { | |||||
| "docnm_kwd": stru["title"], | |||||
| "title_tks": rag_tokenizer.tokenize(stru["title"]), | |||||
| "content_with_weight": json.dumps(obj, ensure_ascii=False), | |||||
| "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]), | |||||
| "knowledge_graph_kwd": "community_report", | |||||
| "weight_flt": stru["weight"], | |||||
| "entities_kwd": stru["entities"], | |||||
| "important_kwd": stru["entities"], | |||||
| "kb_id": self.kb_id, | |||||
| "source_id": doc_ids, | |||||
| "available_int": 0 | |||||
| } | |||||
| chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) | |||||
| #try: | |||||
| # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])]) | |||||
| # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | |||||
| #except Exception as e: | |||||
| # logging.exception(f"Fail to embed entity relation: {e}") | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id))) | |||||
| "weight_flt": stru["weight"], | |||||
| "entities_kwd": stru["entities"], | |||||
| "important_kwd": stru["entities"], | |||||
| "kb_id": kb_id, | |||||
| "source_id": doc_ids, | |||||
| "available_int": 0, | |||||
| } | |||||
| chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( | |||||
| chunk["content_ltks"] | |||||
| ) | |||||
| # try: | |||||
| # ebd, _ = embed_bdl.encode([", ".join(community["entities"])]) | |||||
| # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | |||||
| # except Exception as e: | |||||
| # logging.exception(f"Fail to embed entity relation: {e}") | |||||
| await trio.to_thread.run_sync( | |||||
| lambda: settings.docStoreConn.insert( | |||||
| [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id) | |||||
| ) | |||||
| ) | |||||
| now = trio.current_time() | |||||
| callback( | |||||
| msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s." | |||||
| ) | |||||
| return community_structure, community_reports | 
| import argparse | import argparse | ||||
| import json | import json | ||||
| import logging | |||||
| import networkx as nx | import networkx as nx | ||||
| import trio | import trio | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from graphrag.general.index import WithCommunity, Dealer, WithResolution | |||||
| from graphrag.light.graph_extractor import GraphExtractor | |||||
| from rag.utils.redis_conn import RedisDistributedLock | |||||
| from graphrag.general.graph_extractor import GraphExtractor | |||||
| from graphrag.general.index import update_graph, with_resolution, with_community | |||||
| settings.init_settings() | settings.init_settings() | ||||
| if __name__ == "__main__": | |||||
| def callback(prog=None, msg="Processing..."): | |||||
| logging.info(msg) | |||||
| async def main(): | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) | |||||
| parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) | |||||
| parser.add_argument( | |||||
| "-t", | |||||
| "--tenant_id", | |||||
| default=False, | |||||
| help="Tenant ID", | |||||
| action="store", | |||||
| required=True, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "-d", | |||||
| "--doc_id", | |||||
| default=False, | |||||
| help="Document ID", | |||||
| action="store", | |||||
| required=True, | |||||
| ) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| e, doc = DocumentService.get_by_id(args.doc_id) | e, doc = DocumentService.get_by_id(args.doc_id) | ||||
| if not e: | if not e: | ||||
| raise LookupError("Document not found.") | raise LookupError("Document not found.") | ||||
| kb_id = doc.kb_id | kb_id = doc.kb_id | ||||
| chunks = [d["content_with_weight"] for d in | |||||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, | |||||
| fields=["content_with_weight"])] | |||||
| chunks = [("x", c) for c in chunks] | |||||
| RedisDistributedLock.clean_lock(kb_id) | |||||
| chunks = [ | |||||
| d["content_with_weight"] | |||||
| for d in settings.retrievaler.chunk_list( | |||||
| args.doc_id, | |||||
| args.tenant_id, | |||||
| [kb_id], | |||||
| max_count=6, | |||||
| fields=["content_with_weight"], | |||||
| ) | |||||
| ] | |||||
| _, tenant = TenantService.get_by_id(args.tenant_id) | _, tenant = TenantService.get_by_id(args.tenant_id) | ||||
| llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | ||||
| _, kb = KnowledgebaseService.get_by_id(kb_id) | _, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | ||||
| dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | |||||
| trio.run(dealer()) | |||||
| print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | |||||
| graph, doc_ids = await update_graph( | |||||
| GraphExtractor, | |||||
| args.tenant_id, | |||||
| kb_id, | |||||
| args.doc_id, | |||||
| chunks, | |||||
| "English", | |||||
| llm_bdl, | |||||
| embed_bdl, | |||||
| callback, | |||||
| ) | |||||
| print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) | |||||
| await with_resolution( | |||||
| args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback | |||||
| ) | |||||
| community_structure, community_reports = await with_community( | |||||
| args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback | |||||
| ) | |||||
| dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) | |||||
| trio.run(dealer()) | |||||
| dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) | |||||
| trio.run(dealer()) | |||||
| print( | |||||
| "------------------ COMMUNITY STRUCTURE--------------------\n", | |||||
| json.dumps(community_structure, ensure_ascii=False, indent=2), | |||||
| ) | |||||
| print( | |||||
| "------------------ COMMUNITY REPORTS----------------------\n", | |||||
| community_reports, | |||||
| ) | |||||
| print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) | |||||
| print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) | |||||
| if __name__ == "__main__": | |||||
| trio.run(main) | 
| import json | import json | ||||
| from api import settings | from api import settings | ||||
| import networkx as nx | import networkx as nx | ||||
| import logging | |||||
| import trio | |||||
| from api.db import LLMType | from api.db import LLMType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMBundle | from api.db.services.llm_service import LLMBundle | ||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from graphrag.general.index import Dealer | |||||
| from graphrag.general.index import update_graph | |||||
| from graphrag.light.graph_extractor import GraphExtractor | from graphrag.light.graph_extractor import GraphExtractor | ||||
| from rag.utils.redis_conn import RedisDistributedLock | |||||
| settings.init_settings() | settings.init_settings() | ||||
| if __name__ == "__main__": | |||||
| def callback(prog=None, msg="Processing..."): | |||||
| logging.info(msg) | |||||
| async def main(): | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) | |||||
| parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) | |||||
| parser.add_argument( | |||||
| "-t", | |||||
| "--tenant_id", | |||||
| default=False, | |||||
| help="Tenant ID", | |||||
| action="store", | |||||
| required=True, | |||||
| ) | |||||
| parser.add_argument( | |||||
| "-d", | |||||
| "--doc_id", | |||||
| default=False, | |||||
| help="Document ID", | |||||
| action="store", | |||||
| required=True, | |||||
| ) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| e, doc = DocumentService.get_by_id(args.doc_id) | e, doc = DocumentService.get_by_id(args.doc_id) | ||||
| raise LookupError("Document not found.") | raise LookupError("Document not found.") | ||||
| kb_id = doc.kb_id | kb_id = doc.kb_id | ||||
| chunks = [d["content_with_weight"] for d in | |||||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, | |||||
| fields=["content_with_weight"])] | |||||
| chunks = [("x", c) for c in chunks] | |||||
| RedisDistributedLock.clean_lock(kb_id) | |||||
| chunks = [ | |||||
| d["content_with_weight"] | |||||
| for d in settings.retrievaler.chunk_list( | |||||
| args.doc_id, | |||||
| args.tenant_id, | |||||
| [kb_id], | |||||
| max_count=6, | |||||
| fields=["content_with_weight"], | |||||
| ) | |||||
| ] | |||||
| _, tenant = TenantService.get_by_id(args.tenant_id) | _, tenant = TenantService.get_by_id(args.tenant_id) | ||||
| llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | ||||
| _, kb = KnowledgebaseService.get_by_id(kb_id) | _, kb = KnowledgebaseService.get_by_id(kb_id) | ||||
| embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | ||||
| dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | |||||
| graph, doc_ids = await update_graph( | |||||
| GraphExtractor, | |||||
| args.tenant_id, | |||||
| kb_id, | |||||
| args.doc_id, | |||||
| chunks, | |||||
| "English", | |||||
| llm_bdl, | |||||
| embed_bdl, | |||||
| callback, | |||||
| ) | |||||
| print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) | |||||
| print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | |||||
| if __name__ == "__main__": | |||||
| trio.run(main) | 
| chunk["q_%d_vec" % len(ebd)] = ebd | chunk["q_%d_vec" % len(ebd)] = ebd | ||||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | ||||
| async def does_graph_contains(tenant_id, kb_id, doc_id): | |||||
| # Get doc_ids of graph | |||||
| fields = ["source_id"] | |||||
| condition = { | |||||
| "knowledge_graph_kwd": ["graph"], | |||||
| "removed_kwd": "N", | |||||
| } | |||||
| res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) | |||||
| fields2 = settings.docStoreConn.getFields(res, fields) | |||||
| graph_doc_ids = set() | |||||
| for chunk_id in fields2.keys(): | |||||
| graph_doc_ids = set(fields2[chunk_id]["source_id"]) | |||||
| return doc_id in graph_doc_ids | |||||
| async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: | |||||
| conds = { | |||||
| "fields": ["source_id"], | |||||
| "removed_kwd": "N", | |||||
| "size": 1, | |||||
| "knowledge_graph_kwd": ["graph"] | |||||
| } | |||||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) | |||||
| doc_ids = [] | |||||
| if res.total == 0: | |||||
| return doc_ids | |||||
| for id in res.ids: | |||||
| doc_ids = res.field[id]["source_id"] | |||||
| return doc_ids | |||||
| def get_graph(tenant_id, kb_id): | |||||
| async def get_graph(tenant_id, kb_id): | |||||
| conds = { | conds = { | ||||
| "fields": ["content_with_weight", "source_id"], | "fields": ["content_with_weight", "source_id"], | ||||
| "removed_kwd": "N", | "removed_kwd": "N", | ||||
| "size": 1, | "size": 1, | ||||
| "knowledge_graph_kwd": ["graph"] | "knowledge_graph_kwd": ["graph"] | ||||
| } | } | ||||
| res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) | |||||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) | |||||
| if res.total == 0: | |||||
| return None, [] | |||||
| for id in res.ids: | for id in res.ids: | ||||
| try: | try: | ||||
| return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ | return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ | ||||
| res.field[id]["source_id"] | res.field[id]["source_id"] | ||||
| except Exception: | except Exception: | ||||
| continue | continue | ||||
| return rebuild_graph(tenant_id, kb_id) | |||||
| result = await rebuild_graph(tenant_id, kb_id) | |||||
| return result | |||||
| def set_graph(tenant_id, kb_id, graph, docids): | |||||
| async def set_graph(tenant_id, kb_id, graph, docids): | |||||
| chunk = { | chunk = { | ||||
| "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, | "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, | ||||
| indent=2), | indent=2), | ||||
| "source_id": list(docids), | "source_id": list(docids), | ||||
| "available_int": 0, | "available_int": 0, | ||||
| "removed_kwd": "N" | "removed_kwd": "N" | ||||
| } | |||||
| res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) | |||||
| } | |||||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])) | |||||
| if res.ids: | if res.ids: | ||||
| settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, | |||||
| search.index_name(tenant_id), kb_id) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, | |||||
| search.index_name(tenant_id), kb_id)) | |||||
| else: | else: | ||||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) | |||||
| def is_continuous_subsequence(subseq, seq): | def is_continuous_subsequence(subseq, seq): | ||||
| return result | return result | ||||
| def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||||
| async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||||
| def n_neighbor(id): | def n_neighbor(id): | ||||
| nonlocal graph, n_hop | nonlocal graph, n_hop | ||||
| count = 0 | count = 0 | ||||
| for n, p in pr.items(): | for n, p in pr.items(): | ||||
| graph.nodes[n]["pagerank"] = p | graph.nodes[n]["pagerank"] = p | ||||
| try: | try: | ||||
| settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||||
| {"rank_flt": p, | {"rank_flt": p, | ||||
| "n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)}, | |||||
| search.index_name(tenant_id), kb_id) | |||||
| "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, | |||||
| search.index_name(tenant_id), kb_id)) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(e) | logging.exception(e) | ||||
| "knowledge_graph_kwd": "ty2ents", | "knowledge_graph_kwd": "ty2ents", | ||||
| "available_int": 0 | "available_int": 0 | ||||
| } | } | ||||
| res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, | |||||
| search.index_name(tenant_id), [kb_id]) | |||||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, | |||||
| search.index_name(tenant_id), [kb_id])) | |||||
| if res.ids: | if res.ids: | ||||
| settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, | |||||
| chunk, | chunk, | ||||
| search.index_name(tenant_id), kb_id) | |||||
| search.index_name(tenant_id), kb_id)) | |||||
| else: | else: | ||||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | |||||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) | |||||
| def get_entity_type2sampels(idxnms, kb_ids: list): | |||||
| es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, | |||||
| async def get_entity_type2sampels(idxnms, kb_ids: list): | |||||
| es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, | |||||
| "size": 10000, | "size": 10000, | ||||
| "fields": ["content_with_weight"]}, | "fields": ["content_with_weight"]}, | ||||
| idxnms, kb_ids) | |||||
| idxnms, kb_ids)) | |||||
| res = defaultdict(list) | res = defaultdict(list) | ||||
| for id in es_res.ids: | for id in es_res.ids: | ||||
| return list(set(res)) | return list(set(res)) | ||||
| def rebuild_graph(tenant_id, kb_id): | |||||
| async def rebuild_graph(tenant_id, kb_id): | |||||
| graph = nx.Graph() | graph = nx.Graph() | ||||
| src_ids = [] | src_ids = [] | ||||
| flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] | flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] | ||||
| bs = 256 | bs = 256 | ||||
| for i in range(0, 39*bs, bs): | for i in range(0, 39*bs, bs): | ||||
| es_res = settings.docStoreConn.search(flds, [], | |||||
| es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], | |||||
| {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, | {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, | ||||
| [], | [], | ||||
| OrderByExpr(), | OrderByExpr(), | ||||
| i, bs, search.index_name(tenant_id), [kb_id] | i, bs, search.index_name(tenant_id), [kb_id] | ||||
| ) | |||||
| )) | |||||
| tot = settings.docStoreConn.getTotal(es_res) | tot = settings.docStoreConn.getTotal(es_res) | ||||
| if tot == 0: | if tot == 0: | ||||
| return None, None | return None, None | 
| # | # | ||||
| import logging | import logging | ||||
| import re | import re | ||||
| from threading import Lock | |||||
| import umap | import umap | ||||
| import numpy as np | import numpy as np | ||||
| from sklearn.mixture import GaussianMixture | from sklearn.mixture import GaussianMixture | ||||
| import trio | import trio | ||||
| from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter | |||||
| from graphrag.utils import ( | |||||
| get_llm_cache, | |||||
| get_embed_cache, | |||||
| set_embed_cache, | |||||
| set_llm_cache, | |||||
| chat_limiter, | |||||
| ) | |||||
| from rag.utils import truncate | from rag.utils import truncate | ||||
| class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | ||||
| def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1): | |||||
| def __init__( | |||||
| self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 | |||||
| ): | |||||
| self._max_cluster = max_cluster | self._max_cluster = max_cluster | ||||
| self._llm_model = llm_model | self._llm_model = llm_model | ||||
| self._embd_model = embd_model | self._embd_model = embd_model | ||||
| self._prompt = prompt | self._prompt = prompt | ||||
| self._max_token = max_token | self._max_token = max_token | ||||
| def _chat(self, system, history, gen_conf): | |||||
| async def _chat(self, system, history, gen_conf): | |||||
| response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) | response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) | ||||
| if response: | if response: | ||||
| return response | return response | ||||
| response = self._llm_model.chat(system, history, gen_conf) | |||||
| response = await trio.to_thread.run_sync( | |||||
| lambda: self._llm_model.chat(system, history, gen_conf) | |||||
| ) | |||||
| response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) | response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) | ||||
| if response.find("**ERROR**") >= 0: | if response.find("**ERROR**") >= 0: | ||||
| raise Exception(response) | raise Exception(response) | ||||
| set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | ||||
| return response | return response | ||||
| def _embedding_encode(self, txt): | |||||
| async def _embedding_encode(self, txt): | |||||
| response = get_embed_cache(self._embd_model.llm_name, txt) | response = get_embed_cache(self._embd_model.llm_name, txt) | ||||
| if response is not None: | if response is not None: | ||||
| return response | return response | ||||
| embds, _ = self._embd_model.encode([txt]) | |||||
| embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) | |||||
| if len(embds) < 1 or len(embds[0]) < 1: | if len(embds) < 1 or len(embds[0]) < 1: | ||||
| raise Exception("Embedding error: ") | raise Exception("Embedding error: ") | ||||
| embds = embds[0] | embds = embds[0] | ||||
| return [] | return [] | ||||
| chunks = [(s, a) for s, a in chunks if s and len(a) > 0] | chunks = [(s, a) for s, a in chunks if s and len(a) > 0] | ||||
| async def summarize(ck_idx, lock): | |||||
| async def summarize(ck_idx: list[int]): | |||||
| nonlocal chunks | nonlocal chunks | ||||
| try: | |||||
| texts = [chunks[i][0] for i in ck_idx] | |||||
| len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) | |||||
| cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | |||||
| async with chat_limiter: | |||||
| cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.", | |||||
| [{"role": "user", | |||||
| "content": self._prompt.format(cluster_content=cluster_content)}], | |||||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||||
| )) | |||||
| cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", | |||||
| cnt) | |||||
| logging.debug(f"SUM: {cnt}") | |||||
| embds, _ = self._embd_model.encode([cnt]) | |||||
| with lock: | |||||
| chunks.append((cnt, self._embedding_encode(cnt))) | |||||
| except Exception as e: | |||||
| logging.exception("summarize got exception") | |||||
| return e | |||||
| texts = [chunks[i][0] for i in ck_idx] | |||||
| len_per_chunk = int( | |||||
| (self._llm_model.max_length - self._max_token) / len(texts) | |||||
| ) | |||||
| cluster_content = "\n".join( | |||||
| [truncate(t, max(1, len_per_chunk)) for t in texts] | |||||
| ) | |||||
| async with chat_limiter: | |||||
| cnt = await self._chat( | |||||
| "You're a helpful assistant.", | |||||
| [ | |||||
| { | |||||
| "role": "user", | |||||
| "content": self._prompt.format( | |||||
| cluster_content=cluster_content | |||||
| ), | |||||
| } | |||||
| ], | |||||
| {"temperature": 0.3, "max_tokens": self._max_token}, | |||||
| ) | |||||
| cnt = re.sub( | |||||
| "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", | |||||
| "", | |||||
| cnt, | |||||
| ) | |||||
| logging.debug(f"SUM: {cnt}") | |||||
| embds = await self._embedding_encode(cnt) | |||||
| chunks.append((cnt, embds)) | |||||
| labels = [] | labels = [] | ||||
| lock = Lock() | |||||
| while end - start > 1: | while end - start > 1: | ||||
| embeddings = [embd for _, embd in chunks[start: end]] | |||||
| embeddings = [embd for _, embd in chunks[start:end]] | |||||
| if len(embeddings) == 2: | if len(embeddings) == 2: | ||||
| await summarize([start, start + 1], lock) | |||||
| await summarize([start, start + 1]) | |||||
| if callback: | if callback: | ||||
| callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | |||||
| callback( | |||||
| msg="Cluster one layer: {} -> {}".format( | |||||
| end - start, len(chunks) - end | |||||
| ) | |||||
| ) | |||||
| labels.extend([0, 0]) | labels.extend([0, 0]) | ||||
| layers.append((end, len(chunks))) | layers.append((end, len(chunks))) | ||||
| start = end | start = end | ||||
| n_neighbors = int((len(embeddings) - 1) ** 0.8) | n_neighbors = int((len(embeddings) - 1) ** 0.8) | ||||
| reduced_embeddings = umap.UMAP( | reduced_embeddings = umap.UMAP( | ||||
| n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine" | |||||
| n_neighbors=max(2, n_neighbors), | |||||
| n_components=min(12, len(embeddings) - 2), | |||||
| metric="cosine", | |||||
| ).fit_transform(embeddings) | ).fit_transform(embeddings) | ||||
| n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) | n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) | ||||
| if n_clusters == 1: | if n_clusters == 1: | ||||
| async with trio.open_nursery() as nursery: | async with trio.open_nursery() as nursery: | ||||
| for c in range(n_clusters): | for c in range(n_clusters): | ||||
| ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] | ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] | ||||
| if not ck_idx: | |||||
| continue | |||||
| assert len(ck_idx) > 0 | |||||
| async with chat_limiter: | async with chat_limiter: | ||||
| nursery.start_soon(lambda: summarize(ck_idx, lock)) | |||||
| nursery.start_soon(lambda: summarize(ck_idx)) | |||||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) | |||||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format( | |||||
| len(chunks) - end, n_clusters | |||||
| ) | |||||
| labels.extend(lbls) | labels.extend(lbls) | ||||
| layers.append((end, len(chunks))) | layers.append((end, len(chunks))) | ||||
| if callback: | if callback: | ||||
| callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | |||||
| callback( | |||||
| msg="Cluster one layer: {} -> {}".format( | |||||
| end - start, len(chunks) - end | |||||
| ) | |||||
| ) | |||||
| start = end | start = end | ||||
| end = len(chunks) | end = len(chunks) | ||||
| return chunks | return chunks | ||||
| import sys | import sys | ||||
| from api.utils.log_utils import initRootLogger, get_project_base_directory | from api.utils.log_utils import initRootLogger, get_project_base_directory | ||||
| from graphrag.general.index import WithCommunity, WithResolution, Dealer | |||||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||||
| from graphrag.general.index import run_graphrag | |||||
| from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache | from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache | ||||
| from rag.prompts import keyword_extraction, question_proposal, content_tagging | from rag.prompts import keyword_extraction, question_proposal, content_tagging | ||||
| import resource | import resource | ||||
| import signal | import signal | ||||
| import trio | import trio | ||||
| import exceptiongroup | |||||
| import numpy as np | import numpy as np | ||||
| from peewee import DoesNotExist | from peewee import DoesNotExist | ||||
| return res, tk_count | return res, tk_count | ||||
| async def run_graphrag(row, chat_model, language, embedding_model, callback=None): | |||||
| chunks = [] | |||||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | |||||
| fields=["content_with_weight", "doc_id"]): | |||||
| chunks.append((d["doc_id"], d["content_with_weight"])) | |||||
| dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, | |||||
| row["tenant_id"], | |||||
| str(row["kb_id"]), | |||||
| chat_model, | |||||
| chunks=chunks, | |||||
| language=language, | |||||
| entity_types=row["parser_config"]["graphrag"]["entity_types"], | |||||
| embed_bdl=embedding_model, | |||||
| callback=callback) | |||||
| await dealer() | |||||
| async def do_handle_task(task): | async def do_handle_task(task): | ||||
| task_id = task["id"] | task_id = task["id"] | ||||
| task_from_page = task["from_page"] | task_from_page = task["from_page"] | ||||
| return | return | ||||
| start_ts = timer() | start_ts = timer() | ||||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | ||||
| await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| if graphrag_conf.get("resolution", False): | |||||
| start_ts = timer() | |||||
| with_res = WithResolution( | |||||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| await with_res() | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| if graphrag_conf.get("community", False): | |||||
| start_ts = timer() | |||||
| with_comm = WithCommunity( | |||||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||||
| progress_callback | |||||
| ) | |||||
| await with_comm() | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts)) | |||||
| with_resolution = graphrag_conf.get("resolution", False) | |||||
| with_community = graphrag_conf.get("community", False) | |||||
| await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) | |||||
| progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) | |||||
| return | return | ||||
| else: | else: | ||||
| # Standard chunking methods | # Standard chunking methods | ||||
| FAILED_TASKS += 1 | FAILED_TASKS += 1 | ||||
| CURRENT_TASKS.pop(task["id"], None) | CURRENT_TASKS.pop(task["id"], None) | ||||
| try: | try: | ||||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") | |||||
| err_msg = str(e) | |||||
| while isinstance(e, exceptiongroup.ExceptionGroup): | |||||
| e = e.exceptions[0] | |||||
| err_msg += ' -- ' + str(e) | |||||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}") | |||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| logging.exception(f"handle_task got exception for task {json.dumps(task)}") | logging.exception(f"handle_task got exception for task {json.dumps(task)}") | 
| import logging | import logging | ||||
| import json | import json | ||||
| import time | |||||
| import uuid | import uuid | ||||
| import valkey as redis | import valkey as redis | ||||
| from rag import settings | from rag import settings | ||||
| from rag.utils import singleton | from rag.utils import singleton | ||||
| from valkey.lock import Lock | |||||
| class RedisMsg: | class RedisMsg: | ||||
| def __init__(self, consumer, queue_name, group_name, msg_id, message): | def __init__(self, consumer, queue_name, group_name, msg_id, message): | ||||
| class RedisDistributedLock: | class RedisDistributedLock: | ||||
| def __init__(self, lock_key, timeout=10): | |||||
| def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1): | |||||
| self.lock_key = lock_key | self.lock_key = lock_key | ||||
| self.lock_value = str(uuid.uuid4()) | |||||
| if lock_value: | |||||
| self.lock_value = lock_value | |||||
| else: | |||||
| self.lock_value = str(uuid.uuid4()) | |||||
| self.timeout = timeout | self.timeout = timeout | ||||
| self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout) | |||||
| @staticmethod | |||||
| def clean_lock(lock_key): | |||||
| REDIS_CONN.REDIS.delete(lock_key) | |||||
| def acquire_lock(self): | |||||
| end_time = time.time() + self.timeout | |||||
| while time.time() < end_time: | |||||
| if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value): | |||||
| return True | |||||
| time.sleep(1) | |||||
| return False | |||||
| def acquire(self): | |||||
| return self.lock.acquire() | |||||
| def release_lock(self): | |||||
| if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value: | |||||
| REDIS_CONN.REDIS.delete(self.lock_key) | |||||
| def release(self): | |||||
| return self.lock.release() | |||||
| def __enter__(self): | def __enter__(self): | ||||
| self.acquire_lock() | |||||
| self.acquire() | |||||
| def __exit__(self, exception_type, exception_value, exception_traceback): | def __exit__(self, exception_type, exception_value, exception_traceback): | ||||
| self.release_lock() | |||||
| self.release() |