Browse Source

Perf: Enhance timeout handling. (#8826)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu 3 months ago
parent
commit
c642dbefca
No account linked to committer's email address

+ 18
- 2
api/apps/mcp_server_app.py View File

@@ -1,3 +1,18 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response, request
from flask_login import current_user, login_required

@@ -6,9 +21,10 @@ from api.db.db_models import MCPServer
from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService
from api.settings import RetCode

from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from api.utils.mcp_server import get_mcp_tools
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions


+ 108
- 0
api/utils/api_utils.py View File

@@ -13,19 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import functools
import json
import logging
import queue
import random
import threading
import time
from base64 import b64encode
from copy import deepcopy
from functools import wraps
from hmac import HMAC
from io import BytesIO
from typing import Any, Optional, Union, Callable, Coroutine, Type
from urllib.parse import quote, urlencode
from uuid import uuid1

import trio

from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions


import requests
from flask import (
Response,
@@ -558,3 +568,101 @@ def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
transformed_data[mapped_key] = value

return transformed_data


def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id

cached_tools = mcp_server.variables.get("tools", {})

tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)

try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []

results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})

tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)

# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)


TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(
seconds: float |int = None,
*,
exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None
):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)

thread = threading.Thread(target=target)
thread.daemon = True
thread.start()

try:
result = result_queue.get(timeout=seconds)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")

@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)

try:
with trio.fail_after(seconds):
return await func(*args, **kwargs)
except trio.TooSlowError:
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout

if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds")

if isinstance(exception, BaseException):
raise exception

if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds")

raise RuntimeError("Invalid exception type provided")

if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator


+ 0
- 34
api/utils/mcp_server.py View File

@@ -1,34 +0,0 @@
from api.db.db_models import MCPServer
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions


def get_mcp_tools(mcp_servers: list[MCPServer], timeout: float | int = 10) -> tuple[dict, str]:
results = {}
tool_call_sessions = []
try:
for mcp_server in mcp_servers:
server_key = mcp_server.id

cached_tools = mcp_server.variables.get("tools", {})

tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
tool_call_sessions.append(tool_call_session)

try:
tools = tool_call_session.get_tools(timeout)
except Exception:
tools = []

results[server_key] = []
for tool in tools:
tool_dict = tool.model_dump()
cached_tool = cached_tools.get(tool_dict["name"], {})

tool_dict["enabled"] = cached_tool.get("enabled", True)
results[server_key].append(tool_dict)

# PERF: blocking call to close sessions — consider moving to background thread or task queue
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
return results, ""
except Exception as e:
return {}, str(e)

+ 4
- 1
graphrag/general/community_reports_extractor.py View File

@@ -12,6 +12,8 @@ from typing import Callable
from dataclasses import dataclass
import networkx as nx
import pandas as pd

from api.utils.api_utils import timeout
from graphrag.general import leiden
from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from graphrag.general.extractor import Extractor
@@ -57,6 +59,7 @@ class CommunityReportsExtractor(Extractor):
res_str = []
res_dict = []
over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
cm_id, cm = community
@@ -90,7 +93,7 @@ class CommunityReportsExtractor(Extractor):
gen_conf = {"temperature": 0.3}
async with chat_limiter:
try:
with trio.move_on_after(120) as cancel_scope:
with trio.move_on_after(80) as cancel_scope:
response = await trio.to_thread.run_sync( self._chat, text, [{"role": "user", "content": "Output:"}], gen_conf)
if cancel_scope.cancelled_caught:
logging.warning("extract_community_report._chat timeout, skipping...")

+ 2
- 0
graphrag/general/extractor.py View File

@@ -21,6 +21,7 @@ from typing import Callable
import trio
import networkx as nx

from api.utils.api_utils import timeout
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter, get_from_to, GraphChange
@@ -46,6 +47,7 @@ class Extractor:
self._language = language
self._entity_types = entity_types or DEFAULT_ENTITY_TYPES

@timeout(60)
def _chat(self, system, history, gen_conf):
hist = deepcopy(history)
conf = deepcopy(gen_conf)

+ 6
- 0
graphrag/general/index.py View File

@@ -20,6 +20,7 @@ import trio

