Procházet zdrojové kódy

Revert "feat: improved MCP timeout" (#23602)

tags/1.7.2
crazywoola před 2 měsíci
rodič
revize
1c60b7f070
Žádný účet není propojen s e-mailovou adresou tvůrce revize

+ 0
- 10
api/controllers/console/workspace/tool_providers.py Zobrazit soubor

parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
parser.add_argument(
"sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300
)
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
icon_background=args["icon_background"], icon_background=args["icon_background"],
user_id=user.id, user_id=user.id,
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args["timeout"],
sse_read_timeout=args["sse_read_timeout"],
) )
) )


parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
parser.add_argument("timeout", type=float, required=False, nullable=True, location="json")
parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
if not is_valid_url(args["server_url"]): if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]: if "[__HIDDEN__]" in args["server_url"]:
icon_type=args["icon_type"], icon_type=args["icon_type"],
icon_background=args["icon_background"], icon_background=args["icon_background"],
server_identifier=args["server_identifier"], server_identifier=args["server_identifier"],
timeout=args.get("timeout"),
sse_read_timeout=args.get("sse_read_timeout"),
) )
return {"result": "success"} return {"result": "success"}



+ 1
- 1
api/core/mcp/client/sse_client.py Zobrazit soubor

) )
response.raise_for_status() response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code) logger.debug("Client message sent successfully: %s", response.status_code)
except Exception:
except Exception as exc:
logger.exception("Error sending message") logger.exception("Error sending message")
raise raise



+ 14
- 12
api/core/mcp/client/streamable_client.py Zobrazit soubor

class StreamableHTTPError(Exception): class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors.""" """Base exception for StreamableHTTP transport errors."""


pass



class ResumptionError(StreamableHTTPError): class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid.""" """Raised when resumption request is invalid."""


pass



@dataclass @dataclass
class RequestContext: class RequestContext:
session_message: SessionMessage session_message: SessionMessage
metadata: ClientMessageMetadata | None metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: float
sse_read_timeout: timedelta




class StreamableHTTPTransport: class StreamableHTTPTransport:
self, self,
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None: ) -> None:
"""Initialize the StreamableHTTP transport. """Initialize the StreamableHTTP transport.


""" """
self.url = url self.url = url
self.headers = headers or {} self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None self.session_id: str | None = None
self.request_headers = { self.request_headers = {
ACCEPT: f"{JSON}, {SSE}", ACCEPT: f"{JSON}, {SSE}",
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client, client=client,
method="GET", method="GET",
) as event_source: ) as event_source:
with ssrf_proxy_sse_connect( with ssrf_proxy_sse_connect(
self.url, self.url,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client, client=ctx.client,
method="GET", method="GET",
) as event_source: ) as event_source:
def streamablehttp_client( def streamablehttp_client(
url: str, url: str,
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True, terminate_on_close: bool = True,
) -> Generator[ ) -> Generator[
tuple[ tuple[
try: try:
with create_ssrf_proxy_mcp_http_client( with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client: ) as client:
# Define callbacks that need access to thread pool # Define callbacks that need access to thread pool
def start_get_stream() -> None: def start_get_stream() -> None:

+ 10
- 18
api/core/mcp/mcp_client.py Zobrazit soubor

authed: bool = True, authed: bool = True,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
for_list: bool = False, for_list: bool = False,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
): ):
# Initialize info # Initialize info
self.provider_id = provider_id self.provider_id = provider_id
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.client_type = "streamable" self.client_type = "streamable"
self.server_url = server_url self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout


# Authentication info # Authentication info
self.authed = authed self.authed = authed
self._session: Optional[ClientSession] = None self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None self._session_context: Optional[ClientSession] = None
self._exit_stack = ExitStack()
self.exit_stack = ExitStack()


# Whether the client has been initialized # Whether the client has been initialized
self._initialized = False self._initialized = False
headers = ( headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token if self.authed and self.token
else self.headers
)
self._streams_context = client_factory(
url=self.server_url,
headers=headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
else {}
) )
self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context: if not self._streams_context:
raise MCPConnectionError("Failed to create connection context") raise MCPConnectionError("Failed to create connection context")


# Use exit_stack to manage context managers properly # Use exit_stack to manage context managers properly
if method_name == "mcp": if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream) streams = (read_stream, write_stream)
else: # sse_client else: # sse_client
streams = self._exit_stack.enter_context(self._streams_context)
streams = self.exit_stack.enter_context(self._streams_context)


