Co-authored-by: crazywoola <427733928@qq.com>tags/1.8.0
| parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") | parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") | ||||
| parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") | parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") | ||||
| parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | ||||
| parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) | |||||
| parser.add_argument( | |||||
| "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 | |||||
| ) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| user = current_user | user = current_user | ||||
| if not is_valid_url(args["server_url"]): | if not is_valid_url(args["server_url"]): | ||||
| icon_background=args["icon_background"], | icon_background=args["icon_background"], | ||||
| user_id=user.id, | user_id=user.id, | ||||
| server_identifier=args["server_identifier"], | server_identifier=args["server_identifier"], | ||||
| timeout=args["timeout"], | |||||
| sse_read_timeout=args["sse_read_timeout"], | |||||
| ) | ) | ||||
| ) | ) | ||||
| parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") | parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") | ||||
| parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") | parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") | ||||
| parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") | ||||
| parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") | |||||
| parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not is_valid_url(args["server_url"]): | if not is_valid_url(args["server_url"]): | ||||
| if "[__HIDDEN__]" in args["server_url"]: | if "[__HIDDEN__]" in args["server_url"]: | ||||
| icon_type=args["icon_type"], | icon_type=args["icon_type"], | ||||
| icon_background=args["icon_background"], | icon_background=args["icon_background"], | ||||
| server_identifier=args["server_identifier"], | server_identifier=args["server_identifier"], | ||||
| timeout=args.get("timeout"), | |||||
| sse_read_timeout=args.get("sse_read_timeout"), | |||||
| ) | ) | ||||
| return {"result": "success"} | return {"result": "success"} | ||||
| ) | ) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| logger.debug("Client message sent successfully: %s", response.status_code) | logger.debug("Client message sent successfully: %s", response.status_code) | ||||
| except Exception as exc: | |||||
| except Exception: | |||||
| logger.exception("Error sending message") | logger.exception("Error sending message") | ||||
| raise | raise | ||||
| class StreamableHTTPError(Exception): | class StreamableHTTPError(Exception): | ||||
| """Base exception for StreamableHTTP transport errors.""" | """Base exception for StreamableHTTP transport errors.""" | ||||
| pass | |||||
| class ResumptionError(StreamableHTTPError): | class ResumptionError(StreamableHTTPError): | ||||
| """Raised when resumption request is invalid.""" | """Raised when resumption request is invalid.""" | ||||
| pass | |||||
| @dataclass | @dataclass | ||||
| class RequestContext: | class RequestContext: | ||||
| session_message: SessionMessage | session_message: SessionMessage | ||||
| metadata: ClientMessageMetadata | None | metadata: ClientMessageMetadata | None | ||||
| server_to_client_queue: ServerToClientQueue # Renamed for clarity | server_to_client_queue: ServerToClientQueue # Renamed for clarity | ||||
| sse_read_timeout: timedelta | |||||
| sse_read_timeout: float | |||||
| class StreamableHTTPTransport: | class StreamableHTTPTransport: | ||||
| self, | self, | ||||
| url: str, | url: str, | ||||
| headers: dict[str, Any] | None = None, | headers: dict[str, Any] | None = None, | ||||
| timeout: timedelta = timedelta(seconds=30), | |||||
| sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | |||||
| timeout: float | timedelta = 30, | |||||
| sse_read_timeout: float | timedelta = 60 * 5, | |||||
| ) -> None: | ) -> None: | ||||
| """Initialize the StreamableHTTP transport. | """Initialize the StreamableHTTP transport. | ||||
| """ | """ | ||||
| self.url = url | self.url = url | ||||
| self.headers = headers or {} | self.headers = headers or {} | ||||
| self.timeout = timeout | |||||
| self.sse_read_timeout = sse_read_timeout | |||||
| self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout | |||||
| self.sse_read_timeout = ( | |||||
| sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout | |||||
| ) | |||||
| self.session_id: str | None = None | self.session_id: str | None = None | ||||
| self.request_headers = { | self.request_headers = { | ||||
| ACCEPT: f"{JSON}, {SSE}", | ACCEPT: f"{JSON}, {SSE}", | ||||
| with ssrf_proxy_sse_connect( | with ssrf_proxy_sse_connect( | ||||
| self.url, | self.url, | ||||
| headers=headers, | headers=headers, | ||||
| timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), | |||||
| timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), | |||||
| client=client, | client=client, | ||||
| method="GET", | method="GET", | ||||
| ) as event_source: | ) as event_source: | ||||
| with ssrf_proxy_sse_connect( | with ssrf_proxy_sse_connect( | ||||
| self.url, | self.url, | ||||
| headers=headers, | headers=headers, | ||||
| timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), | |||||
| timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), | |||||
| client=ctx.client, | client=ctx.client, | ||||
| method="GET", | method="GET", | ||||
| ) as event_source: | ) as event_source: | ||||
| def streamablehttp_client( | def streamablehttp_client( | ||||
| url: str, | url: str, | ||||
| headers: dict[str, Any] | None = None, | headers: dict[str, Any] | None = None, | ||||
| timeout: timedelta = timedelta(seconds=30), | |||||
| sse_read_timeout: timedelta = timedelta(seconds=60 * 5), | |||||
| timeout: float | timedelta = 30, | |||||
| sse_read_timeout: float | timedelta = 60 * 5, | |||||
| terminate_on_close: bool = True, | terminate_on_close: bool = True, | ||||
| ) -> Generator[ | ) -> Generator[ | ||||
| tuple[ | tuple[ | ||||
| try: | try: | ||||
| with create_ssrf_proxy_mcp_http_client( | with create_ssrf_proxy_mcp_http_client( | ||||
| headers=transport.request_headers, | headers=transport.request_headers, | ||||
| timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), | |||||
| timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), | |||||
| ) as client: | ) as client: | ||||
| # Define callbacks that need access to thread pool | # Define callbacks that need access to thread pool | ||||
| def start_get_stream() -> None: | def start_get_stream() -> None: |
| authed: bool = True, | authed: bool = True, | ||||
| authorization_code: Optional[str] = None, | authorization_code: Optional[str] = None, | ||||
| for_list: bool = False, | for_list: bool = False, | ||||
| headers: Optional[dict[str, str]] = None, | |||||
| timeout: Optional[float] = None, | |||||
| sse_read_timeout: Optional[float] = None, | |||||
| ): | ): | ||||
| # Initialize info | # Initialize info | ||||
| self.provider_id = provider_id | self.provider_id = provider_id | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.client_type = "streamable" | self.client_type = "streamable" | ||||
| self.server_url = server_url | self.server_url = server_url | ||||
| self.headers = headers or {} | |||||
| self.timeout = timeout | |||||
| self.sse_read_timeout = sse_read_timeout | |||||
| # Authentication info | # Authentication info | ||||
| self.authed = authed | self.authed = authed | ||||
| self._session: Optional[ClientSession] = None | self._session: Optional[ClientSession] = None | ||||
| self._streams_context: Optional[AbstractContextManager[Any]] = None | self._streams_context: Optional[AbstractContextManager[Any]] = None | ||||
| self._session_context: Optional[ClientSession] = None | self._session_context: Optional[ClientSession] = None | ||||
| self.exit_stack = ExitStack() | |||||
| self._exit_stack = ExitStack() | |||||
| # Whether the client has been initialized | # Whether the client has been initialized | ||||
| self._initialized = False | self._initialized = False | ||||
| headers = ( | headers = ( | ||||
| {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} | {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} | ||||
| if self.authed and self.token | if self.authed and self.token | ||||
| else {} | |||||
| else self.headers | |||||
| ) | |||||
| self._streams_context = client_factory( | |||||
| url=self.server_url, | |||||
| headers=headers, | |||||
| timeout=self.timeout, | |||||
| sse_read_timeout=self.sse_read_timeout, | |||||
| ) | ) | ||||
| self._streams_context = client_factory(url=self.server_url, headers=headers) | |||||
| if not self._streams_context: | if not self._streams_context: | ||||
| raise MCPConnectionError("Failed to create connection context") | raise MCPConnectionError("Failed to create connection context") | ||||
| # Use exit_stack to manage context managers properly | # Use exit_stack to manage context managers properly | ||||
| if method_name == "mcp": | if method_name == "mcp": | ||||
| read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) | |||||
| read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) | |||||
| streams = (read_stream, write_stream) | streams = (read_stream, write_stream) | ||||
| else: # sse_client | else: # sse_client | ||||
| streams = self.exit_stack.enter_context(self._streams_context) | |||||
| streams = self._exit_stack.enter_context(self._streams_context) | |||||
| self._session_context = ClientSession(*streams) | self._session_context = ClientSession(*streams) | ||||
| self._session = self.exit_stack.enter_context(self._session_context) | |||||
| self._session = self._exit_stack.enter_context(self._session_context) | |||||
| session = cast(ClientSession, self._session) | session = cast(ClientSession, self._session) | ||||
| session.initialize() | session.initialize() | ||||
| return | return | ||||
| if first_try: | if first_try: | ||||
| return self.connect_server(client_factory, method_name, first_try=False) | return self.connect_server(client_factory, method_name, first_try=False) | ||||
| except MCPConnectionError: | |||||
| raise | |||||
| def list_tools(self) -> list[Tool]: | def list_tools(self) -> list[Tool]: | ||||
| """Connect to an MCP server running with SSE transport""" | """Connect to an MCP server running with SSE transport""" | ||||
| # List available tools to verify connection | # List available tools to verify connection | ||||
| """Clean up resources""" | """Clean up resources""" | ||||
| try: | try: | ||||
| # ExitStack will handle proper cleanup of all managed context managers | # ExitStack will handle proper cleanup of all managed context managers | ||||
| self.exit_stack.close() | |||||
| self._exit_stack.close() | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("Error during cleanup") | logging.exception("Error during cleanup") | ||||
| raise ValueError(f"Error during cleanup: {e}") | raise ValueError(f"Error during cleanup: {e}") |
| import queue | import queue | ||||
| from collections.abc import Callable | from collections.abc import Callable | ||||
| from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError | from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError | ||||
| from contextlib import ExitStack | |||||
| from datetime import timedelta | from datetime import timedelta | ||||
| from types import TracebackType | from types import TracebackType | ||||
| from typing import Any, Generic, Self, TypeVar | from typing import Any, Generic, Self, TypeVar | ||||
| self._receive_notification_type = receive_notification_type | self._receive_notification_type = receive_notification_type | ||||
| self._session_read_timeout_seconds = read_timeout_seconds | self._session_read_timeout_seconds = read_timeout_seconds | ||||
| self._in_flight = {} | self._in_flight = {} | ||||
| self._exit_stack = ExitStack() | |||||
| # Initialize executor and future to None for proper cleanup checks | # Initialize executor and future to None for proper cleanup checks | ||||
| self._executor: ThreadPoolExecutor | None = None | self._executor: ThreadPoolExecutor | None = None | ||||
| self._receiver_future: Future | None = None | self._receiver_future: Future | None = None | ||||
| self._handle_incoming(RuntimeError(f"Server Error: {message}")) | self._handle_incoming(RuntimeError(f"Server Error: {message}")) | ||||
| except queue.Empty: | except queue.Empty: | ||||
| continue | continue | ||||
| except Exception as e: | |||||
| except Exception: | |||||
| logging.exception("Error in message processing loop") | logging.exception("Error in message processing loop") | ||||
| raise | raise | ||||
| If the request is responded to within this method, it will not be | If the request is responded to within this method, it will not be | ||||
| forwarded on to the message stream. | forwarded on to the message stream. | ||||
| """ | """ | ||||
| pass | |||||
| def _received_notification(self, notification: ReceiveNotificationT) -> None: | def _received_notification(self, notification: ReceiveNotificationT) -> None: | ||||
| """ | """ | ||||
| Can be overridden by subclasses to handle a notification without needing | Can be overridden by subclasses to handle a notification without needing | ||||
| to listen on the message stream. | to listen on the message stream. | ||||
| """ | """ | ||||
| pass | |||||
| def send_progress_notification( | def send_progress_notification( | ||||
| self, progress_token: str | int, progress: float, total: float | None = None | self, progress_token: str | int, progress: float, total: float | None = None | ||||
| Sends a progress notification for a request that is currently being | Sends a progress notification for a request that is currently being | ||||
| processed. | processed. | ||||
| """ | """ | ||||
| pass | |||||
| def _handle_incoming( | def _handle_incoming( | ||||
| self, | self, | ||||
| req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, | ||||
| ) -> None: | ) -> None: | ||||
| """A generic handler for incoming messages. Overwritten by subclasses.""" | """A generic handler for incoming messages. Overwritten by subclasses.""" | ||||
| pass |
| import queue | |||||
| from datetime import timedelta | from datetime import timedelta | ||||
| from typing import Any, Protocol | from typing import Any, Protocol | ||||
| ): | ): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| read_stream, | |||||
| write_stream, | |||||
| read_stream: queue.Queue, | |||||
| write_stream: queue.Queue, | |||||
| read_timeout_seconds: timedelta | None = None, | read_timeout_seconds: timedelta | None = None, | ||||
| sampling_callback: SamplingFnT | None = None, | sampling_callback: SamplingFnT | None = None, | ||||
| list_roots_callback: ListRootsFnT | None = None, | list_roots_callback: ListRootsFnT | None = None, |
| class ToolProviderController(ABC): | class ToolProviderController(ABC): | ||||
| entity: ToolProviderEntity | |||||
| def __init__(self, entity: ToolProviderEntity) -> None: | def __init__(self, entity: ToolProviderEntity) -> None: | ||||
| self.entity = entity | self.entity = entity | ||||
| import json | import json | ||||
| from typing import Any | |||||
| from typing import Any, Optional | |||||
| from core.mcp.types import Tool as RemoteMCPTool | from core.mcp.types import Tool as RemoteMCPTool | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| class MCPToolProviderController(ToolProviderController): | class MCPToolProviderController(ToolProviderController): | ||||
| provider_id: str | |||||
| entity: ToolProviderEntityWithPlugin | |||||
| def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: | |||||
| def __init__( | |||||
| self, | |||||
| entity: ToolProviderEntityWithPlugin, | |||||
| provider_id: str, | |||||
| tenant_id: str, | |||||
| server_url: str, | |||||
| headers: Optional[dict[str, str]] = None, | |||||
| timeout: Optional[float] = None, | |||||
| sse_read_timeout: Optional[float] = None, | |||||
| ) -> None: | |||||
| super().__init__(entity) | super().__init__(entity) | ||||
| self.entity = entity | |||||
| self.entity: ToolProviderEntityWithPlugin = entity | |||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.provider_id = provider_id | self.provider_id = provider_id | ||||
| self.server_url = server_url | self.server_url = server_url | ||||
| self.headers = headers or {} | |||||
| self.timeout = timeout | |||||
| self.sse_read_timeout = sse_read_timeout | |||||
| @property | @property | ||||
| def provider_type(self) -> ToolProviderType: | def provider_type(self) -> ToolProviderType: | ||||
| provider_id=db_provider.server_identifier or "", | provider_id=db_provider.server_identifier or "", | ||||
| tenant_id=db_provider.tenant_id or "", | tenant_id=db_provider.tenant_id or "", | ||||
| server_url=db_provider.decrypted_server_url, | server_url=db_provider.decrypted_server_url, | ||||
| headers={}, # TODO: get headers from db provider | |||||
| timeout=db_provider.timeout, | |||||
| sse_read_timeout=db_provider.sse_read_timeout, | |||||
| ) | ) | ||||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | ||||
| icon=self.entity.identity.icon, | icon=self.entity.identity.icon, | ||||
| server_url=self.server_url, | server_url=self.server_url, | ||||
| provider_id=self.provider_id, | provider_id=self.provider_id, | ||||
| headers=self.headers, | |||||
| timeout=self.timeout, | |||||
| sse_read_timeout=self.sse_read_timeout, | |||||
| ) | ) | ||||
| def get_tools(self) -> list[MCPTool]: # type: ignore | def get_tools(self) -> list[MCPTool]: # type: ignore | ||||
| icon=self.entity.identity.icon, | icon=self.entity.identity.icon, | ||||
| server_url=self.server_url, | server_url=self.server_url, | ||||
| provider_id=self.provider_id, | provider_id=self.provider_id, | ||||
| headers=self.headers, | |||||
| timeout=self.timeout, | |||||
| sse_read_timeout=self.sse_read_timeout, | |||||
| ) | ) | ||||
| for tool_entity in self.entity.tools | for tool_entity in self.entity.tools | ||||
| ] | ] |
| class MCPTool(Tool): | class MCPTool(Tool): | ||||
| def __init__( | def __init__( | ||||
| self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str | |||||
| self, | |||||
| entity: ToolEntity, | |||||
| runtime: ToolRuntime, | |||||
| tenant_id: str, | |||||
| icon: str, | |||||
| server_url: str, | |||||
| provider_id: str, | |||||
| headers: Optional[dict[str, str]] = None, | |||||
| timeout: Optional[float] = None, | |||||
| sse_read_timeout: Optional[float] = None, | |||||
| ) -> None: | ) -> None: | ||||
| super().__init__(entity, runtime) | super().__init__(entity, runtime) | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| self.icon = icon | self.icon = icon | ||||
| self.server_url = server_url | self.server_url = server_url | ||||
| self.provider_id = provider_id | self.provider_id = provider_id | ||||
| self.headers = headers or {} | |||||
| self.timeout = timeout | |||||
| self.sse_read_timeout = sse_read_timeout | |||||
| def tool_provider_type(self) -> ToolProviderType: | def tool_provider_type(self) -> ToolProviderType: | ||||
| return ToolProviderType.MCP | return ToolProviderType.MCP | ||||
| from core.tools.errors import ToolInvokeError | from core.tools.errors import ToolInvokeError | ||||
| try: | try: | ||||
| with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: | |||||
| with MCPClient( | |||||
| self.server_url, | |||||
| self.provider_id, | |||||
| self.tenant_id, | |||||
| authed=True, | |||||
| headers=self.headers, | |||||
| timeout=self.timeout, | |||||
| sse_read_timeout=self.sse_read_timeout, | |||||
| ) as mcp_client: | |||||
| tool_parameters = self._handle_none_parameter(tool_parameters) | tool_parameters = self._handle_none_parameter(tool_parameters) | ||||
| result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) | result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) | ||||
| except MCPAuthError as e: | except MCPAuthError as e: | ||||
| icon=self.icon, | icon=self.icon, | ||||
| server_url=self.server_url, | server_url=self.server_url, | ||||
| provider_id=self.provider_id, | provider_id=self.provider_id, | ||||
| headers=self.headers, | |||||
| timeout=self.timeout, | |||||
| sse_read_timeout=self.sse_read_timeout, | |||||
| ) | ) | ||||
| def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: | def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: |
| """ | """ | ||||
| get api provider | get api provider | ||||
| """ | """ | ||||
| """ | |||||
| get tool provider | |||||
| """ | |||||
| provider_name = provider | provider_name = provider | ||||
| provider_obj: ApiToolProvider | None = ( | provider_obj: ApiToolProvider | None = ( | ||||
| db.session.query(ApiToolProvider) | db.session.query(ApiToolProvider) |
| """add timeout for tool_mcp_providers | |||||
| Revision ID: fa8b0fa6f407 | |||||
| Revises: 532b3f888abf | |||||
| Create Date: 2025-08-07 11:15:31.517985 | |||||
| """ | |||||
| from alembic import op | |||||
| import sqlalchemy as sa | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'fa8b0fa6f407' | |||||
| down_revision = '532b3f888abf' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False)) | |||||
| batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: | |||||
| batch_op.drop_column('sse_read_timeout') | |||||
| batch_op.drop_column('timeout') | |||||
| # ### end Alembic commands ### |
| updated_at: Mapped[datetime] = mapped_column( | updated_at: Mapped[datetime] = mapped_column( | ||||
| sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") | sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") | ||||
| ) | ) | ||||
| timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30")) | |||||
| sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300")) | |||||
| def load_user(self) -> Account | None: | def load_user(self) -> Account | None: | ||||
| return db.session.query(Account).where(Account.id == self.user_id).first() | return db.session.query(Account).where(Account.id == self.user_id).first() |
| icon_type: str, | icon_type: str, | ||||
| icon_background: str, | icon_background: str, | ||||
| server_identifier: str, | server_identifier: str, | ||||
| timeout: float, | |||||
| sse_read_timeout: float, | |||||
| ) -> ToolProviderApiEntity: | ) -> ToolProviderApiEntity: | ||||
| server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() | server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() | ||||
| existing_provider = ( | existing_provider = ( | ||||
| tools="[]", | tools="[]", | ||||
| icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, | icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, | ||||
| server_identifier=server_identifier, | server_identifier=server_identifier, | ||||
| timeout=timeout, | |||||
| sse_read_timeout=sse_read_timeout, | |||||
| ) | ) | ||||
| db.session.add(mcp_tool) | db.session.add(mcp_tool) | ||||
| db.session.commit() | db.session.commit() | ||||
| icon_type: str, | icon_type: str, | ||||
| icon_background: str, | icon_background: str, | ||||
| server_identifier: str, | server_identifier: str, | ||||
| timeout: float | None = None, | |||||
| sse_read_timeout: float | None = None, | |||||
| ): | ): | ||||
| mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) | mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) | ||||
| mcp_provider.tools = reconnect_result["tools"] | mcp_provider.tools = reconnect_result["tools"] | ||||
| mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] | mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] | ||||
| if timeout is not None: | |||||
| mcp_provider.timeout = timeout | |||||
| if sse_read_timeout is not None: | |||||
| mcp_provider.sse_read_timeout = sse_read_timeout | |||||
| db.session.commit() | db.session.commit() | ||||
| except IntegrityError as e: | except IntegrityError as e: | ||||
| db.session.rollback() | db.session.rollback() |
| icon: string | icon: string | ||||
| icon_background?: string | null | icon_background?: string | null | ||||
| server_identifier: string | server_identifier: string | ||||
| timeout: number | |||||
| sse_read_timeout: number | |||||
| }) => void | }) => void | ||||
| onHide: () => void | onHide: () => void | ||||
| } | } | ||||
| const [appIcon, setAppIcon] = useState<AppIconSelection>(getIcon(data)) | const [appIcon, setAppIcon] = useState<AppIconSelection>(getIcon(data)) | ||||
| const [showAppIconPicker, setShowAppIconPicker] = useState(false) | const [showAppIconPicker, setShowAppIconPicker] = useState(false) | ||||
| const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '') | const [serverIdentifier, setServerIdentifier] = React.useState(data?.server_identifier || '') | ||||
| const [timeout, setMcpTimeout] = React.useState(30) | |||||
| const [sseReadTimeout, setSseReadTimeout] = React.useState(300) | |||||
| const [isFetchingIcon, setIsFetchingIcon] = useState(false) | const [isFetchingIcon, setIsFetchingIcon] = useState(false) | ||||
| const appIconRef = useRef<HTMLDivElement>(null) | const appIconRef = useRef<HTMLDivElement>(null) | ||||
| const isHovering = useHover(appIconRef) | const isHovering = useHover(appIconRef) | ||||
| const urlPattern = /^(https?:\/\/)((([a-z\d]([a-z\d-]*[a-z\d])*)\.)+[a-z]{2,}|((\d{1,3}\.){3}\d{1,3})|localhost)(\:\d+)?(\/[-a-z\d%_.~+]*)*(\?[;&a-z\d%_.~+=-]*)?/i | const urlPattern = /^(https?:\/\/)((([a-z\d]([a-z\d-]*[a-z\d])*)\.)+[a-z]{2,}|((\d{1,3}\.){3}\d{1,3})|localhost)(\:\d+)?(\/[-a-z\d%_.~+]*)*(\?[;&a-z\d%_.~+=-]*)?/i | ||||
| return urlPattern.test(string) | return urlPattern.test(string) | ||||
| } | } | ||||
| catch (e) { | |||||
| catch { | |||||
| return false | return false | ||||
| } | } | ||||
| } | } | ||||
| icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, | icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, | ||||
| icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, | icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, | ||||
| server_identifier: serverIdentifier.trim(), | server_identifier: serverIdentifier.trim(), | ||||
| timeout: timeout || 30, | |||||
| sse_read_timeout: sseReadTimeout || 300, | |||||
| }) | }) | ||||
| if(isCreate) | if(isCreate) | ||||
| onHide() | onHide() | ||||
| </div> | </div> | ||||
| )} | )} | ||||
| </div> | </div> | ||||
| <div> | |||||
| <div className='mb-1 flex h-6 items-center'> | |||||
| <span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.timeout')}</span> | |||||
| </div> | |||||
| <Input | |||||
| type='number' | |||||
| value={timeout} | |||||
| onChange={e => setMcpTimeout(Number(e.target.value))} | |||||
| onBlur={e => handleBlur(e.target.value.trim())} | |||||
| placeholder={t('tools.mcp.modal.timeoutPlaceholder')} | |||||
| /> | |||||
| </div> | |||||
| <div> | |||||
| <div className='mb-1 flex h-6 items-center'> | |||||
| <span className='system-sm-medium text-text-secondary'>{t('tools.mcp.modal.sseReadTimeout')}</span> | |||||
| </div> | |||||
| <Input | |||||
| type='number' | |||||
| value={sseReadTimeout} | |||||
| onChange={e => setSseReadTimeout(Number(e.target.value))} | |||||
| onBlur={e => handleBlur(e.target.value.trim())} | |||||
| placeholder={t('tools.mcp.modal.timeoutPlaceholder')} | |||||
| /> | |||||
| </div> | |||||
| </div> | </div> | ||||
| <div className='flex flex-row-reverse pt-5'> | <div className='flex flex-row-reverse pt-5'> | ||||
| <Button disabled={!name || !url || !serverIdentifier || isFetchingIcon} className='ml-2' variant='primary' onClick={submit}>{data ? t('tools.mcp.modal.save') : t('tools.mcp.modal.confirm')}</Button> | <Button disabled={!name || !url || !serverIdentifier || isFetchingIcon} className='ml-2' variant='primary' onClick={submit}>{data ? t('tools.mcp.modal.save') : t('tools.mcp.modal.confirm')}</Button> |
| server_url?: string | server_url?: string | ||||
| updated_at?: number | updated_at?: number | ||||
| server_identifier?: string | server_identifier?: string | ||||
| timeout?: number | |||||
| sse_read_timeout?: number | |||||
| } | } | ||||
| export type ToolParameter = { | export type ToolParameter = { |
| cancel: 'Cancel', | cancel: 'Cancel', | ||||
| save: 'Save', | save: 'Save', | ||||
| confirm: 'Add & Authorize', | confirm: 'Add & Authorize', | ||||
| timeout: 'Timeout', | |||||
| sseReadTimeout: 'SSE Read Timeout', | |||||
| }, | }, | ||||
| delete: 'Remove MCP Server', | delete: 'Remove MCP Server', | ||||
| deleteConfirmTitle: 'Would you like to remove {{mcp}}?', | deleteConfirmTitle: 'Would you like to remove {{mcp}}?', |
| cancel: '取消', | cancel: '取消', | ||||
| save: '保存', | save: '保存', | ||||
| confirm: '添加并授权', | confirm: '添加并授权', | ||||
| timeout: '超时时间', | |||||
| sseReadTimeout: 'SSE 读取超时时间', | |||||
| }, | }, | ||||
| delete: '删除 MCP 服务', | delete: '删除 MCP 服务', | ||||
| deleteConfirmTitle: '你想要删除 {{mcp}} 吗?', | deleteConfirmTitle: '你想要删除 {{mcp}} 吗?', |
| icon_type: AppIconType | icon_type: AppIconType | ||||
| icon: string | icon: string | ||||
| icon_background?: string | null | icon_background?: string | null | ||||
| timeout?: number | |||||
| sse_read_timeout?: number | |||||
| }) => { | }) => { | ||||
| return post<ToolWithProvider>('workspaces/current/tool-provider/mcp', { | return post<ToolWithProvider>('workspaces/current/tool-provider/mcp', { | ||||
| body: { | body: { | ||||
| icon: string | icon: string | ||||
| icon_background?: string | null | icon_background?: string | null | ||||
| provider_id: string | provider_id: string | ||||
| timeout?: number | |||||
| sse_read_timeout?: number | |||||
| }) => { | }) => { | ||||
| return put('workspaces/current/tool-provider/mcp', { | return put('workspaces/current/tool-provider/mcp', { | ||||
| body: { | body: { |