您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

base_session.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import logging
  2. import queue
  3. from collections.abc import Callable
  4. from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
  5. from datetime import timedelta
  6. from types import TracebackType
  7. from typing import Any, Generic, Self, TypeVar
  8. from httpx import HTTPStatusError
  9. from pydantic import BaseModel
  10. from core.mcp.error import MCPAuthError, MCPConnectionError
  11. from core.mcp.types import (
  12. CancelledNotification,
  13. ClientNotification,
  14. ClientRequest,
  15. ClientResult,
  16. ErrorData,
  17. JSONRPCError,
  18. JSONRPCMessage,
  19. JSONRPCNotification,
  20. JSONRPCRequest,
  21. JSONRPCResponse,
  22. MessageMetadata,
  23. RequestId,
  24. RequestParams,
  25. ServerMessageMetadata,
  26. ServerNotification,
  27. ServerRequest,
  28. ServerResult,
  29. SessionMessage,
  30. )
  31. SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
  32. SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
  33. SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
  34. ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
  35. ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
  36. ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
  37. DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
  38. class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
  39. """Handles responding to MCP requests and manages request lifecycle.
  40. This class MUST be used as a context manager to ensure proper cleanup and
  41. cancellation handling:
  42. Example:
  43. with request_responder as resp:
  44. resp.respond(result)
  45. The context manager ensures:
  46. 1. Proper cancellation scope setup and cleanup
  47. 2. Request completion tracking
  48. 3. Cleanup of in-flight requests
  49. """
  50. request: ReceiveRequestT
  51. _session: Any
  52. _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
  53. def __init__(
  54. self,
  55. request_id: RequestId,
  56. request_meta: RequestParams.Meta | None,
  57. request: ReceiveRequestT,
  58. session: """BaseSession[
  59. SendRequestT,
  60. SendNotificationT,
  61. SendResultT,
  62. ReceiveRequestT,
  63. ReceiveNotificationT
  64. ]""",
  65. on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
  66. ) -> None:
  67. self.request_id = request_id
  68. self.request_meta = request_meta
  69. self.request = request
  70. self._session = session
  71. self._completed = False
  72. self._on_complete = on_complete
  73. self._entered = False # Track if we're in a context manager
  74. def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
  75. """Enter the context manager, enabling request cancellation tracking."""
  76. self._entered = True
  77. return self
  78. def __exit__(
  79. self,
  80. exc_type: type[BaseException] | None,
  81. exc_val: BaseException | None,
  82. exc_tb: TracebackType | None,
  83. ) -> None:
  84. """Exit the context manager, performing cleanup and notifying completion."""
  85. try:
  86. if self._completed:
  87. self._on_complete(self)
  88. finally:
  89. self._entered = False
  90. def respond(self, response: SendResultT | ErrorData) -> None:
  91. """Send a response for this request.
  92. Must be called within a context manager block.
  93. Raises:
  94. RuntimeError: If not used within a context manager
  95. AssertionError: If request was already responded to
  96. """
  97. if not self._entered:
  98. raise RuntimeError("RequestResponder must be used as a context manager")
  99. assert not self._completed, "Request already responded to"
  100. self._completed = True
  101. self._session._send_response(request_id=self.request_id, response=response)
  102. def cancel(self) -> None:
  103. """Cancel this request and mark it as completed."""
  104. if not self._entered:
  105. raise RuntimeError("RequestResponder must be used as a context manager")
  106. self._completed = True # Mark as completed so it's removed from in_flight
  107. # Send an error response to indicate cancellation
  108. self._session._send_response(
  109. request_id=self.request_id,
  110. response=ErrorData(code=0, message="Request cancelled", data=None),
  111. )
  112. class BaseSession(
  113. Generic[
  114. SendRequestT,
  115. SendNotificationT,
  116. SendResultT,
  117. ReceiveRequestT,
  118. ReceiveNotificationT,
  119. ],
  120. ):
  121. """
  122. Implements an MCP "session" on top of read/write streams, including features
  123. like request/response linking, notifications, and progress.
  124. This class is a context manager that automatically starts processing
  125. messages when entered.
  126. """
  127. _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
  128. _request_id: int
  129. _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
  130. _receive_request_type: type[ReceiveRequestT]
  131. _receive_notification_type: type[ReceiveNotificationT]
  132. def __init__(
  133. self,
  134. read_stream: queue.Queue,
  135. write_stream: queue.Queue,
  136. receive_request_type: type[ReceiveRequestT],
  137. receive_notification_type: type[ReceiveNotificationT],
  138. # If none, reading will never time out
  139. read_timeout_seconds: timedelta | None = None,
  140. ) -> None:
  141. self._read_stream = read_stream
  142. self._write_stream = write_stream
  143. self._response_streams = {}
  144. self._request_id = 0
  145. self._receive_request_type = receive_request_type
  146. self._receive_notification_type = receive_notification_type
  147. self._session_read_timeout_seconds = read_timeout_seconds
  148. self._in_flight = {}
  149. # Initialize executor and future to None for proper cleanup checks
  150. self._executor: ThreadPoolExecutor | None = None
  151. self._receiver_future: Future | None = None
  152. def __enter__(self) -> Self:
  153. # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
  154. # ensures no unnecessary threads are created.
  155. self._executor = ThreadPoolExecutor(max_workers=1)
  156. self._receiver_future = self._executor.submit(self._receive_loop)
  157. return self
  158. def check_receiver_status(self) -> None:
  159. """`check_receiver_status` ensures that any exceptions raised during the
  160. execution of `_receive_loop` are retrieved and propagated."""
  161. if self._receiver_future and self._receiver_future.done():
  162. self._receiver_future.result()
  163. def __exit__(
  164. self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
  165. ) -> None:
  166. self._read_stream.put(None)
  167. self._write_stream.put(None)
  168. # Wait for the receiver loop to finish
  169. if self._receiver_future:
  170. try:
  171. self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
  172. except TimeoutError:
  173. # If the receiver loop is still running after timeout, we'll force shutdown
  174. pass
  175. # Shutdown the executor
  176. if self._executor:
  177. self._executor.shutdown(wait=True)
  178. def send_request(
  179. self,
  180. request: SendRequestT,
  181. result_type: type[ReceiveResultT],
  182. request_read_timeout_seconds: timedelta | None = None,
  183. metadata: MessageMetadata = None,
  184. ) -> ReceiveResultT:
  185. """
  186. Sends a request and wait for a response. Raises an McpError if the
  187. response contains an error. If a request read timeout is provided, it
  188. will take precedence over the session read timeout.
  189. Do not use this method to emit notifications! Use send_notification()
  190. instead.
  191. """
  192. self.check_receiver_status()
  193. request_id = self._request_id
  194. self._request_id = request_id + 1
  195. response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
  196. self._response_streams[request_id] = response_queue
  197. try:
  198. jsonrpc_request = JSONRPCRequest(
  199. jsonrpc="2.0",
  200. id=request_id,
  201. **request.model_dump(by_alias=True, mode="json", exclude_none=True),
  202. )
  203. self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
  204. timeout = DEFAULT_RESPONSE_READ_TIMEOUT
  205. if request_read_timeout_seconds is not None:
  206. timeout = float(request_read_timeout_seconds.total_seconds())
  207. elif self._session_read_timeout_seconds is not None:
  208. timeout = float(self._session_read_timeout_seconds.total_seconds())
  209. while True:
  210. try:
  211. response_or_error = response_queue.get(timeout=timeout)
  212. break
  213. except queue.Empty:
  214. self.check_receiver_status()
  215. continue
  216. if response_or_error is None:
  217. raise MCPConnectionError(
  218. ErrorData(
  219. code=500,
  220. message="No response received",
  221. )
  222. )
  223. elif isinstance(response_or_error, JSONRPCError):
  224. if response_or_error.error.code == 401:
  225. raise MCPAuthError(
  226. ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
  227. )
  228. else:
  229. raise MCPConnectionError(
  230. ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
  231. )
  232. else:
  233. return result_type.model_validate(response_or_error.result)
  234. finally:
  235. self._response_streams.pop(request_id, None)
  236. def send_notification(
  237. self,
  238. notification: SendNotificationT,
  239. related_request_id: RequestId | None = None,
  240. ) -> None:
  241. """
  242. Emits a notification, which is a one-way message that does not expect
  243. a response.
  244. """
  245. self.check_receiver_status()
  246. # Some transport implementations may need to set the related_request_id
  247. # to attribute to the notifications to the request that triggered them.
  248. jsonrpc_notification = JSONRPCNotification(
  249. jsonrpc="2.0",
  250. **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
  251. )
  252. session_message = SessionMessage(
  253. message=JSONRPCMessage(jsonrpc_notification),
  254. metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
  255. )
  256. self._write_stream.put(session_message)
  257. def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
  258. if isinstance(response, ErrorData):
  259. jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
  260. session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
  261. self._write_stream.put(session_message)
  262. else:
  263. jsonrpc_response = JSONRPCResponse(
  264. jsonrpc="2.0",
  265. id=request_id,
  266. result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
  267. )
  268. session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
  269. self._write_stream.put(session_message)
  270. def _receive_loop(self) -> None:
  271. """
  272. Main message processing loop.
  273. In a real synchronous implementation, this would likely run in a separate thread.
  274. """
  275. while True:
  276. try:
  277. # Attempt to receive a message (this would be blocking in a synchronous context)
  278. message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
  279. if message is None:
  280. break
  281. if isinstance(message, HTTPStatusError):
  282. response_queue = self._response_streams.get(self._request_id - 1)
  283. if response_queue is not None:
  284. response_queue.put(
  285. JSONRPCError(
  286. jsonrpc="2.0",
  287. id=self._request_id - 1,
  288. error=ErrorData(code=message.response.status_code, message=message.args[0]),
  289. )
  290. )
  291. else:
  292. self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
  293. elif isinstance(message, Exception):
  294. self._handle_incoming(message)
  295. elif isinstance(message.message.root, JSONRPCRequest):
  296. validated_request = self._receive_request_type.model_validate(
  297. message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
  298. )
  299. responder = RequestResponder(
  300. request_id=message.message.root.id,
  301. request_meta=validated_request.root.params.meta if validated_request.root.params else None,
  302. request=validated_request,
  303. session=self,
  304. on_complete=lambda r: self._in_flight.pop(r.request_id, None),
  305. )
  306. self._in_flight[responder.request_id] = responder
  307. self._received_request(responder)
  308. if not responder._completed:
  309. self._handle_incoming(responder)
  310. elif isinstance(message.message.root, JSONRPCNotification):
  311. try:
  312. notification = self._receive_notification_type.model_validate(
  313. message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
  314. )
  315. # Handle cancellation notifications
  316. if isinstance(notification.root, CancelledNotification):
  317. cancelled_id = notification.root.params.requestId
  318. if cancelled_id in self._in_flight:
  319. self._in_flight[cancelled_id].cancel()
  320. else:
  321. self._received_notification(notification)
  322. self._handle_incoming(notification)
  323. except Exception as e:
  324. # For other validation errors, log and continue
  325. logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
  326. else: # Response or error
  327. response_queue = self._response_streams.get(message.message.root.id)
  328. if response_queue is not None:
  329. response_queue.put(message.message.root)
  330. else:
  331. self._handle_incoming(RuntimeError(f"Server Error: {message}"))
  332. except queue.Empty:
  333. continue
  334. except Exception:
  335. logging.exception("Error in message processing loop")
  336. raise
  337. def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
  338. """
  339. Can be overridden by subclasses to handle a request without needing to
  340. listen on the message stream.
  341. If the request is responded to within this method, it will not be
  342. forwarded on to the message stream.
  343. """
  344. def _received_notification(self, notification: ReceiveNotificationT) -> None:
  345. """
  346. Can be overridden by subclasses to handle a notification without needing
  347. to listen on the message stream.
  348. """
  349. def send_progress_notification(
  350. self, progress_token: str | int, progress: float, total: float | None = None
  351. ) -> None:
  352. """
  353. Sends a progress notification for a request that is currently being
  354. processed.
  355. """
  356. def _handle_incoming(
  357. self,
  358. req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
  359. ) -> None:
  360. """A generic handler for incoming messages. Overwritten by subclasses."""