### What problem does this PR solve? Refactor graphrag to remove redis lock ### Type of change - [x] Refactoringtags/v0.17.1
| @@ -42,16 +42,22 @@ from api.db.init_data import init_web_data | |||
| from api.versions import get_ragflow_version | |||
| from api.utils import show_configs | |||
| from rag.settings import print_rag_settings | |||
| from rag.utils.redis_conn import RedisDistributedLock | |||
| stop_event = threading.Event() | |||
| def update_progress(): | |||
| redis_lock = RedisDistributedLock("update_progress", timeout=60) | |||
| while not stop_event.is_set(): | |||
| try: | |||
| if not redis_lock.acquire(): | |||
| continue | |||
| DocumentService.update_progress() | |||
| stop_event.wait(6) | |||
| except Exception: | |||
| logging.exception("update_progress exception") | |||
| finally: | |||
| redis_lock.release() | |||
| def signal_handler(sig, frame): | |||
| logging.info("Received interrupt signal, shutting down...") | |||
| @@ -93,7 +93,7 @@ class Extractor: | |||
| return dict(maybe_nodes), dict(maybe_edges) | |||
| async def __call__( | |||
| self, chunks: list[tuple[str, str]], | |||
| self, doc_id: str, chunks: list[str], | |||
| callback: Callable | None = None | |||
| ): | |||
| @@ -101,9 +101,9 @@ class Extractor: | |||
| start_ts = trio.current_time() | |||
| out_results = [] | |||
| async with trio.open_nursery() as nursery: | |||
| for i, (cid, ck) in enumerate(chunks): | |||
| for i, ck in enumerate(chunks): | |||
| ck = truncate(ck, int(self._llm.max_length*0.8)) | |||
| nursery.start_soon(lambda: self._process_single_content((cid, ck), i, len(chunks), out_results)) | |||
| nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results)) | |||
| maybe_nodes = defaultdict(list) | |||
| maybe_edges = defaultdict(list) | |||
| @@ -241,10 +241,13 @@ class Extractor: | |||
| ) -> str: | |||
| summary_max_tokens = 512 | |||
| use_description = truncate(description, summary_max_tokens) | |||
| description_list=use_description.split(GRAPH_FIELD_SEP), | |||
| if len(description_list) <= 12: | |||
| return use_description | |||
| prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT | |||
| context_base = dict( | |||
| entity_name=entity_or_relation_name, | |||
| description_list=use_description.split(GRAPH_FIELD_SEP), | |||
| description_list=description_list, | |||
| language=self._language, | |||
| ) | |||
| use_prompt = prompt_template.format(**context_base) | |||
| @@ -15,196 +15,353 @@ | |||
| # | |||
| import json | |||
| import logging | |||
| from functools import reduce, partial | |||
| from functools import partial | |||
| import networkx as nx | |||
| import trio | |||
| from api import settings | |||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||
| from graphrag.general.community_reports_extractor import CommunityReportsExtractor | |||
| from graphrag.entity_resolution import EntityResolution | |||
| from graphrag.general.extractor import Extractor | |||
| from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES | |||
| from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \ | |||
| chunk_id, update_nodes_pagerank_nhop_neighbour | |||
| from graphrag.utils import ( | |||
| graph_merge, | |||
| set_entity, | |||
| get_relation, | |||
| set_relation, | |||
| get_entity, | |||
| get_graph, | |||
| set_graph, | |||
| chunk_id, | |||
| update_nodes_pagerank_nhop_neighbour, | |||
| does_graph_contains, | |||
| get_graph_doc_ids, | |||
| ) | |||
| from rag.nlp import rag_tokenizer, search | |||
| from rag.utils.redis_conn import RedisDistributedLock | |||
| class Dealer: | |||
| def __init__(self, | |||
| extractor: Extractor, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| llm_bdl, | |||
| chunks: list[tuple[str, str]], | |||
| language, | |||
| entity_types=DEFAULT_ENTITY_TYPES, | |||
| embed_bdl=None, | |||
| callback=None | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.kb_id = kb_id | |||
| self.chunks = chunks | |||
| self.llm_bdl = llm_bdl | |||
| self.embed_bdl = embed_bdl | |||
| self.ext = extractor(self.llm_bdl, language=language, | |||
| entity_types=entity_types, | |||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||
| set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl), | |||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||
| set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl) | |||
| ) | |||
| self.graph = nx.Graph() | |||
| self.callback = callback | |||
| async def __call__(self): | |||
| docids = list(set([docid for docid, _ in self.chunks])) | |||
| ents, rels = await self.ext(self.chunks, self.callback) | |||
| for en in ents: | |||
| self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"]) | |||
| for rel in rels: | |||
| self.graph.add_edge( | |||
| rel["src_id"], | |||
| rel["tgt_id"], | |||
| weight=rel["weight"], | |||
| #description=rel["description"] | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool: | |||
| key = f"graphrag:{tenant_id}:{kb_id}" | |||
| ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24) | |||
| if not ok: | |||
| raise Exception(f"Faild to set the {key} to {doc_id}") | |||
| def graphrag_task_get(tenant_id, kb_id) -> str | None: | |||
| key = f"graphrag:{tenant_id}:{kb_id}" | |||
| doc_id = REDIS_CONN.get(key) | |||
| return doc_id | |||
| async def run_graphrag( | |||
| row: dict, | |||
| language, | |||
| with_resolution: bool, | |||
| with_community: bool, | |||
| chat_model, | |||
| embedding_model, | |||
| callback, | |||
| ): | |||
| start = trio.current_time() | |||
| tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] | |||
| chunks = [] | |||
| for d in settings.retrievaler.chunk_list( | |||
| doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"] | |||
| ): | |||
| chunks.append(d["content_with_weight"]) | |||
| graph, doc_ids = await update_graph( | |||
| LightKGExt | |||
| if row["parser_config"]["graphrag"]["method"] != "general" | |||
| else GeneralKGExt, | |||
| tenant_id, | |||
| kb_id, | |||
| doc_id, | |||
| chunks, | |||
| language, | |||
| row["parser_config"]["graphrag"]["entity_types"], | |||
| chat_model, | |||
| embedding_model, | |||
| callback, | |||
| ) | |||
| if not graph: | |||
| return | |||
| if with_resolution or with_community: | |||
| graphrag_task_set(tenant_id, kb_id, doc_id) | |||
| if with_resolution: | |||
| await resolve_entities( | |||
| graph, | |||
| doc_ids, | |||
| tenant_id, | |||
| kb_id, | |||
| doc_id, | |||
| chat_model, | |||
| embedding_model, | |||
| callback, | |||
| ) | |||
| if with_community: | |||
| await extract_community( | |||
| graph, | |||
| doc_ids, | |||
| tenant_id, | |||
| kb_id, | |||
| doc_id, | |||
| chat_model, | |||
| embedding_model, | |||
| callback, | |||
| ) | |||
| now = trio.current_time() | |||
| callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") | |||
| return | |||
| async def update_graph( | |||
| extractor: Extractor, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| doc_id: str, | |||
| chunks: list[str], | |||
| language, | |||
| entity_types, | |||
| llm_bdl, | |||
| embed_bdl, | |||
| callback, | |||
| ): | |||
| contains = await does_graph_contains(tenant_id, kb_id, doc_id) | |||
| if contains: | |||
| callback(msg=f"Graph already contains {doc_id}, cancel myself") | |||
| return None, None | |||
| start = trio.current_time() | |||
| ext = extractor( | |||
| llm_bdl, | |||
| language=language, | |||
| entity_types=entity_types, | |||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||
| ) | |||
| ents, rels = await ext(doc_id, chunks, callback) | |||
| subgraph = nx.Graph() | |||
| for en in ents: | |||
| subgraph.add_node(en["entity_name"], entity_type=en["entity_type"]) | |||
| for rel in rels: | |||
| subgraph.add_edge( | |||
| rel["src_id"], | |||
| rel["tgt_id"], | |||
| weight=rel["weight"], | |||
| # description=rel["description"] | |||
| ) | |||
| # TODO: infinity doesn't support array search | |||
| chunk = { | |||
| "content_with_weight": json.dumps( | |||
| nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2 | |||
| ), | |||
| "knowledge_graph_kwd": "subgraph", | |||
| "kb_id": kb_id, | |||
| "source_id": [doc_id], | |||
| "available_int": 0, | |||
| "removed_kwd": "N", | |||
| } | |||
| cid = chunk_id(chunk) | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.insert( | |||
| [{"id": cid, **chunk}], search.index_name(tenant_id), kb_id | |||
| ) | |||
| ) | |||
| now = trio.current_time() | |||
| callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") | |||
| start = now | |||
| while True: | |||
| new_graph = subgraph | |||
| now_docids = set([doc_id]) | |||
| old_graph, old_doc_ids = await get_graph(tenant_id, kb_id) | |||
| if old_graph is not None: | |||
| logging.info("Merge with an exiting graph...................") | |||
| new_graph = graph_merge(old_graph, subgraph) | |||
| await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2) | |||
| if old_doc_ids: | |||
| for old_doc_id in old_doc_ids: | |||
| now_docids.add(old_doc_id) | |||
| old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id) | |||
| delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids) | |||
| if delta_doc_ids: | |||
| callback( | |||
| msg="The global graph has changed during merging, try again" | |||
| ) | |||
| await trio.sleep(1) | |||
| continue | |||
| break | |||
| await set_graph(tenant_id, kb_id, new_graph, list(now_docids)) | |||
| now = trio.current_time() | |||
| callback( | |||
| msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds." | |||
| ) | |||
| return new_graph, now_docids | |||
| async def resolve_entities( | |||
| graph, | |||
| doc_ids, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| doc_id: str, | |||
| llm_bdl, | |||
| embed_bdl, | |||
| callback, | |||
| ): | |||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||
| if doc_id != working_doc_id: | |||
| callback( | |||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||
| ) | |||
| return | |||
| start = trio.current_time() | |||
| er = EntityResolution( | |||
| llm_bdl, | |||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||
| ) | |||
| reso = await er(graph) | |||
| graph = reso.graph | |||
| callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.") | |||
| await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2) | |||
| callback(msg="Graph resolution updated pagerank.") | |||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||
| if doc_id != working_doc_id: | |||
| callback( | |||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||
| ) | |||
| return | |||
| await set_graph(tenant_id, kb_id, graph, doc_ids) | |||
| with RedisDistributedLock(self.kb_id, 60*60): | |||
| old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id) | |||
| if old_graph is not None: | |||
| logging.info("Merge with an exiting graph...................") | |||
| self.graph = reduce(graph_merge, [old_graph, self.graph]) | |||
| update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2) | |||
| if old_doc_ids: | |||
| docids.extend(old_doc_ids) | |||
| docids = list(set(docids)) | |||
| set_graph(self.tenant_id, self.kb_id, self.graph, docids) | |||
| class WithResolution(Dealer): | |||
| def __init__(self, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| llm_bdl, | |||
| embed_bdl=None, | |||
| callback=None | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.kb_id = kb_id | |||
| self.llm_bdl = llm_bdl | |||
| self.embed_bdl = embed_bdl | |||
| self.callback = callback | |||
| async def __call__(self): | |||
| with RedisDistributedLock(self.kb_id, 60*60): | |||
| self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id)) | |||
| if not self.graph: | |||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||
| if self.callback: | |||
| self.callback(-1, msg="Faild to fetch the graph.") | |||
| return | |||
| if self.callback: | |||
| self.callback(msg="Fetch the existing graph.") | |||
| er = EntityResolution(self.llm_bdl, | |||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||
| reso = await er(self.graph) | |||
| self.graph = reso.graph | |||
| logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||
| if self.callback: | |||
| self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities))) | |||
| await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)) | |||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||
| "knowledge_graph_kwd": "relation", | |||
| "kb_id": self.kb_id, | |||
| "from_entity_kwd": reso.removed_entities | |||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||
| "knowledge_graph_kwd": "relation", | |||
| "kb_id": self.kb_id, | |||
| "to_entity_kwd": reso.removed_entities | |||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||
| "knowledge_graph_kwd": "entity", | |||
| "kb_id": self.kb_id, | |||
| "entity_kwd": reso.removed_entities | |||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||
| class WithCommunity(Dealer): | |||
| def __init__(self, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| llm_bdl, | |||
| embed_bdl=None, | |||
| callback=None | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.kb_id = kb_id | |||
| self.community_structure = None | |||
| self.community_reports = None | |||
| self.llm_bdl = llm_bdl | |||
| self.embed_bdl = embed_bdl | |||
| self.callback = callback | |||
| async def __call__(self): | |||
| with RedisDistributedLock(self.kb_id, 60*60): | |||
| self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id) | |||
| if not self.graph: | |||
| logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}") | |||
| if self.callback: | |||
| self.callback(-1, msg="Faild to fetch the graph.") | |||
| return | |||
| if self.callback: | |||
| self.callback(msg="Fetch the existing graph.") | |||
| cr = CommunityReportsExtractor(self.llm_bdl, | |||
| get_entity=partial(get_entity, self.tenant_id, self.kb_id), | |||
| set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl), | |||
| get_relation=partial(get_relation, self.tenant_id, self.kb_id), | |||
| set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl)) | |||
| cr = await cr(self.graph, callback=self.callback) | |||
| self.community_structure = cr.structured_output | |||
| self.community_reports = cr.output | |||
| await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids)) | |||
| if self.callback: | |||
| self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output))) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({ | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.delete( | |||
| { | |||
| "knowledge_graph_kwd": "relation", | |||
| "kb_id": kb_id, | |||
| "from_entity_kwd": reso.removed_entities, | |||
| }, | |||
| search.index_name(tenant_id), | |||
| kb_id, | |||
| ) | |||
| ) | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.delete( | |||
| { | |||
| "knowledge_graph_kwd": "relation", | |||
| "kb_id": kb_id, | |||
| "to_entity_kwd": reso.removed_entities, | |||
| }, | |||
| search.index_name(tenant_id), | |||
| kb_id, | |||
| ) | |||
| ) | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.delete( | |||
| { | |||
| "knowledge_graph_kwd": "entity", | |||
| "kb_id": kb_id, | |||
| "entity_kwd": reso.removed_entities, | |||
| }, | |||
| search.index_name(tenant_id), | |||
| kb_id, | |||
| ) | |||
| ) | |||
| now = trio.current_time() | |||
| callback(msg=f"Graph resolution done in {now - start:.2f}s.") | |||
| async def extract_community( | |||
| graph, | |||
| doc_ids, | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| doc_id: str, | |||
| llm_bdl, | |||
| embed_bdl, | |||
| callback, | |||
| ): | |||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||
| if doc_id != working_doc_id: | |||
| callback( | |||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||
| ) | |||
| return | |||
| start = trio.current_time() | |||
| ext = CommunityReportsExtractor( | |||
| llm_bdl, | |||
| get_entity=partial(get_entity, tenant_id, kb_id), | |||
| set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl), | |||
| get_relation=partial(get_relation, tenant_id, kb_id), | |||
| set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl), | |||
| ) | |||
| cr = await ext(graph, callback=callback) | |||
| community_structure = cr.structured_output | |||
| community_reports = cr.output | |||
| working_doc_id = graphrag_task_get(tenant_id, kb_id) | |||
| if doc_id != working_doc_id: | |||
| callback( | |||
| msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself" | |||
| ) | |||
| return | |||
| await set_graph(tenant_id, kb_id, graph, doc_ids) | |||
| now = trio.current_time() | |||
| callback( | |||
| msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s." | |||
| ) | |||
| start = now | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.delete( | |||
| {"knowledge_graph_kwd": "community_report", "kb_id": kb_id}, | |||
| search.index_name(tenant_id), | |||
| kb_id, | |||
| ) | |||
| ) | |||
| for stru, rep in zip(community_structure, community_reports): | |||
| obj = { | |||
| "report": rep, | |||
| "evidences": "\n".join([f["explanation"] for f in stru["findings"]]), | |||
| } | |||
| chunk = { | |||
| "docnm_kwd": stru["title"], | |||
| "title_tks": rag_tokenizer.tokenize(stru["title"]), | |||
| "content_with_weight": json.dumps(obj, ensure_ascii=False), | |||
| "content_ltks": rag_tokenizer.tokenize( | |||
| obj["report"] + " " + obj["evidences"] | |||
| ), | |||
| "knowledge_graph_kwd": "community_report", | |||
| "kb_id": self.kb_id | |||
| }, search.index_name(self.tenant_id), self.kb_id)) | |||
| for stru, rep in zip(self.community_structure, self.community_reports): | |||
| obj = { | |||
| "report": rep, | |||
| "evidences": "\n".join([f["explanation"] for f in stru["findings"]]) | |||
| } | |||
| chunk = { | |||
| "docnm_kwd": stru["title"], | |||
| "title_tks": rag_tokenizer.tokenize(stru["title"]), | |||
| "content_with_weight": json.dumps(obj, ensure_ascii=False), | |||
| "content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]), | |||
| "knowledge_graph_kwd": "community_report", | |||
| "weight_flt": stru["weight"], | |||
| "entities_kwd": stru["entities"], | |||
| "important_kwd": stru["entities"], | |||
| "kb_id": self.kb_id, | |||
| "source_id": doc_ids, | |||
| "available_int": 0 | |||
| } | |||
| chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) | |||
| #try: | |||
| # ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])]) | |||
| # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | |||
| #except Exception as e: | |||
| # logging.exception(f"Fail to embed entity relation: {e}") | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id))) | |||
| "weight_flt": stru["weight"], | |||
| "entities_kwd": stru["entities"], | |||
| "important_kwd": stru["entities"], | |||
| "kb_id": kb_id, | |||
| "source_id": doc_ids, | |||
| "available_int": 0, | |||
| } | |||
| chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize( | |||
| chunk["content_ltks"] | |||
| ) | |||
| # try: | |||
| # ebd, _ = embed_bdl.encode([", ".join(community["entities"])]) | |||
| # chunk["q_%d_vec" % len(ebd[0])] = ebd[0] | |||
| # except Exception as e: | |||
| # logging.exception(f"Fail to embed entity relation: {e}") | |||
| await trio.to_thread.run_sync( | |||
| lambda: settings.docStoreConn.insert( | |||
| [{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id) | |||
| ) | |||
| ) | |||
| now = trio.current_time() | |||
| callback( | |||
| msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s." | |||
| ) | |||
| return community_structure, community_reports | |||
| @@ -16,7 +16,7 @@ | |||
| import argparse | |||
| import json | |||
| import logging | |||
| import networkx as nx | |||
| import trio | |||
| @@ -26,42 +26,85 @@ from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import TenantService | |||
| from graphrag.general.index import WithCommunity, Dealer, WithResolution | |||
| from graphrag.light.graph_extractor import GraphExtractor | |||
| from rag.utils.redis_conn import RedisDistributedLock | |||
| from graphrag.general.graph_extractor import GraphExtractor | |||
| from graphrag.general.index import update_graph, with_resolution, with_community | |||
| settings.init_settings() | |||
| if __name__ == "__main__": | |||
| def callback(prog=None, msg="Processing..."): | |||
| logging.info(msg) | |||
| async def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) | |||
| parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) | |||
| parser.add_argument( | |||
| "-t", | |||
| "--tenant_id", | |||
| default=False, | |||
| help="Tenant ID", | |||
| action="store", | |||
| required=True, | |||
| ) | |||
| parser.add_argument( | |||
| "-d", | |||
| "--doc_id", | |||
| default=False, | |||
| help="Document ID", | |||
| action="store", | |||
| required=True, | |||
| ) | |||
| args = parser.parse_args() | |||
| e, doc = DocumentService.get_by_id(args.doc_id) | |||
| if not e: | |||
| raise LookupError("Document not found.") | |||
| kb_id = doc.kb_id | |||
| chunks = [d["content_with_weight"] for d in | |||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, | |||
| fields=["content_with_weight"])] | |||
| chunks = [("x", c) for c in chunks] | |||
| RedisDistributedLock.clean_lock(kb_id) | |||
| chunks = [ | |||
| d["content_with_weight"] | |||
| for d in settings.retrievaler.chunk_list( | |||
| args.doc_id, | |||
| args.tenant_id, | |||
| [kb_id], | |||
| max_count=6, | |||
| fields=["content_with_weight"], | |||
| ) | |||
| ] | |||
| _, tenant = TenantService.get_by_id(args.tenant_id) | |||
| llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | |||
| _, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | |||
| dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | |||
| trio.run(dealer()) | |||
| print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | |||
| graph, doc_ids = await update_graph( | |||
| GraphExtractor, | |||
| args.tenant_id, | |||
| kb_id, | |||
| args.doc_id, | |||
| chunks, | |||
| "English", | |||
| llm_bdl, | |||
| embed_bdl, | |||
| callback, | |||
| ) | |||
| print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) | |||
| await with_resolution( | |||
| args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback | |||
| ) | |||
| community_structure, community_reports = await with_community( | |||
| args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback | |||
| ) | |||
| dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl) | |||
| trio.run(dealer()) | |||
| dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl) | |||
| trio.run(dealer()) | |||
| print( | |||
| "------------------ COMMUNITY STRUCTURE--------------------\n", | |||
| json.dumps(community_structure, ensure_ascii=False, indent=2), | |||
| ) | |||
| print( | |||
| "------------------ COMMUNITY REPORTS----------------------\n", | |||
| community_reports, | |||
| ) | |||
| print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports) | |||
| print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2)) | |||
| if __name__ == "__main__": | |||
| trio.run(main) | |||
| @@ -18,22 +18,42 @@ import argparse | |||
| import json | |||
| from api import settings | |||
| import networkx as nx | |||
| import logging | |||
| import trio | |||
| from api.db import LLMType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import TenantService | |||
| from graphrag.general.index import Dealer | |||
| from graphrag.general.index import update_graph | |||
| from graphrag.light.graph_extractor import GraphExtractor | |||
| from rag.utils.redis_conn import RedisDistributedLock | |||
| settings.init_settings() | |||
| if __name__ == "__main__": | |||
| def callback(prog=None, msg="Processing..."): | |||
| logging.info(msg) | |||
| async def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True) | |||
| parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True) | |||
| parser.add_argument( | |||
| "-t", | |||
| "--tenant_id", | |||
| default=False, | |||
| help="Tenant ID", | |||
| action="store", | |||
| required=True, | |||
| ) | |||
| parser.add_argument( | |||
| "-d", | |||
| "--doc_id", | |||
| default=False, | |||
| help="Document ID", | |||
| action="store", | |||
| required=True, | |||
| ) | |||
| args = parser.parse_args() | |||
| e, doc = DocumentService.get_by_id(args.doc_id) | |||
| @@ -41,18 +61,36 @@ if __name__ == "__main__": | |||
| raise LookupError("Document not found.") | |||
| kb_id = doc.kb_id | |||
| chunks = [d["content_with_weight"] for d in | |||
| settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6, | |||
| fields=["content_with_weight"])] | |||
| chunks = [("x", c) for c in chunks] | |||
| RedisDistributedLock.clean_lock(kb_id) | |||
| chunks = [ | |||
| d["content_with_weight"] | |||
| for d in settings.retrievaler.chunk_list( | |||
| args.doc_id, | |||
| args.tenant_id, | |||
| [kb_id], | |||
| max_count=6, | |||
| fields=["content_with_weight"], | |||
| ) | |||
| ] | |||
| _, tenant = TenantService.get_by_id(args.tenant_id) | |||
| llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) | |||
| _, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) | |||
| dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl) | |||
| graph, doc_ids = await update_graph( | |||
| GraphExtractor, | |||
| args.tenant_id, | |||
| kb_id, | |||
| args.doc_id, | |||
| chunks, | |||
| "English", | |||
| llm_bdl, | |||
| embed_bdl, | |||
| callback, | |||
| ) | |||
| print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2)) | |||
| print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2)) | |||
| if __name__ == "__main__": | |||
| trio.run(main) | |||
| @@ -352,25 +352,57 @@ def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta): | |||
| chunk["q_%d_vec" % len(ebd)] = ebd | |||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | |||
| async def does_graph_contains(tenant_id, kb_id, doc_id): | |||
| # Get doc_ids of graph | |||
| fields = ["source_id"] | |||
| condition = { | |||
| "knowledge_graph_kwd": ["graph"], | |||
| "removed_kwd": "N", | |||
| } | |||
| res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id])) | |||
| fields2 = settings.docStoreConn.getFields(res, fields) | |||
| graph_doc_ids = set() | |||
| for chunk_id in fields2.keys(): | |||
| graph_doc_ids = set(fields2[chunk_id]["source_id"]) | |||
| return doc_id in graph_doc_ids | |||
| async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: | |||
| conds = { | |||
| "fields": ["source_id"], | |||
| "removed_kwd": "N", | |||
| "size": 1, | |||
| "knowledge_graph_kwd": ["graph"] | |||
| } | |||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) | |||
| doc_ids = [] | |||
| if res.total == 0: | |||
| return doc_ids | |||
| for id in res.ids: | |||
| doc_ids = res.field[id]["source_id"] | |||
| return doc_ids | |||
| def get_graph(tenant_id, kb_id): | |||
| async def get_graph(tenant_id, kb_id): | |||
| conds = { | |||
| "fields": ["content_with_weight", "source_id"], | |||
| "removed_kwd": "N", | |||
| "size": 1, | |||
| "knowledge_graph_kwd": ["graph"] | |||
| } | |||
| res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]) | |||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])) | |||
| if res.total == 0: | |||
| return None, [] | |||
| for id in res.ids: | |||
| try: | |||
| return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ | |||
| res.field[id]["source_id"] | |||
| except Exception: | |||
| continue | |||
| return rebuild_graph(tenant_id, kb_id) | |||
| result = await rebuild_graph(tenant_id, kb_id) | |||
| return result | |||
| def set_graph(tenant_id, kb_id, graph, docids): | |||
| async def set_graph(tenant_id, kb_id, graph, docids): | |||
| chunk = { | |||
| "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, | |||
| indent=2), | |||
| @@ -379,13 +411,13 @@ def set_graph(tenant_id, kb_id, graph, docids): | |||
| "source_id": list(docids), | |||
| "available_int": 0, | |||
| "removed_kwd": "N" | |||
| } | |||
| res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]) | |||
| } | |||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])) | |||
| if res.ids: | |||
| settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, | |||
| search.index_name(tenant_id), kb_id) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk, | |||
| search.index_name(tenant_id), kb_id)) | |||
| else: | |||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) | |||
| def is_continuous_subsequence(subseq, seq): | |||
| @@ -430,7 +462,7 @@ def merge_tuples(list1, list2): | |||
| return result | |||
| def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||
| async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||
| def n_neighbor(id): | |||
| nonlocal graph, n_hop | |||
| count = 0 | |||
| @@ -460,10 +492,10 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||
| for n, p in pr.items(): | |||
| graph.nodes[n]["pagerank"] = p | |||
| try: | |||
| settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id}, | |||
| {"rank_flt": p, | |||
| "n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)}, | |||
| search.index_name(tenant_id), kb_id) | |||
| "n_hop_with_weight": json.dumps( (n), ensure_ascii=False)}, | |||
| search.index_name(tenant_id), kb_id)) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| @@ -480,21 +512,21 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop): | |||
| "knowledge_graph_kwd": "ty2ents", | |||
| "available_int": 0 | |||
| } | |||
| res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, | |||
| search.index_name(tenant_id), [kb_id]) | |||
| res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []}, | |||
| search.index_name(tenant_id), [kb_id])) | |||
| if res.ids: | |||
| settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"}, | |||
| chunk, | |||
| search.index_name(tenant_id), kb_id) | |||
| search.index_name(tenant_id), kb_id)) | |||
| else: | |||
| settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) | |||
| await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)) | |||
| def get_entity_type2sampels(idxnms, kb_ids: list): | |||
| es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, | |||
| async def get_entity_type2sampels(idxnms, kb_ids: list): | |||
| es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, | |||
| "size": 10000, | |||
| "fields": ["content_with_weight"]}, | |||
| idxnms, kb_ids) | |||
| idxnms, kb_ids)) | |||
| res = defaultdict(list) | |||
| for id in es_res.ids: | |||
| @@ -522,18 +554,18 @@ def flat_uniq_list(arr, key): | |||
| return list(set(res)) | |||
| def rebuild_graph(tenant_id, kb_id): | |||
| async def rebuild_graph(tenant_id, kb_id): | |||
| graph = nx.Graph() | |||
| src_ids = [] | |||
| flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] | |||
| bs = 256 | |||
| for i in range(0, 39*bs, bs): | |||
| es_res = settings.docStoreConn.search(flds, [], | |||
| es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [], | |||
| {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, | |||
| [], | |||
| OrderByExpr(), | |||
| i, bs, search.index_name(tenant_id), [kb_id] | |||
| ) | |||
| )) | |||
| tot = settings.docStoreConn.getTotal(es_res) | |||
| if tot == 0: | |||
| return None, None | |||
| @@ -15,18 +15,25 @@ | |||
| # | |||
| import logging | |||
| import re | |||
| from threading import Lock | |||
| import umap | |||
| import numpy as np | |||
| from sklearn.mixture import GaussianMixture | |||
| import trio | |||
| from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter | |||
| from graphrag.utils import ( | |||
| get_llm_cache, | |||
| get_embed_cache, | |||
| set_embed_cache, | |||
| set_llm_cache, | |||
| chat_limiter, | |||
| ) | |||
| from rag.utils import truncate | |||
| class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1): | |||
| def __init__( | |||
| self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1 | |||
| ): | |||
| self._max_cluster = max_cluster | |||
| self._llm_model = llm_model | |||
| self._embd_model = embd_model | |||
| @@ -34,22 +41,24 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| self._prompt = prompt | |||
| self._max_token = max_token | |||
| def _chat(self, system, history, gen_conf): | |||
| async def _chat(self, system, history, gen_conf): | |||
| response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) | |||
| if response: | |||
| return response | |||
| response = self._llm_model.chat(system, history, gen_conf) | |||
| response = await trio.to_thread.run_sync( | |||
| lambda: self._llm_model.chat(system, history, gen_conf) | |||
| ) | |||
| response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | |||
| return response | |||
| def _embedding_encode(self, txt): | |||
| async def _embedding_encode(self, txt): | |||
| response = get_embed_cache(self._embd_model.llm_name, txt) | |||
| if response is not None: | |||
| return response | |||
| embds, _ = self._embd_model.encode([txt]) | |||
| embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt])) | |||
| if len(embds) < 1 or len(embds[0]) < 1: | |||
| raise Exception("Embedding error: ") | |||
| embds = embds[0] | |||
| @@ -74,36 +83,48 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| return [] | |||
| chunks = [(s, a) for s, a in chunks if s and len(a) > 0] | |||
| async def summarize(ck_idx, lock): | |||
| async def summarize(ck_idx: list[int]): | |||
| nonlocal chunks | |||
| try: | |||
| texts = [chunks[i][0] for i in ck_idx] | |||
| len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) | |||
| cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | |||
| async with chat_limiter: | |||
| cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.", | |||
| [{"role": "user", | |||
| "content": self._prompt.format(cluster_content=cluster_content)}], | |||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||
| )) | |||
| cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", | |||
| cnt) | |||
| logging.debug(f"SUM: {cnt}") | |||
| embds, _ = self._embd_model.encode([cnt]) | |||
| with lock: | |||
| chunks.append((cnt, self._embedding_encode(cnt))) | |||
| except Exception as e: | |||
| logging.exception("summarize got exception") | |||
| return e | |||
| texts = [chunks[i][0] for i in ck_idx] | |||
| len_per_chunk = int( | |||
| (self._llm_model.max_length - self._max_token) / len(texts) | |||
| ) | |||
| cluster_content = "\n".join( | |||
| [truncate(t, max(1, len_per_chunk)) for t in texts] | |||
| ) | |||
| async with chat_limiter: | |||
| cnt = await self._chat( | |||
| "You're a helpful assistant.", | |||
| [ | |||
| { | |||
| "role": "user", | |||
| "content": self._prompt.format( | |||
| cluster_content=cluster_content | |||
| ), | |||
| } | |||
| ], | |||
| {"temperature": 0.3, "max_tokens": self._max_token}, | |||
| ) | |||
| cnt = re.sub( | |||
| "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", | |||
| "", | |||
| cnt, | |||
| ) | |||
| logging.debug(f"SUM: {cnt}") | |||
| embds = await self._embedding_encode(cnt) | |||
| chunks.append((cnt, embds)) | |||
| labels = [] | |||
| lock = Lock() | |||
| while end - start > 1: | |||
| embeddings = [embd for _, embd in chunks[start: end]] | |||
| embeddings = [embd for _, embd in chunks[start:end]] | |||
| if len(embeddings) == 2: | |||
| await summarize([start, start + 1], lock) | |||
| await summarize([start, start + 1]) | |||
| if callback: | |||
| callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | |||
| callback( | |||
| msg="Cluster one layer: {} -> {}".format( | |||
| end - start, len(chunks) - end | |||
| ) | |||
| ) | |||
| labels.extend([0, 0]) | |||
| layers.append((end, len(chunks))) | |||
| start = end | |||
| @@ -112,7 +133,9 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| n_neighbors = int((len(embeddings) - 1) ** 0.8) | |||
| reduced_embeddings = umap.UMAP( | |||
| n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine" | |||
| n_neighbors=max(2, n_neighbors), | |||
| n_components=min(12, len(embeddings) - 2), | |||
| metric="cosine", | |||
| ).fit_transform(embeddings) | |||
| n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) | |||
| if n_clusters == 1: | |||
| @@ -127,18 +150,22 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| async with trio.open_nursery() as nursery: | |||
| for c in range(n_clusters): | |||
| ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] | |||
| if not ck_idx: | |||
| continue | |||
| assert len(ck_idx) > 0 | |||
| async with chat_limiter: | |||
| nursery.start_soon(lambda: summarize(ck_idx, lock)) | |||
| nursery.start_soon(lambda: summarize(ck_idx)) | |||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) | |||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format( | |||
| len(chunks) - end, n_clusters | |||
| ) | |||
| labels.extend(lbls) | |||
| layers.append((end, len(chunks))) | |||
| if callback: | |||
| callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) | |||
| callback( | |||
| msg="Cluster one layer: {} -> {}".format( | |||
| end - start, len(chunks) - end | |||
| ) | |||
| ) | |||
| start = end | |||
| end = len(chunks) | |||
| return chunks | |||
| @@ -20,9 +20,7 @@ import random | |||
| import sys | |||
| from api.utils.log_utils import initRootLogger, get_project_base_directory | |||
| from graphrag.general.index import WithCommunity, WithResolution, Dealer | |||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||
| from graphrag.general.index import run_graphrag | |||
| from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache | |||
| from rag.prompts import keyword_extraction, question_proposal, content_tagging | |||
| @@ -45,6 +43,7 @@ import tracemalloc | |||
| import resource | |||
| import signal | |||
| import trio | |||
| import exceptiongroup | |||
| import numpy as np | |||
| from peewee import DoesNotExist | |||
| @@ -453,24 +452,6 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||
| return res, tk_count | |||
| async def run_graphrag(row, chat_model, language, embedding_model, callback=None): | |||
| chunks = [] | |||
| for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], | |||
| fields=["content_with_weight", "doc_id"]): | |||
| chunks.append((d["doc_id"], d["content_with_weight"])) | |||
| dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt, | |||
| row["tenant_id"], | |||
| str(row["kb_id"]), | |||
| chat_model, | |||
| chunks=chunks, | |||
| language=language, | |||
| entity_types=row["parser_config"]["graphrag"]["entity_types"], | |||
| embed_bdl=embedding_model, | |||
| callback=callback) | |||
| await dealer() | |||
| async def do_handle_task(task): | |||
| task_id = task["id"] | |||
| task_from_page = task["from_page"] | |||
| @@ -526,24 +507,10 @@ async def do_handle_task(task): | |||
| return | |||
| start_ts = timer() | |||
| chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) | |||
| await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback) | |||
| progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts)) | |||
| if graphrag_conf.get("resolution", False): | |||
| start_ts = timer() | |||
| with_res = WithResolution( | |||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||
| progress_callback | |||
| ) | |||
| await with_res() | |||
| progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts)) | |||
| if graphrag_conf.get("community", False): | |||
| start_ts = timer() | |||
| with_comm = WithCommunity( | |||
| task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model, | |||
| progress_callback | |||
| ) | |||
| await with_comm() | |||
| progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts)) | |||
| with_resolution = graphrag_conf.get("resolution", False) | |||
| with_community = graphrag_conf.get("community", False) | |||
| await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) | |||
| progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts)) | |||
| return | |||
| else: | |||
| # Standard chunking methods | |||
| @@ -622,7 +589,11 @@ async def handle_task(): | |||
| FAILED_TASKS += 1 | |||
| CURRENT_TASKS.pop(task["id"], None) | |||
| try: | |||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}") | |||
| err_msg = str(e) | |||
| while isinstance(e, exceptiongroup.ExceptionGroup): | |||
| e = e.exceptions[0] | |||
| err_msg += ' -- ' + str(e) | |||
| set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}") | |||
| except Exception: | |||
| pass | |||
| logging.exception(f"handle_task got exception for task {json.dumps(task)}") | |||
| @@ -16,13 +16,12 @@ | |||
| import logging | |||
| import json | |||
| import time | |||
| import uuid | |||
| import valkey as redis | |||
| from rag import settings | |||
| from rag.utils import singleton | |||
| from valkey.lock import Lock | |||
| class RedisMsg: | |||
| def __init__(self, consumer, queue_name, group_name, msg_id, message): | |||
| @@ -281,29 +280,23 @@ REDIS_CONN = RedisDB() | |||
| class RedisDistributedLock: | |||
| def __init__(self, lock_key, timeout=10): | |||
| def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1): | |||
| self.lock_key = lock_key | |||
| self.lock_value = str(uuid.uuid4()) | |||
| if lock_value: | |||
| self.lock_value = lock_value | |||
| else: | |||
| self.lock_value = str(uuid.uuid4()) | |||
| self.timeout = timeout | |||
| self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout) | |||
| @staticmethod | |||
| def clean_lock(lock_key): | |||
| REDIS_CONN.REDIS.delete(lock_key) | |||
| def acquire_lock(self): | |||
| end_time = time.time() + self.timeout | |||
| while time.time() < end_time: | |||
| if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value): | |||
| return True | |||
| time.sleep(1) | |||
| return False | |||
| def acquire(self): | |||
| return self.lock.acquire() | |||
| def release_lock(self): | |||
| if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value: | |||
| REDIS_CONN.REDIS.delete(self.lock_key) | |||
| def release(self): | |||
| return self.lock.release() | |||
| def __enter__(self): | |||
| self.acquire_lock() | |||
| self.acquire() | |||
| def __exit__(self, exception_type, exception_value, exception_traceback): | |||
| self.release_lock() | |||
| self.release() | |||