from api import settings
from api.utils import get_uuid
from api.utils.api_utils import timeout
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
@@ -123,6 +124,7 @@ async def run_graphrag(
return


@timeout(60*60*2)
async def generate_subgraph(
extractor: Extractor,
tenant_id: str,
@@ -194,6 +196,8 @@ async def generate_subgraph(
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
return subgraph


@timeout(60*3)
async def merge_subgraph(
tenant_id: str,
kb_id: str,
@@ -225,6 +229,7 @@ async def merge_subgraph(
return new_graph


@timeout(60*60)
async def resolve_entities(
graph,
subgraph_nodes: set[str],
@@ -250,6 +255,7 @@ async def resolve_entities(
callback(msg=f"Graph resolution done in {now - start:.2f}s.")


@timeout(60*30)
async def extract_community(
graph,
tenant_id: str,

+ 6
- 0
graphrag/utils.py View File

@@ -157,6 +157,7 @@ def set_tags_to_cache(kb_ids, tags):
k = hasher.hexdigest()
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)


def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
"""
Ensure all nodes and edges in the graph have some essential attribute.
@@ -190,12 +191,14 @@ def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
if purged_edges and callback:
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")


def get_from_to(node1, node2):
if node1 < node2:
return (node1, node2)
else:
return (node2, node1)


def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
"""Merge graph g2 into g1 in place."""
for node_name, attr in g2.nodes(data=True):
@@ -228,6 +231,7 @@ def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
g1.graph["source_id"] += g2.graph.get("source_id", [])
return g1


def compute_args_hash(*args):
return md5(str(args).encode()).hexdigest()

@@ -378,6 +382,7 @@ async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta,
chunk["q_%d_vec" % len(ebd)] = ebd
chunks.append(chunk)


async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
@@ -392,6 +397,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids


async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {
"fields": ["source_id"],

+ 3
- 0
rag/raptor.py View File

@@ -20,6 +20,7 @@ import numpy as np
from sklearn.mixture import GaussianMixture
import trio

from api.utils.api_utils import timeout
from graphrag.utils import (
get_llm_cache,
get_embed_cache,
@@ -54,6 +55,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response

@timeout(2)
async def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt)
if response is not None:
@@ -83,6 +85,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
layers = [(0, len(chunks))]
start, end = 0, len(chunks)

@timeout(60)
async def summarize(ck_idx: list[int]):
nonlocal chunks
texts = [chunks[i][0] for i in ck_idx]

+ 5
- 0
rag/svr/task_executor.py View File

@@ -21,6 +21,7 @@ import sys
import threading
import time

from api.utils.api_utils import timeout
from api.utils.log_utils import init_root_logger, get_project_base_directory
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
@@ -275,6 +276,7 @@ async def build_chunks(task, progress_callback):
doc[PAGERANK_FLD] = int(task["pagerank"])
st = timer()

@timeout(60)
async def upload_to_minio(document, chunk):
try:
d = copy.deepcopy(document)
@@ -415,6 +417,7 @@ def init_kb(row, vector_size: int):
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)


@timeout(60*20)
async def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
@@ -461,6 +464,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
return tk_count, vector_size


@timeout(3600)
async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
chunks = []
vctr_nm = "q_%d_vec"%vector_size
@@ -502,6 +506,7 @@ async def run_raptor(row, chat_mdl, embd_mdl, vector_size, callback=None):
return res, tk_count


@timeout(60*60*1.5)
async def do_handle_task(task):
task_id = task["id"]
task_from_page = task["from_page"]

+ 55
- 48
rag/utils/redis_conn.py View File

@@ -220,40 +220,43 @@ class RedisDB:
logging.exception(
"RedisDB.queue_product " + str(queue) + " got exception: " + str(e)
)
self.__open__()
return False

def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> RedisMsg:
"""https://redis.io/docs/latest/commands/xreadgroup/"""
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(gi["name"] == group_name for gi in group_info):
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
args = {
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": 5,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res
except Exception as e:
if str(e) == 'no such key':
pass
else:
logging.exception(
"RedisDB.queue_consumer "
+ str(queue_name)
+ " got exception: "
+ str(e)
)
for _ in range(3):
try:
group_info = self.REDIS.xinfo_groups(queue_name)
if not any(gi["name"] == group_name for gi in group_info):
self.REDIS.xgroup_create(queue_name, group_name, id="0", mkstream=True)
args = {
"groupname": group_name,
"consumername": consumer_name,
"count": 1,
"block": 5,
"streams": {queue_name: msg_id},
}
messages = self.REDIS.xreadgroup(**args)
if not messages:
return None
stream, element_list = messages[0]
if not element_list:
return None
msg_id, payload = element_list[0]
res = RedisMsg(self.REDIS, queue_name, group_name, msg_id, payload)
return res
except Exception as e:
if str(e) == 'no such key':
pass
else:
logging.exception(
"RedisDB.queue_consumer "
+ str(queue_name)
+ " got exception: "
+ str(e)
)
self.__open__()
return None

def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name):
@@ -294,26 +297,30 @@ class RedisDB:
return []

def requeue_msg(self, queue: str, group_name: str, msg_id: str):
try:
messages = self.REDIS.xrange(queue, msg_id, msg_id)
if messages:
self.REDIS.xadd(queue, messages[0][1])
self.REDIS.xack(queue, group_name, msg_id)
except Exception as e:
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
for _ in range(3):
try:
messages = self.REDIS.xrange(queue, msg_id, msg_id)
if messages:
self.REDIS.xadd(queue, messages[0][1])
self.REDIS.xack(queue, group_name, msg_id)
except Exception as e:
logging.warning(
"RedisDB.get_pending_msg " + str(queue) + " got exception: " + str(e)
)
self.__open__()

def queue_info(self, queue, group_name) -> dict | None:
try:
groups = self.REDIS.xinfo_groups(queue)
for group in groups:
if group["name"] == group_name:
return group
except Exception as e:
logging.warning(
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
)
for _ in range(3):
try:
groups = self.REDIS.xinfo_groups(queue)
for group in groups:
if group["name"] == group_name:
return group
except Exception as e:
logging.warning(
"RedisDB.queue_info " + str(queue) + " got exception: " + str(e)
)
self.__open__()
return None

def delete_if_equal(self, key: str, expected_value: str) -> bool:

Loading…
Cancel
Save