| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 | import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
logger = logging.getLogger(__name__)
class MCPClient:
    def __init__(
        self,
        server_url: str,
        provider_id: str,
        tenant_id: str,
        authed: bool = True,
        authorization_code: Optional[str] = None,
        for_list: bool = False,
    ):
        # Initialize info
        self.provider_id = provider_id
        self.tenant_id = tenant_id
        self.client_type = "streamable"
        self.server_url = server_url
        # Authentication info
        self.authed = authed
        self.authorization_code = authorization_code
        if authed:
            from core.mcp.auth.auth_provider import OAuthClientProvider
            self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
            self.token = self.provider.tokens()
        # Initialize session and client objects
        self._session: Optional[ClientSession] = None
        self._streams_context: Optional[AbstractContextManager[Any]] = None
        self._session_context: Optional[ClientSession] = None
        self.exit_stack = ExitStack()
        # Whether the client has been initialized
        self._initialized = False
    def __enter__(self):
        self._initialize()
        self._initialized = True
        return self
    def __exit__(
        self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
    ):
        self.cleanup()
    def _initialize(
        self,
    ):
        """Initialize the client with fallback to SSE if streamable connection fails"""
        connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
            "mcp": streamablehttp_client,
            "sse": sse_client,
        }
        parsed_url = urlparse(self.server_url)
        path = parsed_url.path or ""
        method_name = path.rstrip("/").split("/")[-1] if path else ""
        if method_name in connection_methods:
            client_factory = connection_methods[method_name]
            self.connect_server(client_factory, method_name)
        else:
            try:
                logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
                self.connect_server(sse_client, "sse")
            except MCPConnectionError:
                logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
                self.connect_server(streamablehttp_client, "mcp")
    def connect_server(
        self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
    ):
        from core.mcp.auth.auth_flow import auth
        try:
            headers = (
                {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
                if self.authed and self.token
                else {}
            )
            self._streams_context = client_factory(url=self.server_url, headers=headers)
            if not self._streams_context:
                raise MCPConnectionError("Failed to create connection context")
            # Use exit_stack to manage context managers properly
            if method_name == "mcp":
                read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
                streams = (read_stream, write_stream)
            else:  # sse_client
                streams = self.exit_stack.enter_context(self._streams_context)
            self._session_context = ClientSession(*streams)
            self._session = self.exit_stack.enter_context(self._session_context)
            session = cast(ClientSession, self._session)
            session.initialize()
            return
        except MCPAuthError:
            if not self.authed:
                raise
            try:
                auth(self.provider, self.server_url, self.authorization_code)
            except Exception as e:
                raise ValueError(f"Failed to authenticate: {e}")
            self.token = self.provider.tokens()
            if first_try:
                return self.connect_server(client_factory, method_name, first_try=False)
        except MCPConnectionError:
            raise
    def list_tools(self) -> list[Tool]:
        """Connect to an MCP server running with SSE transport"""
        # List available tools to verify connection
        if not self._initialized or not self._session:
            raise ValueError("Session not initialized.")
        response = self._session.list_tools()
        tools = response.tools
        return tools
    def invoke_tool(self, tool_name: str, tool_args: dict):
        """Call a tool"""
        if not self._initialized or not self._session:
            raise ValueError("Session not initialized.")
        return self._session.call_tool(tool_name, tool_args)
    def cleanup(self):
        """Clean up resources"""
        try:
            # ExitStack will handle proper cleanup of all managed context managers
            self.exit_stack.close()
        except Exception as e:
            logging.exception("Error during cleanup")
            raise ValueError(f"Error during cleanup: {e}")
        finally:
            self._session = None
            self._session_context = None
            self._streams_context = None
            self._initialized = False
 |