Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

mcp_client.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import logging
  2. from collections.abc import Callable
  3. from contextlib import AbstractContextManager, ExitStack
  4. from types import TracebackType
  5. from typing import Any, Optional, cast
  6. from urllib.parse import urlparse
  7. from core.mcp.client.sse_client import sse_client
  8. from core.mcp.client.streamable_client import streamablehttp_client
  9. from core.mcp.error import MCPAuthError, MCPConnectionError
  10. from core.mcp.session.client_session import ClientSession
  11. from core.mcp.types import Tool
  12. logger = logging.getLogger(__name__)
  13. class MCPClient:
  14. def __init__(
  15. self,
  16. server_url: str,
  17. provider_id: str,
  18. tenant_id: str,
  19. authed: bool = True,
  20. authorization_code: Optional[str] = None,
  21. for_list: bool = False,
  22. ):
  23. # Initialize info
  24. self.provider_id = provider_id
  25. self.tenant_id = tenant_id
  26. self.client_type = "streamable"
  27. self.server_url = server_url
  28. # Authentication info
  29. self.authed = authed
  30. self.authorization_code = authorization_code
  31. if authed:
  32. from core.mcp.auth.auth_provider import OAuthClientProvider
  33. self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
  34. self.token = self.provider.tokens()
  35. # Initialize session and client objects
  36. self._session: Optional[ClientSession] = None
  37. self._streams_context: Optional[AbstractContextManager[Any]] = None
  38. self._session_context: Optional[ClientSession] = None
  39. self.exit_stack = ExitStack()
  40. # Whether the client has been initialized
  41. self._initialized = False
  42. def __enter__(self):
  43. self._initialize()
  44. self._initialized = True
  45. return self
  46. def __exit__(
  47. self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
  48. ):
  49. self.cleanup()
  50. def _initialize(
  51. self,
  52. ):
  53. """Initialize the client with fallback to SSE if streamable connection fails"""
  54. connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
  55. "mcp": streamablehttp_client,
  56. "sse": sse_client,
  57. }
  58. parsed_url = urlparse(self.server_url)
  59. path = parsed_url.path
  60. method_name = path.rstrip("/").split("/")[-1] if path else ""
  61. try:
  62. client_factory = connection_methods[method_name]
  63. self.connect_server(client_factory, method_name)
  64. except KeyError:
  65. try:
  66. self.connect_server(sse_client, "sse")
  67. except MCPConnectionError:
  68. self.connect_server(streamablehttp_client, "mcp")
  69. def connect_server(
  70. self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
  71. ):
  72. from core.mcp.auth.auth_flow import auth
  73. try:
  74. headers = (
  75. {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
  76. if self.authed and self.token
  77. else {}
  78. )
  79. self._streams_context = client_factory(url=self.server_url, headers=headers)
  80. if self._streams_context is None:
  81. raise MCPConnectionError("Failed to create connection context")
  82. # Use exit_stack to manage context managers properly
  83. if method_name == "mcp":
  84. read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
  85. streams = (read_stream, write_stream)
  86. else: # sse_client
  87. streams = self.exit_stack.enter_context(self._streams_context)
  88. self._session_context = ClientSession(*streams)
  89. self._session = self.exit_stack.enter_context(self._session_context)
  90. session = cast(ClientSession, self._session)
  91. session.initialize()
  92. return
  93. except MCPAuthError:
  94. if not self.authed:
  95. raise
  96. try:
  97. auth(self.provider, self.server_url, self.authorization_code)
  98. except Exception as e:
  99. raise ValueError(f"Failed to authenticate: {e}")
  100. self.token = self.provider.tokens()
  101. if first_try:
  102. return self.connect_server(client_factory, method_name, first_try=False)
  103. except MCPConnectionError:
  104. raise
  105. def list_tools(self) -> list[Tool]:
  106. """Connect to an MCP server running with SSE transport"""
  107. # List available tools to verify connection
  108. if not self._initialized or not self._session:
  109. raise ValueError("Session not initialized.")
  110. response = self._session.list_tools()
  111. tools = response.tools
  112. return tools
  113. def invoke_tool(self, tool_name: str, tool_args: dict):
  114. """Call a tool"""
  115. if not self._initialized or not self._session:
  116. raise ValueError("Session not initialized.")
  117. return self._session.call_tool(tool_name, tool_args)
  118. def cleanup(self):
  119. """Clean up resources"""
  120. try:
  121. # ExitStack will handle proper cleanup of all managed context managers
  122. self.exit_stack.close()
  123. self._session = None
  124. self._session_context = None
  125. self._streams_context = None
  126. self._initialized = False
  127. except Exception as e:
  128. logging.exception("Error during cleanup")
  129. raise ValueError(f"Error during cleanup: {e}")