### What problem does this PR solve? ### Type of change - [x] Performance Improvementtags/v0.20.0
| @@ -1,3 +1,18 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from flask import Response, request | |||
| from flask_login import current_user, login_required | |||
| @@ -6,9 +21,10 @@ from api.db.db_models import MCPServer | |||
| from api.db.services.mcp_server_service import MCPServerService | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import RetCode | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request | |||
| from api.utils.mcp_server import get_mcp_tools | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ | |||
| get_mcp_tools | |||
| from api.utils.web_utils import get_float, safe_json_parse | |||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | |||
| @@ -13,19 +13,29 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import asyncio | |||
| import functools | |||
| import json | |||
| import logging | |||
| import queue | |||
| import random | |||
| import threading | |||
| import time | |||
| from base64 import b64encode | |||
| from copy import deepcopy | |||
| from functools import wraps | |||
| from hmac import HMAC | |||
| from io import BytesIO | |||
| from typing import Any, Optional, Union, Callable, Coroutine, Type | |||
| from urllib.parse import quote, urlencode | |||
| from uuid import uuid1 | |||
| import trio | |||
| from api.db.db_models import MCPServer | |||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | |||
| import requests | |||
| from flask import ( | |||
| Response, | |||
| @@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict: | |||
| transformed_data[mapped_key] = value | |||
| return transformed_data | |||
| def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: | |||
| results = {} | |||
| tool_call_sessions = [] | |||
| try: | |||
| for mcp_server in mcp_servers: | |||
| server_key = mcp_server.id | |||
| cached_tools = mcp_server.variables.get("tools", {}) | |||
| tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) | |||
| tool_call_sessions.append(tool_call_session) | |||
| try: | |||
| tools = tool_call_session.get_tools(timeout) | |||
| except Exception: | |||
| tools = [] | |||
| results[server_key] = [] | |||
| for tool in tools: | |||
| tool_dict = tool.model_dump() | |||
| cached_tool = cached_tools.get(tool_dict["name"], {}) | |||
| tool_dict["enabled"] = cached_tool.get("enabled", True) | |||
| results[server_key].append(tool_dict) | |||
| # PERF: blocking call to close sessions — consider moving to background thread or task queue | |||
| close_multiple_mcp_toolcall_sessions(tool_call_sessions) | |||
| return results, "" | |||
| except Exception as e: | |||
| return {}, str(e) | |||
| TimeoutException = Union[Type[BaseException], BaseException] | |||
| OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] | |||
| def timeout( | |||
| seconds: float |int = None, | |||
| *, | |||
| exception: Optional[TimeoutException] = None, | |||
| on_timeout: Optional[OnTimeoutCallback] = None | |||
| ): | |||
| def decorator(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| result_queue = queue.Queue(maxsize=1) | |||
| def target(): | |||
| try: | |||
| result = func(*args, **kwargs) | |||
| result_queue.put(result) | |||
| except Exception as e: | |||
| result_queue.put(e) | |||
| thread = threading.Thread(target=target) | |||
| thread.daemon = True | |||
| thread.start() | |||
| try: | |||
| result = result_queue.get(timeout=seconds) | |||
| if isinstance(result, Exception): | |||
| raise result | |||
| return result | |||
| except queue.Empty: | |||
| raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds") | |||
| @wraps(func) | |||
| async def async_wrapper(*args, **kwargs) -> Any: | |||
| if seconds is None: | |||
| return await func(*args, **kwargs) | |||
| try: | |||
| with trio.fail_after(seconds): | |||
| return await func(*args, **kwargs) | |||
| except trio.TooSlowError: | |||
| if on_timeout is not None: | |||
| if callable(on_timeout): | |||
| result = on_timeout() | |||
| if isinstance(result, Coroutine): | |||
| return await result | |||
| return result | |||
| return on_timeout | |||
| if exception is None: | |||
| raise TimeoutError(f"Operation timed out after {seconds} seconds") | |||
| if isinstance(exception, BaseException): | |||
| raise exception | |||
| if isinstance(exception, type) and issubclass(exception, BaseException): | |||
| raise exception(f"Operation timed out after {seconds} seconds") | |||
| raise RuntimeError("Invalid exception type provided") | |||
| if asyncio.iscoroutinefunction(func): | |||
| return async_wrapper | |||
| return wrapper | |||
| return decorator | |||
| @@ -1,34 +0,0 @@ | |||
| from api.db.db_models import MCPServer | |||
| from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions | |||
| def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]: | |||
| results = {} | |||
| tool_call_sessions = [] | |||
| try: | |||
| for mcp_server in mcp_servers: | |||
| server_key = mcp_server.id | |||
| cached_tools = mcp_server.variables.get("tools", {}) | |||
| tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) | |||
| tool_call_sessions.append(tool_call_session) | |||
| try: | |||
| tools = tool_call_session.get_tools(timeout) | |||
| except Exception: | |||
| tools = [] | |||
| results[server_key] = [] | |||
| for tool in tools: | |||
| tool_dict = tool.model_dump() | |||
| cached_tool = cached_tools.get(tool_dict["name"], {}) | |||
| tool_dict["enabled"] = cached_tool.get("enabled", True) | |||
| results[server_key].append(tool_dict) | |||
| # PERF: blocking call to close sessions — consider moving to background thread or task queue | |||
| close_multiple_mcp_toolcall_sessions(tool_call_sessions) | |||
| return results, "" | |||
| except Exception as e: | |||
| return {}, str(e) | |||
| @@ -12,6 +12,8 @@ from typing import Callable | |||
| from dataclasses import dataclass | |||
| import networkx as nx | |||
| import pandas as pd | |||
| from api.utils.api_utils import timeout | |||
| from graphrag.general import leiden | |||
| from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT | |||
| from graphrag.general.extractor import Extractor | |||
| @@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor): | |||
| res_str = [] | |||
| res_dict = [] | |||
| over, token_count = 0, 0 | |||
| @timeout(120) | |||
| async def extract_community_report(community): | |||
| nonlocal res_str, res_dict, over, token_count | |||
| cm_id, cm = community | |||
| @@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor): | |||
| gen_conf = {"temperature": 0.3} | |||
| async with chat_limiter: | |||
| try: | |||
| with trio.move_on_after(120) as cancel_scope: | |||
| with trio.move_on_after(80) as cancel_scope: | |||
| response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf) | |||
| if cancel_scope.cancelled_caught: | |||
| logging.warning("extract_community_report._chat timeout, skipping...") | |||
| @@ -21,6 +21,7 @@ from typing import Callable | |||
| import trio | |||
| import networkx as nx | |||
| from api.utils.api_utils import timeout | |||
| from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT | |||
| from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \ | |||
| handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange | |||
| @@ -46,6 +47,7 @@ class Extractor: | |||
| self._language = language | |||
| self._entity_types = entity_types or DEFAULT_ENTITY_TYPES | |||
| @timeout(60) | |||
| def _chat(self, system, history, gen_conf): | |||
| hist = deepcopy(history) | |||
| conf = deepcopy(gen_conf) | |||
| @@ -20,6 +20,7 @@ import trio | |||
| from api import settings | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import timeout | |||
| from graphrag.light.graph_extractor import GraphExtractor as LightKGExt | |||
| from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt | |||
| from graphrag.general.community_reports_extractor import CommunityReportsExtractor | |||
| @@ -123,6 +124,7 @@ async def run_graphrag( | |||
| return | |||
| @timeout(60*60*2) | |||
| async def generate_subgraph( | |||
| extractor: Extractor, | |||
| tenant_id: str, | |||
| @@ -194,6 +196,8 @@ async def generate_subgraph( | |||
| callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.") | |||
| return subgraph | |||
| @timeout(60*3) | |||
| async def merge_subgraph( | |||
| tenant_id: str, | |||
| kb_id: str, | |||
| @@ -225,6 +229,7 @@ async def merge_subgraph( | |||
| return new_graph | |||
| @timeout(60*60) | |||
| async def resolve_entities( | |||
| graph, | |||
| subgraph_nodes: set[str], | |||
| @@ -250,6 +255,7 @@ async def resolve_entities( | |||
| callback(msg=f"Graph resolution done in {now - start:.2f}s.") | |||
| @timeout(60*30) | |||
| async def extract_community( | |||
| graph, | |||
| tenant_id: str, | |||
| @@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags): | |||
| k = hasher.hexdigest() | |||
| REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) | |||
| def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): | |||
| """ | |||
| Ensure all nodes and edges in the graph have some essential attribute. | |||
| @@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): | |||
| if purged_edges and callback: | |||
| callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.") | |||
| def get_from_to(node1, node2): | |||
| if node1 < node2: | |||
| return (node1, node2) | |||
| else: | |||
| return (node2, node1) | |||
| def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): | |||
| """Merge graph g2 into g1 in place.""" | |||
| for node_name, attr in g2.nodes(data=True): | |||
| @@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): | |||
| g1.graph["source_id"] += g2.graph.get("source_id", []) | |||
| return g1 | |||
| def compute_args_hash(*args): | |||
| return md5(str(args).encode()).hexdigest() | |||
| @@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, | |||
| chunk["q_%d_vec" % len(ebd)] = ebd | |||
| chunks.append(chunk) | |||
| async def does_graph_contains(tenant_id, kb_id, doc_id): | |||
| # Get doc_ids of graph | |||
| fields = ["source_id"] | |||
| @@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id): | |||
| graph_doc_ids = set(fields2[chunk_id]["source_id"]) | |||
| return doc_id in graph_doc_ids | |||
| async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: | |||
| conds = { | |||
| "fields": ["source_id"], | |||
| @@ -20,6 +20,7 @@ import numpy as np | |||
| from sklearn.mixture import GaussianMixture | |||
| import trio | |||
| from api.utils.api_utils import timeout | |||
| from graphrag.utils import ( | |||
| get_llm_cache, | |||
| get_embed_cache, | |||
| @@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) | |||
| return response | |||
| @timeout(2) | |||
| async def _embedding_encode(self, txt): | |||
| response = get_embed_cache(self._embd_model.llm_name, txt) | |||
| if response is not None: | |||
| @@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| layers = [(0, len(chunks))] | |||
| start, end = 0, len(chunks) | |||
| @timeout(60) | |||
| async def summarize(ck_idx: list[int]): | |||
| nonlocal chunks | |||
| texts = [chunks[i][0] for i in ck_idx] | |||
| @@ -21,6 +21,7 @@ import sys | |||
| import threading | |||
| import time | |||
| from api.utils.api_utils import timeout | |||
| from api.utils.log_utils import init_root_logger, get_project_base_directory | |||
| from graphrag.general.index import run_graphrag | |||
| from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache | |||
| @@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback): | |||
| doc[PAGERANK_FLD] = int(task["pagerank"]) | |||
| st = timer() | |||
| @timeout(60) | |||
| async def upload_to_minio(document, chunk): | |||
| try: | |||
| d = copy.deepcopy(document) | |||
| @@ -415,6 +417,7 @@ def init_kb(row, vector_size: int): | |||
| return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size) | |||
| @timeout(60*20) | |||
| async def embedding(docs, mdl, parser_config=None, callback=None): | |||
| if parser_config is None: | |||
| parser_config = {} | |||
| @@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): | |||
| return tk_count, vector_size | |||
| @timeout(3600) | |||
| async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||
| chunks = [] | |||
| vctr_nm = "q_%d_vec"%vector_size | |||
| @@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None): | |||
| return res, tk_count | |||
| @timeout(60*60*1.5) | |||
| async def do_handle_task(task): | |||
| task_id = task["id"] | |||
| task_from_page = task["from_page"] | |||
| @@ -220,40 +220,43 @@ class RedisDB: | |||
| logging.exception( | |||
| "RedisDB.queue_product " + str(queue) + " got exception: " + str(e) | |||
| ) | |||
| self.__open__() | |||
| return False | |||
| def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg: | |||
| """https://redis.io/docs/latest/commands/xreadgroup/""" | |||
| try: | |||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||
| if not any(gi["name"] == group_name for gi in group_info): | |||
| self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True) | |||
| args = { | |||
| "groupname": group_name, | |||
| "consumername": consumer_name, | |||
| "count": 1, | |||
| "block": 5, | |||
| "streams": {queue_name: msg_id}, | |||
| } | |||
| messages = self.REDIS.xreadgroup(**args) | |||
| if not messages: | |||
| return None | |||
| stream, element_list = messages[0] | |||
| if not element_list: | |||
| return None | |||
| msg_id, payload = element_list[0] | |||
| res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload) | |||
| return res | |||
| except Exception as e: | |||
| if str(e) == 'no such key': | |||
| pass | |||
| else: | |||
| logging.exception( | |||
| "RedisDB.queue_consumer " | |||
| + str(queue_name) | |||
| + " got exception: " | |||
| + str(e) | |||
| ) | |||
| for _ in range(3): | |||
| try: | |||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||
| if not any(gi["name"] == group_name for gi in group_info): | |||
| self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True) | |||
| args = { | |||
| "groupname": group_name, | |||
| "consumername": consumer_name, | |||
| "count": 1, | |||
| "block": 5, | |||
| "streams": {queue_name: msg_id}, | |||
| } | |||
| messages = self.REDIS.xreadgroup(**args) | |||
| if not messages: | |||
| return None | |||
| stream, element_list = messages[0] | |||
| if not element_list: | |||
| return None | |||
| msg_id, payload = element_list[0] | |||
| res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload) | |||
| return res | |||
| except Exception as e: | |||
| if str(e) == 'no such key': | |||
| pass | |||
| else: | |||
| logging.exception( | |||
| "RedisDB.queue_consumer " | |||
| + str(queue_name) | |||
| + " got exception: " | |||
| + str(e) | |||
| ) | |||
| self.__open__() | |||
| return None | |||
| def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): | |||
| @@ -294,26 +297,30 @@ class RedisDB: | |||
| return [] | |||
| def requeue_msg(self, queue: str, group_name: str, msg_id: str): | |||
| try: | |||
| messages = self.REDIS.xrange(queue, msg_id, msg_id) | |||
| if messages: | |||
| self.REDIS.xadd(queue, messages[0][1]) | |||
| self.REDIS.xack(queue, group_name, msg_id) | |||
| except Exception as e: | |||
| logging.warning( | |||
| "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) | |||
| ) | |||
| for _ in range(3): | |||
| try: | |||
| messages = self.REDIS.xrange(queue, msg_id, msg_id) | |||
| if messages: | |||
| self.REDIS.xadd(queue, messages[0][1]) | |||
| self.REDIS.xack(queue, group_name, msg_id) | |||
| except Exception as e: | |||
| logging.warning( | |||
| "RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e) | |||
| ) | |||
| self.__open__() | |||
| def queue_info(self, queue, group_name) -> dict | None: | |||
| try: | |||
| groups = self.REDIS.xinfo_groups(queue) | |||
| for group in groups: | |||
| if group["name"] == group_name: | |||
| return group | |||
| except Exception as e: | |||
| logging.warning( | |||
| "RedisDB.queue_info " + str(queue) + " got exception: " + str(e) | |||
| ) | |||
| for _ in range(3): | |||
| try: | |||
| groups = self.REDIS.xinfo_groups(queue) | |||
| for group in groups: | |||
| if group["name"] == group_name: | |||
| return group | |||
| except Exception as e: | |||
| logging.warning( | |||
| "RedisDB.queue_info " + str(queue) + " got exception: " + str(e) | |||
| ) | |||
| self.__open__() | |||
| return None | |||
| def delete_if_equal(self, key: str, expected_value: str) -> bool: | |||