您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

mcp_client.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import logging
  2. from collections.abc import Callable
  3. from contextlib import AbstractContextManager, ExitStack
  4. from types import TracebackType
  5. from typing import Any, Optional, cast
  6. from urllib.parse import urlparse
  7. from core.mcp.client.sse_client import sse_client
  8. from core.mcp.client.streamable_client import streamablehttp_client
  9. from core.mcp.error import MCPAuthError, MCPConnectionError
  10. from core.mcp.session.client_session import ClientSession
  11. from core.mcp.types import Tool
  12. logger = logging.getLogger(__name__)
  13. class MCPClient:
  14. def __init__(
  15. self,
  16. server_url: str,
  17. provider_id: str,
  18. tenant_id: str,
  19. authed: bool = True,
  20. authorization_code: Optional[str] = None,
  21. for_list: bool = False,
  22. headers: Optional[dict[str, str]] = None,
  23. timeout: Optional[float] = None,
  24. sse_read_timeout: Optional[float] = None,
  25. ):
  26. # Initialize info
  27. self.provider_id = provider_id
  28. self.tenant_id = tenant_id
  29. self.client_type = "streamable"
  30. self.server_url = server_url
  31. self.headers = headers or {}
  32. self.timeout = timeout
  33. self.sse_read_timeout = sse_read_timeout
  34. # Authentication info
  35. self.authed = authed
  36. self.authorization_code = authorization_code
  37. if authed:
  38. from core.mcp.auth.auth_provider import OAuthClientProvider
  39. self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
  40. self.token = self.provider.tokens()
  41. # Initialize session and client objects
  42. self._session: Optional[ClientSession] = None
  43. self._streams_context: Optional[AbstractContextManager[Any]] = None
  44. self._session_context: Optional[ClientSession] = None
  45. self._exit_stack = ExitStack()
  46. # Whether the client has been initialized
  47. self._initialized = False
  48. def __enter__(self):
  49. self._initialize()
  50. self._initialized = True
  51. return self
  52. def __exit__(
  53. self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
  54. ):
  55. self.cleanup()
  56. def _initialize(
  57. self,
  58. ):
  59. """Initialize the client with fallback to SSE if streamable connection fails"""
  60. connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
  61. "mcp": streamablehttp_client,
  62. "sse": sse_client,
  63. }
  64. parsed_url = urlparse(self.server_url)
  65. path = parsed_url.path or ""
  66. method_name = path.rstrip("/").split("/")[-1] if path else ""
  67. if method_name in connection_methods:
  68. client_factory = connection_methods[method_name]
  69. self.connect_server(client_factory, method_name)
  70. else:
  71. try:
  72. logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
  73. self.connect_server(sse_client, "sse")
  74. except MCPConnectionError:
  75. logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
  76. self.connect_server(streamablehttp_client, "mcp")
  77. def connect_server(
  78. self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
  79. ):
  80. from core.mcp.auth.auth_flow import auth
  81. try:
  82. headers = (
  83. {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
  84. if self.authed and self.token
  85. else self.headers
  86. )
  87. self._streams_context = client_factory(
  88. url=self.server_url,
  89. headers=headers,
  90. timeout=self.timeout,
  91. sse_read_timeout=self.sse_read_timeout,
  92. )
  93. if not self._streams_context:
  94. raise MCPConnectionError("Failed to create connection context")
  95. # Use exit_stack to manage context managers properly
  96. if method_name == "mcp":
  97. read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
  98. streams = (read_stream, write_stream)
  99. else: # sse_client
  100. streams = self._exit_stack.enter_context(self._streams_context)
  101. self._session_context = ClientSession(*streams)
  102. self._session = self._exit_stack.enter_context(self._session_context)
  103. session = cast(ClientSession, self._session)
  104. session.initialize()
  105. return
  106. except MCPAuthError:
  107. if not self.authed:
  108. raise
  109. try:
  110. auth(self.provider, self.server_url, self.authorization_code)
  111. except Exception as e:
  112. raise ValueError(f"Failed to authenticate: {e}")
  113. self.token = self.provider.tokens()
  114. if first_try:
  115. return self.connect_server(client_factory, method_name, first_try=False)
  116. def list_tools(self) -> list[Tool]:
  117. """Connect to an MCP server running with SSE transport"""
  118. # List available tools to verify connection
  119. if not self._initialized or not self._session:
  120. raise ValueError("Session not initialized.")
  121. response = self._session.list_tools()
  122. tools = response.tools
  123. return tools
  124. def invoke_tool(self, tool_name: str, tool_args: dict):
  125. """Call a tool"""
  126. if not self._initialized or not self._session:
  127. raise ValueError("Session not initialized.")
  128. return self._session.call_tool(tool_name, tool_args)
  129. def cleanup(self):
  130. """Clean up resources"""
  131. try:
  132. # ExitStack will handle proper cleanup of all managed context managers
  133. self._exit_stack.close()
  134. except Exception as e:
  135. logging.exception("Error during cleanup")
  136. raise ValueError(f"Error during cleanup: {e}")
  137. finally:
  138. self._session = None
  139. self._session_context = None
  140. self._streams_context = None
  141. self._initialized = False