| @@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource): | |||
| parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") | |||
| parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) | |||
| parser.add_argument( | |||
| "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 | |||
| ) | |||
| args = parser.parse_args() | |||
| user = current_user | |||
| if not is_valid_url(args["server_url"]): | |||
| @@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource): | |||
| icon_background=args["icon_background"], | |||
| user_id=user.id, | |||
| server_identifier=args["server_identifier"], | |||
| timeout=args["timeout"], | |||
| sse_read_timeout=args["sse_read_timeout"], | |||
| ) | |||
| ) | |||
| @@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource): | |||
| parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") | |||
| parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| if not is_valid_url(args["server_url"]): | |||
| if "[__HIDDEN__]" in args["server_url"]: | |||
| @@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource): | |||
| icon_type=args["icon_type"], | |||
| icon_background=args["icon_background"], | |||
| server_identifier=args["server_identifier"], | |||
| timeout=args.get("timeout"), | |||
| sse_read_timeout=args.get("sse_read_timeout"), | |||
| ) | |||
| return {"result": "success"} | |||
| @@ -327,7 +327,7 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message: | |||
| ) | |||
| response.raise_for_status() | |||
| logger.debug("Client message sent successfully: %s", response.status_code) | |||
| except Exception as exc: | |||
| except Exception: | |||
| logger.exception("Error sending message") | |||
| raise | |||
| @@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 | |||
| class StreamableHTTPError(Exception): | |||
| """Base exception for StreamableHTTP transport errors.""" | |||
| pass | |||
| class ResumptionError(StreamableHTTPError): | |||
| """Raised when resumption request is invalid.""" | |||
| pass | |||
| @dataclass | |||
| class RequestContext: | |||
| @@ -74,7 +70,7 @@ class RequestContext: | |||
| session_message: SessionMessage | |||
| metadata: ClientMessageMetadata | None | |||
| server_to_client_queue: ServerToClientQueue # Renamed for clarity | |||
| sse_read_timeout: timedelta | |||
| sse_read_timeout: float | |||
| class StreamableHTTPTransport: | |||
| @@ -84,8 +80,8 @@ class StreamableHTTPTransport: | |||
| self, | |||
| url: str, | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: timedelta = timedelta(seconds=30), | |||
| sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | |||
| timeout: float | timedelta = 30, | |||
| sse_read_timeout: float | timedelta = 60 * 5, | |||
| ) -> None: | |||
| """Initialize the StreamableHTTP transport. | |||
| @@ -97,8 +93,10 @@ class StreamableHTTPTransport: | |||
| """ | |||
| self.url = url | |||
| self.headers = headers or {} | |||
| self.timeout = timeout | |||
| self.sse_read_timeout = sse_read_timeout | |||
| self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout | |||
| self.sse_read_timeout = ( | |||
| sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout | |||
| ) | |||
| self.session_id: str | None = None | |||
| self.request_headers = { | |||
| ACCEPT: f"{JSON}, {SSE}", | |||
| @@ -186,7 +184,7 @@ class StreamableHTTPTransport: | |||
| with ssrf_proxy_sse_connect( | |||
| self.url, | |||
| headers=headers, | |||
| timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), | |||
| timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), | |||
| client=client, | |||
| method="GET", | |||
| ) as event_source: | |||
| @@ -215,7 +213,7 @@ class StreamableHTTPTransport: | |||
| with ssrf_proxy_sse_connect( | |||
| self.url, | |||
| headers=headers, | |||
| timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), | |||
| timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), | |||
| client=ctx.client, | |||
| method="GET", | |||
| ) as event_source: | |||
| @@ -402,8 +400,8 @@ class StreamableHTTPTransport: | |||
| def streamablehttp_client( | |||
| url: str, | |||
| headers: dict[str, Any] | None = None, | |||
| timeout: timedelta = timedelta(seconds=30), | |||
| sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | |||
| timeout: float | timedelta = 30, | |||
| sse_read_timeout: float | timedelta = 60 * 5, | |||
| terminate_on_close: bool = True, | |||
| ) -> Generator[ | |||
| tuple[ | |||
| @@ -436,7 +434,7 @@ def streamablehttp_client( | |||
| try: | |||
| with create_ssrf_proxy_mcp_http_client( | |||
| headers=transport.request_headers, | |||
| timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), | |||
| timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | |||
| ) as client: | |||
| # Define callbacks that need access to thread pool | |||
| def start_get_stream() -> None: | |||
| @@ -23,12 +23,18 @@ class MCPClient: | |||
| authed: bool = True, | |||
| authorization_code: Optional[str] = None, | |||
| for_list: bool = False, | |||
| headers: Optional[dict[str, str]] = None, | |||
| timeout: Optional[float] = None, | |||
| sse_read_timeout: Optional[float] = None, | |||
| ): | |||
| # Initialize info | |||
| self.provider_id = provider_id | |||
| self.tenant_id = tenant_id | |||
| self.client_type = "streamable" | |||
| self.server_url = server_url | |||
| self.headers = headers or {} | |||
| self.timeout = timeout | |||
| self.sse_read_timeout = sse_read_timeout | |||
| # Authentication info | |||
| self.authed = authed | |||
| @@ -43,7 +49,7 @@ class MCPClient: | |||
| self._session: Optional[ClientSession] = None | |||
| self._streams_context: Optional[AbstractContextManager[Any]] = None | |||
| self._session_context: Optional[ClientSession] = None | |||
| self.exit_stack = ExitStack() | |||
| self._exit_stack = ExitStack() | |||
| # Whether the client has been initialized | |||
| self._initialized = False | |||
| @@ -90,21 +96,26 @@ class MCPClient: | |||
| headers = ( | |||
| {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} | |||
| if self.authed and self.token | |||
| else {} | |||
| else self.headers | |||
| ) | |||
| self._streams_context = client_factory( | |||
| url=self.server_url, | |||
| headers=headers, | |||
| timeout=self.timeout, | |||
| sse_read_timeout=self.sse_read_timeout, | |||
| ) | |||
| self._streams_context = client_factory(url=self.server_url, headers=headers) | |||
| if not self._streams_context: | |||
| raise MCPConnectionError("Failed to create connection context") | |||
| # Use exit_stack to manage context managers properly | |||
| if method_name == "mcp": | |||
| read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) | |||
| read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) | |||
| streams = (read_stream, write_stream) | |||
| else: # sse_client | |||
| streams = self.exit_stack.enter_context(self._streams_context) | |||
| streams = self._exit_stack.enter_context(self._streams_context) | |||
| self._session_context = ClientSession(*streams) | |||
| self._session = self.exit_stack.enter_context(self._session_context) | |||
| self._session = self._exit_stack.enter_context(self._session_context) | |||
| session = cast(ClientSession, self._session) | |||
| session.initialize() | |||
| return | |||
| @@ -120,9 +131,6 @@ class MCPClient: | |||
| if first_try: | |||
| return self.connect_server(client_factory, method_name, first_try=False) | |||
| except MCPConnectionError: | |||
| raise | |||
| def list_tools(self) -> list[Tool]: | |||
| """Connect to an MCP server running with SSE transport""" | |||
| # List available tools to verify connection | |||
| @@ -142,7 +150,7 @@ class MCPClient: | |||
| """Clean up resources""" | |||
| try: | |||
| # ExitStack will handle proper cleanup of all managed context managers | |||
| self.exit_stack.close() | |||
| self._exit_stack.close() | |||
| except Exception as e: | |||
| logging.exception("Error during cleanup") | |||
| raise ValueError(f"Error during cleanup: {e}") | |||
| @@ -2,7 +2,6 @@ import logging | |||
| import queue | |||
| from collections.abc import Callable | |||
| from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError | |||
| from contextlib import ExitStack | |||
| from datetime import timedelta | |||
| from types import TracebackType | |||
| from typing import Any, Generic, Self, TypeVar | |||
| @@ -170,7 +169,6 @@ class BaseSession( | |||
| self._receive_notification_type = receive_notification_type | |||
| self._session_read_timeout_seconds = read_timeout_seconds | |||
| self._in_flight = {} | |||
| self._exit_stack = ExitStack() | |||
| # Initialize executor and future to None for proper cleanup checks | |||
| self._executor: ThreadPoolExecutor | None = None | |||
| self._receiver_future: Future | None = None | |||
| @@ -377,7 +375,7 @@ class BaseSession( | |||
| self._handle_incoming(RuntimeError(f"Server Error: {message}")) | |||
| except queue.Empty: | |||
| continue | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("Error in message processing loop") | |||
| raise | |||
| @@ -389,14 +387,12 @@ class BaseSession( | |||
| If the request is responded to within this method, it will not be | |||
| forwarded on to the message stream. | |||
| """ | |||
| pass | |||
| def _received_notification(self, notification: ReceiveNotificationT) -> None: | |||
| """ | |||
| Can be overridden by subclasses to handle a notification without needing | |||
| to listen on the message stream. | |||
| """ | |||
| pass | |||
| def send_progress_notification( | |||
| self, progress_token: str | int, progress: float, total: float | None = None | |||
| @@ -405,11 +401,9 @@ class BaseSession( | |||
| Sends a progress notification for a request that is currently being | |||
| processed. | |||
| """ | |||
| pass | |||
| def _handle_incoming( | |||
| self, | |||
| req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | |||
| ) -> None: | |||
| """A generic handler for incoming messages. Overwritten by subclasses.""" | |||
| pass | |||
| @@ -1,3 +1,4 @@ | |||
| import queue | |||
| from datetime import timedelta | |||
| from typing import Any, Protocol | |||
| @@ -85,8 +86,8 @@ class ClientSession( | |||
| ): | |||
| def __init__( | |||
| self, | |||
| read_stream, | |||
| write_stream, | |||
| read_stream: queue.Queue, | |||
| write_stream: queue.Queue, | |||
| read_timeout_seconds: timedelta | None = None, | |||
| sampling_callback: SamplingFnT | None = None, | |||
| list_roots_callback: ListRootsFnT | None = None, | |||
| @@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError | |||
| class ToolProviderController(ABC): | |||
| entity: ToolProviderEntity | |||
| def __init__(self, entity: ToolProviderEntity) -> None: | |||
| self.entity = entity | |||
| @@ -1,5 +1,5 @@ | |||
| import json | |||
| from typing import Any | |||
| from typing import Any, Optional | |||
| from core.mcp.types import Tool as RemoteMCPTool | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| @@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService | |||
| class MCPToolProviderController(ToolProviderController): | |||
| provider_id: str | |||
| entity: ToolProviderEntityWithPlugin | |||
| def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: | |||
| def __init__( | |||
| self, | |||
| entity: ToolProviderEntityWithPlugin, | |||
| provider_id: str, | |||
| tenant_id: str, | |||
| server_url: str, | |||
| headers: Optional[dict[str, str]] = None, | |||
| timeout: Optional[float] = None, | |||
| sse_read_timeout: Optional[float] = None, | |||
| ) -> None: | |||
| super().__init__(entity) | |||
| self.entity = entity | |||
| self.entity: ToolProviderEntityWithPlugin = entity | |||
| self.tenant_id = tenant_id | |||
| self.provider_id = provider_id | |||
| self.server_url = server_url | |||
| self.headers = headers or {} | |||
| self.timeout = timeout | |||
| self.sse_read_timeout = sse_read_timeout | |||
| @property | |||
| def provider_type(self) -> ToolProviderType: | |||
| @@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController): | |||
| provider_id=db_provider.server_identifier or "", | |||
| tenant_id=db_provider.tenant_id or "", | |||
| server_url=db_provider.decrypted_server_url, | |||
| headers={}, # TODO: get headers from db provider | |||
| timeout=db_provider.timeout, | |||
| sse_read_timeout=db_provider.sse_read_timeout, | |||
| ) | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| @@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController): | |||
| icon=self.entity.identity.icon, | |||
| server_url=self.server_url, | |||
| provider_id=self.provider_id, | |||
| headers=self.headers, | |||
| timeout=self.timeout, | |||
| sse_read_timeout=self.sse_read_timeout, | |||
| ) | |||
| def get_tools(self) -> list[MCPTool]: # type: ignore | |||
| @@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController): | |||
| icon=self.entity.identity.icon, | |||
| server_url=self.server_url, | |||
| provider_id=self.provider_id, | |||
| headers=self.headers, | |||
| timeout=self.timeout, | |||
| sse_read_timeout=self.sse_read_timeout, | |||
| ) | |||
| for tool_entity in self.entity.tools | |||
| ] | |||
| @@ -13,13 +13,25 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too | |||
| class MCPTool(Tool): | |||
| def __init__( | |||
| self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str | |||
| self, | |||
| entity: ToolEntity, | |||
| runtime: ToolRuntime, | |||
| tenant_id: str, | |||
| icon: str, | |||
| server_url: str, | |||
| provider_id: str, | |||
| headers: Optional[dict[str, str]] = None, | |||
| timeout: Optional[float] = None, | |||
| sse_read_timeout: Optional[float] = None, | |||
| ) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.tenant_id = tenant_id | |||
| self.icon = icon | |||
| self.server_url = server_url | |||
| self.provider_id = provider_id | |||
| self.headers = headers or {} | |||
| self.timeout = timeout | |||
| self.sse_read_timeout = sse_read_timeout | |||
| def tool_provider_type(self) -> ToolProviderType: | |||
| return ToolProviderType.MCP | |||
| @@ -35,7 +47,15 @@ class MCPTool(Tool): | |||
| from core.tools.errors import ToolInvokeError | |||
| try: | |||
| with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: | |||
| with MCPClient( | |||
| self.server_url, | |||
| self.provider_id, | |||
| self.tenant_id, | |||
| authed=True, | |||
| headers=self.headers, | |||
| timeout=self.timeout, | |||
| sse_read_timeout=self.sse_read_timeout, | |||
| ) as mcp_client: | |||
| tool_parameters = self._handle_none_parameter(tool_parameters) | |||
| result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) | |||
| except MCPAuthError as e: | |||
| @@ -72,6 +92,9 @@ class MCPTool(Tool): | |||
| icon=self.icon, | |||
| server_url=self.server_url, | |||
| provider_id=self.provider_id, | |||
| headers=self.headers, | |||
| timeout=self.timeout, | |||
| sse_read_timeout=self.sse_read_timeout, | |||
| ) | |||
| def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: | |||
| @@ -789,9 +789,6 @@ class ToolManager: | |||
| """ | |||
| get api provider | |||
| """ | |||
| """ | |||
| get tool provider | |||
| """ | |||
| provider_name = provider | |||
| provider_obj: ApiToolProvider | None = ( | |||
| db.session.query(ApiToolProvider) | |||
| @@ -0,0 +1,47 @@ | |||
| """add timeout for tool_mcp_providers | |||
| Revision ID: fa8b0fa6f407 | |||
| Revises: 532b3f888abf | |||
| Create Date: 2025-08-07 11:15:31.517985 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'fa8b0fa6f407' | |||
| down_revision = '532b3f888abf' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False)) | |||
| batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False)) | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx')) | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.drop_index(batch_op.f('workflow_run_created_at_idx')) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False) | |||
| with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: | |||
| batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False) | |||
| with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: | |||
| batch_op.drop_column('sse_read_timeout') | |||
| batch_op.drop_column('timeout') | |||
| # ### end Alembic commands ### | |||
| @@ -278,6 +278,8 @@ class MCPToolProvider(Base): | |||
| updated_at: Mapped[datetime] = mapped_column( | |||
| sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") | |||
| ) | |||
| timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30")) | |||
| sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300")) | |||
| def load_user(self) -> Account | None: | |||
| return db.session.query(Account).where(Account.id == self.user_id).first() | |||
| @@ -59,6 +59,8 @@ class MCPToolManageService: | |||
| icon_type: str, | |||
| icon_background: str, | |||
| server_identifier: str, | |||
| timeout: float, | |||
| sse_read_timeout: float, | |||
| ) -> ToolProviderApiEntity: | |||
| server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() | |||
| existing_provider = ( | |||
| @@ -91,6 +93,8 @@ class MCPToolManageService: | |||
| tools="[]", | |||
| icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, | |||
| server_identifier=server_identifier, | |||
| timeout=timeout, | |||
| sse_read_timeout=sse_read_timeout, | |||
| ) | |||
| db.session.add(mcp_tool) | |||
| db.session.commit() | |||
| @@ -166,6 +170,8 @@ class MCPToolManageService: | |||
| icon_type: str, | |||
| icon_background: str, | |||
| server_identifier: str, | |||
| timeout: float | None = None, | |||
| sse_read_timeout: float | None = None, | |||
| ): | |||
| mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) | |||
| @@ -197,6 +203,10 @@ class MCPToolManageService: | |||
| mcp_provider.tools = reconnect_result["tools"] | |||
| mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] | |||
| if timeout is not None: | |||
| mcp_provider.timeout = timeout | |||
| if sse_read_timeout is not None: | |||
| mcp_provider.sse_read_timeout = sse_read_timeout | |||
| db.session.commit() | |||
| except IntegrityError as e: | |||
| db.session.rollback() | |||