You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mcp_client.py 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import logging
  2. from collections.abc import Callable
  3. from contextlib import AbstractContextManager, ExitStack
  4. from types import TracebackType
  5. from typing import Any, Optional, cast
  6. from urllib.parse import urlparse
  7. from core.mcp.client.sse_client import sse_client
  8. from core.mcp.client.streamable_client import streamablehttp_client
  9. from core.mcp.error import MCPAuthError, MCPConnectionError
  10. from core.mcp.session.client_session import ClientSession
  11. from core.mcp.types import Tool
  12. logger = logging.getLogger(__name__)
  13. class MCPClient:
  14. def __init__(
  15. self,
  16. server_url: str,
  17. provider_id: str,
  18. tenant_id: str,
  19. authed: bool = True,
  20. authorization_code: Optional[str] = None,
  21. for_list: bool = False,
  22. ):
  23. # Initialize info
  24. self.provider_id = provider_id
  25. self.tenant_id = tenant_id
  26. self.client_type = "streamable"
  27. self.server_url = server_url
  28. # Authentication info
  29. self.authed = authed
  30. self.authorization_code = authorization_code
  31. if authed:
  32. from core.mcp.auth.auth_provider import OAuthClientProvider
  33. self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
  34. self.token = self.provider.tokens()
  35. # Initialize session and client objects
  36. self._session: Optional[ClientSession] = None
  37. self._streams_context: Optional[AbstractContextManager[Any]] = None
  38. self._session_context: Optional[ClientSession] = None
  39. self.exit_stack = ExitStack()
  40. # Whether the client has been initialized
  41. self._initialized = False
  42. def __enter__(self):
  43. self._initialize()
  44. self._initialized = True
  45. return self
  46. def __exit__(
  47. self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
  48. ):
  49. self.cleanup()
  50. def _initialize(
  51. self,
  52. ):
  53. """Initialize the client with fallback to SSE if streamable connection fails"""
  54. connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
  55. "mcp": streamablehttp_client,
  56. "sse": sse_client,
  57. }
  58. parsed_url = urlparse(self.server_url)
  59. path = parsed_url.path or ""
  60. method_name = path.rstrip("/").split("/")[-1] if path else ""
  61. if method_name in connection_methods:
  62. client_factory = connection_methods[method_name]
  63. self.connect_server(client_factory, method_name)
  64. else:
  65. try:
  66. logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
  67. self.connect_server(sse_client, "sse")
  68. except MCPConnectionError:
  69. logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
  70. self.connect_server(streamablehttp_client, "mcp")
  71. def connect_server(
  72. self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
  73. ):
  74. from core.mcp.auth.auth_flow import auth
  75. try:
  76. headers = (
  77. {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
  78. if self.authed and self.token
  79. else {}
  80. )
  81. self._streams_context = client_factory(url=self.server_url, headers=headers)
  82. if not self._streams_context:
  83. raise MCPConnectionError("Failed to create connection context")
  84. # Use exit_stack to manage context managers properly
  85. if method_name == "mcp":
  86. read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
  87. streams = (read_stream, write_stream)
  88. else: # sse_client
  89. streams = self.exit_stack.enter_context(self._streams_context)
  90. self._session_context = ClientSession(*streams)
  91. self._session = self.exit_stack.enter_context(self._session_context)
  92. session = cast(ClientSession, self._session)
  93. session.initialize()
  94. return
  95. except MCPAuthError:
  96. if not self.authed:
  97. raise
  98. try:
  99. auth(self.provider, self.server_url, self.authorization_code)
  100. except Exception as e:
  101. raise ValueError(f"Failed to authenticate: {e}")
  102. self.token = self.provider.tokens()
  103. if first_try:
  104. return self.connect_server(client_factory, method_name, first_try=False)
  105. except MCPConnectionError:
  106. raise
  107. def list_tools(self) -> list[Tool]:
  108. """Connect to an MCP server running with SSE transport"""
  109. # List available tools to verify connection
  110. if not self._initialized or not self._session:
  111. raise ValueError("Session not initialized.")
  112. response = self._session.list_tools()
  113. tools = response.tools
  114. return tools
  115. def invoke_tool(self, tool_name: str, tool_args: dict):
  116. """Call a tool"""
  117. if not self._initialized or not self._session:
  118. raise ValueError("Session not initialized.")
  119. return self._session.call_tool(tool_name, tool_args)
  120. def cleanup(self):
  121. """Clean up resources"""
  122. try:
  123. # ExitStack will handle proper cleanup of all managed context managers
  124. self.exit_stack.close()
  125. except Exception as e:
  126. logging.exception("Error during cleanup")
  127. raise ValueError(f"Error during cleanup: {e}")
  128. finally:
  129. self._session = None
  130. self._session_context = None
  131. self._streams_context = None
  132. self._initialized = False