self._session_context = ClientSession(*streams) self._session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(self._session_context)
self._session = self.exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session) session = cast(ClientSession, self._session)
session.initialize() session.initialize()
return return
if first_try: if first_try:
return self.connect_server(client_factory, method_name, first_try=False) return self.connect_server(client_factory, method_name, first_try=False)


except MCPConnectionError:
raise

def list_tools(self) -> list[Tool]: def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport""" """Connect to an MCP server running with SSE transport"""
# List available tools to verify connection # List available tools to verify connection
"""Clean up resources""" """Clean up resources"""
try: try:
# ExitStack will handle proper cleanup of all managed context managers # ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
self.exit_stack.close()
except Exception as e: except Exception as e:
logging.exception("Error during cleanup") logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}") raise ValueError(f"Error during cleanup: {e}")

+ 7
- 1
api/core/mcp/session/base_session.py Zobrazit soubor

import queue import queue
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta from datetime import timedelta
from types import TracebackType from types import TracebackType
from typing import Any, Generic, Self, TypeVar from typing import Any, Generic, Self, TypeVar
self._receive_notification_type = receive_notification_type self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack()
# Initialize executor and future to None for proper cleanup checks # Initialize executor and future to None for proper cleanup checks
self._executor: ThreadPoolExecutor | None = None self._executor: ThreadPoolExecutor | None = None
self._receiver_future: Future | None = None self._receiver_future: Future | None = None
self._handle_incoming(RuntimeError(f"Server Error: {message}")) self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty: except queue.Empty:
continue continue
except Exception:
except Exception as e:
logging.exception("Error in message processing loop") logging.exception("Error in message processing loop")
raise raise


If the request is responded to within this method, it will not be If the request is responded to within this method, it will not be
forwarded on to the message stream. forwarded on to the message stream.
""" """
pass


def _received_notification(self, notification: ReceiveNotificationT) -> None: def _received_notification(self, notification: ReceiveNotificationT) -> None:
""" """
Can be overridden by subclasses to handle a notification without needing Can be overridden by subclasses to handle a notification without needing
to listen on the message stream. to listen on the message stream.
""" """
pass


def send_progress_notification( def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None self, progress_token: str | int, progress: float, total: float | None = None
Sends a progress notification for a request that is currently being Sends a progress notification for a request that is currently being
processed. processed.
""" """
pass


def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None: ) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses.""" """A generic handler for incoming messages. Overwritten by subclasses."""
pass

+ 2
- 3
api/core/mcp/session/client_session.py Zobrazit soubor

import queue
from datetime import timedelta from datetime import timedelta
from typing import Any, Protocol from typing import Any, Protocol


): ):
def __init__( def __init__(
self, self,
read_stream: queue.Queue,
write_stream: queue.Queue,
read_stream,
write_stream,
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None, sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None, list_roots_callback: ListRootsFnT | None = None,

+ 2
- 0
api/core/tools/__base/tool_provider.py Zobrazit soubor





class ToolProviderController(ABC): class ToolProviderController(ABC):
entity: ToolProviderEntity

def __init__(self, entity: ToolProviderEntity) -> None: def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity self.entity = entity



+ 6
- 24
api/core/tools/mcp_tool/provider.py Zobrazit soubor

import json import json
from typing import Any, Optional
from typing import Any


from core.mcp.types import Tool as RemoteMCPTool from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController




class MCPToolProviderController(ToolProviderController): class MCPToolProviderController(ToolProviderController):
def __init__(
self,
entity: ToolProviderEntityWithPlugin,
provider_id: str,
tenant_id: str,
server_url: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
provider_id: str
entity: ToolProviderEntityWithPlugin

def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
super().__init__(entity) super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
self.entity = entity
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.provider_id = provider_id self.provider_id = provider_id
self.server_url = server_url self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout


@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
provider_id=db_provider.server_identifier or "", provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "", tenant_id=db_provider.tenant_id or "",
server_url=db_provider.decrypted_server_url, server_url=db_provider.decrypted_server_url,
headers={}, # TODO: get headers from db provider
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
) )


def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
icon=self.entity.identity.icon, icon=self.entity.identity.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )


