Parcourir la source

Revert "feat: improved MCP timeout" (#23602)

tags/1.7.2
crazywoola il y a 2 mois
Parent
révision
1c60b7f070
Aucun compte lié à l'adresse e-mail de l'auteur

+ 0
- 10
api/controllers/console/workspace/tool_providers.py Voir le fichier

@@ -862,10 +862,6 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
@@ -880,8 +876,6 @@ class ToolProviderMCPApi(Resource):
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
)
)

@@ -897,8 +891,6 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
@@ -914,8 +906,6 @@ class ToolProviderMCPApi(Resource):
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
)
return {"result": "success"}


+ 1
- 1
api/core/mcp/client/sse_client.py Voir le fichier

@@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message:
)
response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code)
except Exception:
except Exception as exc:
logger.exception("Error sending message")
raise


+ 14
- 12
api/core/mcp/client/streamable_client.py Voir le fichier

@@ -55,10 +55,14 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""

pass


class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""

pass


@dataclass
class RequestContext:
@@ -70,7 +74,7 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: float
sse_read_timeout: timedelta


class StreamableHTTPTransport:
@@ -80,8 +84,8 @@ class StreamableHTTPTransport:
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None:
"""Initialize the StreamableHTTP transport.

@@ -93,10 +97,8 @@ class StreamableHTTPTransport:
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
@@ -184,7 +186,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source:
@@ -213,7 +215,7 @@ class StreamableHTTPTransport:
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source:
@@ -400,8 +402,8 @@ class StreamableHTTPTransport:
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
) -> Generator[
tuple[
@@ -434,7 +436,7 @@ def streamablehttp_client(
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:

+ 10
- 18
api/core/mcp/mcp_client.py Voir le fichier

@@ -23,18 +23,12 @@ class MCPClient:
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout

# Authentication info
self.authed = authed
@@ -49,7 +43,7 @@ class MCPClient:
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self._exit_stack = ExitStack()
self.exit_stack = ExitStack()

# Whether the client has been initialized
self._initialized = False
@@ -96,26 +90,21 @@ class MCPClient:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")

# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
streams = self.exit_stack.enter_context(self._streams_context)

self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session = self.exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@@ -131,6 +120,9 @@ class MCPClient:
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)

except MCPConnectionError:
raise

def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
@@ -150,7 +142,7 @@ class MCPClient:
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
self.exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

+ 7
- 1
api/core/mcp/session/base_session.py Voir le fichier

@@ -2,6 +2,7 @@ import logging
import queue
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
@@ -169,6 +170,7 @@ class BaseSession(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None
@@ -375,7 +377,7 @@ class BaseSession(
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception:
except Exception as e:
logging.exception("Error in message processing loop")
raise

@@ -387,12 +389,14 @@ class BaseSession(
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass

def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass

def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
@@ -401,9 +405,11 @@ class BaseSession(
Sends a progress notification for a request that is currently being
processed.
"""
pass

def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

+ 2
- 3
api/core/mcp/session/client_session.py Voir le fichier

@@ -1,4 +1,3 @@
import queue
from datetime import timedelta
from typing import Any, Protocol

@@ -86,8 +85,8 @@ class ClientSession(
):
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_stream,
write_stream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,

+ 2
- 0
api/core/tools/__base/tool_provider.py Voir le fichier

@@ -12,6 +12,8 @@ from core.tools.errors import ToolProviderCredentialValidationError


class ToolProviderController(ABC):
entity: ToolProviderEntity

def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity


+ 6
- 24
api/core/tools/mcp_tool/provider.py Voir le fichier

@@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any

from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
@@ -19,24 +19,15 @@ from services.tools.tools_transform_service import ToolTransformService


class MCPToolProviderController(ToolProviderController):
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
provider_id: str
entity: ToolProviderEntityWithPlugin

def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
self.entity = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout

@property
def provider_type(self) -> ToolProviderType:
@@ -94,9 +85,6 @@ class MCPToolProviderController(ToolProviderController):
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
)

def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
@@ -123,9 +111,6 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)

def get_tools(self) -> list[MCPTool]: # type: ignore
@@ -140,9 +125,6 @@ class MCPToolProviderController(ToolProviderController):
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
for tool_entity in self.entity.tools
]

+ 2
- 25
api/core/tools/mcp_tool/tool.py Voir le fichier

@@ -13,25 +13,13 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too

class MCPTool(Tool):
def __init__(
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.server_url = server_url
self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout

def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
@@ -47,15 +35,7 @@ class MCPTool(Tool):
from core.tools.errors import ToolInvokeError

try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
@@ -92,9 +72,6 @@ class MCPTool(Tool):
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)

def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

+ 3
- 0
api/core/tools/tool_manager.py Voir le fichier

@@ -789,6 +789,9 @@ class ToolManager:
"""
get api provider
"""
"""
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)

+ 0
- 47
api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py Voir le fichier

@@ -1,47 +0,0 @@
"""add timeout for tool_mcp_providers

Revision ID: fa8b0fa6f407
Revises: 532b3f888abf
Create Date: 2025-08-07 11:15:31.517985

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'fa8b0fa6f407'
down_revision = '532b3f888abf'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))

with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx'))

with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_run_created_at_idx'))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False)

with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False)

with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.drop_column('sse_read_timeout')
batch_op.drop_column('timeout')

# ### end Alembic commands ###

+ 0
- 2
api/models/tools.py Voir le fichier

@@ -278,8 +278,6 @@ class MCPToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))

def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()

+ 0
- 10
api/services/tools/mcp_tools_manage_service.py Voir le fichier

@@ -59,8 +59,6 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float,
sse_read_timeout: float,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@@ -93,8 +91,6 @@ class MCPToolManageService:
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
)
db.session.add(mcp_tool)
db.session.commit()
@@ -170,8 +166,6 @@ class MCPToolManageService:
icon_type: str,
icon_background: str,
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)

@@ -203,10 +197,6 @@ class MCPToolManageService:
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]

if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
db.session.commit()
except IntegrityError as e:
db.session.rollback()

Chargement…
Annuler
Enregistrer