You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mcp_client.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import logging
  2. from collections.abc import Callable
  3. from contextlib import AbstractContextManager, ExitStack
  4. from types import TracebackType
  5. from typing import Any
  6. from urllib.parse import urlparse
  7. from core.mcp.client.sse_client import sse_client
  8. from core.mcp.client.streamable_client import streamablehttp_client
  9. from core.mcp.error import MCPAuthError, MCPConnectionError
  10. from core.mcp.session.client_session import ClientSession
  11. from core.mcp.types import Tool
  12. logger = logging.getLogger(__name__)
  13. class MCPClient:
  14. def __init__(
  15. self,
  16. server_url: str,
  17. provider_id: str,
  18. tenant_id: str,
  19. authed: bool = True,
  20. authorization_code: str | None = None,
  21. for_list: bool = False,
  22. headers: dict[str, str] | None = None,
  23. timeout: float | None = None,
  24. sse_read_timeout: float | None = None,
  25. ):
  26. # Initialize info
  27. self.provider_id = provider_id
  28. self.tenant_id = tenant_id
  29. self.client_type = "streamable"
  30. self.server_url = server_url
  31. self.headers = headers or {}
  32. self.timeout = timeout
  33. self.sse_read_timeout = sse_read_timeout
  34. # Authentication info
  35. self.authed = authed
  36. self.authorization_code = authorization_code
  37. if authed:
  38. from core.mcp.auth.auth_provider import OAuthClientProvider
  39. self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
  40. self.token = self.provider.tokens()
  41. # Initialize session and client objects
  42. self._session: ClientSession | None = None
  43. self._streams_context: AbstractContextManager[Any] | None = None
  44. self._session_context: ClientSession | None = None
  45. self._exit_stack = ExitStack()
  46. # Whether the client has been initialized
  47. self._initialized = False
  48. def __enter__(self):
  49. self._initialize()
  50. self._initialized = True
  51. return self
  52. def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
  53. self.cleanup()
  54. def _initialize(
  55. self,
  56. ):
  57. """Initialize the client with fallback to SSE if streamable connection fails"""
  58. connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
  59. "mcp": streamablehttp_client,
  60. "sse": sse_client,
  61. }
  62. parsed_url = urlparse(self.server_url)
  63. path = parsed_url.path or ""
  64. method_name = path.rstrip("/").split("/")[-1] if path else ""
  65. if method_name in connection_methods:
  66. client_factory = connection_methods[method_name]
  67. self.connect_server(client_factory, method_name)
  68. else:
  69. try:
  70. logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
  71. self.connect_server(sse_client, "sse")
  72. except MCPConnectionError:
  73. logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
  74. self.connect_server(streamablehttp_client, "mcp")
  75. def connect_server(
  76. self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
  77. ):
  78. from core.mcp.auth.auth_flow import auth
  79. try:
  80. headers = (
  81. {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
  82. if self.authed and self.token
  83. else self.headers
  84. )
  85. self._streams_context = client_factory(
  86. url=self.server_url,
  87. headers=headers,
  88. timeout=self.timeout,
  89. sse_read_timeout=self.sse_read_timeout,
  90. )
  91. if not self._streams_context:
  92. raise MCPConnectionError("Failed to create connection context")
  93. # Use exit_stack to manage context managers properly
  94. if method_name == "mcp":
  95. read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
  96. streams = (read_stream, write_stream)
  97. else: # sse_client
  98. streams = self._exit_stack.enter_context(self._streams_context)
  99. self._session_context = ClientSession(*streams)
  100. self._session = self._exit_stack.enter_context(self._session_context)
  101. self._session.initialize()
  102. return
  103. except MCPAuthError:
  104. if not self.authed:
  105. raise
  106. try:
  107. auth(self.provider, self.server_url, self.authorization_code)
  108. except Exception as e:
  109. raise ValueError(f"Failed to authenticate: {e}")
  110. self.token = self.provider.tokens()
  111. if first_try:
  112. return self.connect_server(client_factory, method_name, first_try=False)
  113. def list_tools(self) -> list[Tool]:
  114. """Connect to an MCP server running with SSE transport"""
  115. # List available tools to verify connection
  116. if not self._initialized or not self._session:
  117. raise ValueError("Session not initialized.")
  118. response = self._session.list_tools()
  119. tools = response.tools
  120. return tools
  121. def invoke_tool(self, tool_name: str, tool_args: dict):
  122. """Call a tool"""
  123. if not self._initialized or not self._session:
  124. raise ValueError("Session not initialized.")
  125. return self._session.call_tool(tool_name, tool_args)
  126. def cleanup(self):
  127. """Clean up resources"""
  128. try:
  129. # ExitStack will handle proper cleanup of all managed context managers
  130. self._exit_stack.close()
  131. except Exception as e:
  132. logger.exception("Error during cleanup")
  133. raise ValueError(f"Error during cleanup: {e}")
  134. finally:
  135. self._session = None
  136. self._session_context = None
  137. self._streams_context = None
  138. self._initialized = False