def get_tools(self) -> list[MCPTool]: # type: ignore def get_tools(self) -> list[MCPTool]: # type: ignore
icon=self.entity.identity.icon, icon=self.entity.identity.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )
for tool_entity in self.entity.tools for tool_entity in self.entity.tools
] ]

+ 2
- 25
api/core/tools/mcp_tool/tool.py Zobrazit soubor



class MCPTool(Tool): class MCPTool(Tool):
def __init__( def __init__(
self,
entity: ToolEntity,
runtime: ToolRuntime,
tenant_id: str,
icon: str,
server_url: str,
provider_id: str,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
) -> None: ) -> None:
super().__init__(entity, runtime) super().__init__(entity, runtime)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.icon = icon self.icon = icon
self.server_url = server_url self.server_url = server_url
self.provider_id = provider_id self.provider_id = provider_id
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout


def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP return ToolProviderType.MCP
from core.tools.errors import ToolInvokeError from core.tools.errors import ToolInvokeError


try: try:
with MCPClient(
self.server_url,
self.provider_id,
self.tenant_id,
authed=True,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) as mcp_client:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = self._handle_none_parameter(tool_parameters) tool_parameters = self._handle_none_parameter(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e: except MCPAuthError as e:
icon=self.icon, icon=self.icon,
server_url=self.server_url, server_url=self.server_url,
provider_id=self.provider_id, provider_id=self.provider_id,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
) )


def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:

+ 3
- 0
api/core/tools/tool_manager.py Zobrazit soubor

""" """
get api provider get api provider
""" """
"""
get tool provider
"""
provider_name = provider provider_name = provider
provider_obj: ApiToolProvider | None = ( provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)

+ 0
- 47
api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py Zobrazit soubor

"""add timeout for tool_mcp_providers

Revision ID: fa8b0fa6f407
Revises: 532b3f888abf
Create Date: 2025-08-07 11:15:31.517985

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'fa8b0fa6f407'
down_revision = '532b3f888abf'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False))
batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False))

with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_node_execution_created_at_idx'))

with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_run_created_at_idx'))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_run_created_at_idx'), ['created_at'], unique=False)

with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_node_execution_created_at_idx'), ['created_at'], unique=False)

with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
batch_op.drop_column('sse_read_timeout')
batch_op.drop_column('timeout')

# ### end Alembic commands ###

+ 0
- 2
api/models/tools.py Zobrazit soubor

updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
) )
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))


def load_user(self) -> Account | None: def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first() return db.session.query(Account).where(Account.id == self.user_id).first()

+ 0
- 10
api/services/tools/mcp_tools_manage_service.py Zobrazit soubor

icon_type: str, icon_type: str,
icon_background: str, icon_background: str,
server_identifier: str, server_identifier: str,
timeout: float,
sse_read_timeout: float,
) -> ToolProviderApiEntity: ) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = ( existing_provider = (
tools="[]", tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
server_identifier=server_identifier, server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) )
db.session.add(mcp_tool) db.session.add(mcp_tool)
db.session.commit() db.session.commit()
icon_type: str, icon_type: str,
icon_background: str, icon_background: str,
server_identifier: str, server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
): ):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)


mcp_provider.tools = reconnect_result["tools"] mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]


if timeout is not None:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
db.session.commit() db.session.commit()
except IntegrityError as e: except IntegrityError as e:
db.session.rollback() db.session.rollback()

Načítá se…
Zrušit
Uložit