You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_session.py 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import logging
  2. import queue
  3. from collections.abc import Callable
  4. from concurrent.futures import ThreadPoolExecutor
  5. from contextlib import ExitStack
  6. from datetime import timedelta
  7. from types import TracebackType
  8. from typing import Any, Generic, Self, TypeVar
  9. from httpx import HTTPStatusError
  10. from pydantic import BaseModel
  11. from core.mcp.error import MCPAuthError, MCPConnectionError
  12. from core.mcp.types import (
  13. CancelledNotification,
  14. ClientNotification,
  15. ClientRequest,
  16. ClientResult,
  17. ErrorData,
  18. JSONRPCError,
  19. JSONRPCMessage,
  20. JSONRPCNotification,
  21. JSONRPCRequest,
  22. JSONRPCResponse,
  23. MessageMetadata,
  24. RequestId,
  25. RequestParams,
  26. ServerMessageMetadata,
  27. ServerNotification,
  28. ServerRequest,
  29. ServerResult,
  30. SessionMessage,
  31. )
  32. SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
  33. SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
  34. SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
  35. ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
  36. ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
  37. ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
  38. DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
  39. class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
  40. """Handles responding to MCP requests and manages request lifecycle.
  41. This class MUST be used as a context manager to ensure proper cleanup and
  42. cancellation handling:
  43. Example:
  44. with request_responder as resp:
  45. resp.respond(result)
  46. The context manager ensures:
  47. 1. Proper cancellation scope setup and cleanup
  48. 2. Request completion tracking
  49. 3. Cleanup of in-flight requests
  50. """
  51. request: ReceiveRequestT
  52. _session: Any
  53. _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
  54. def __init__(
  55. self,
  56. request_id: RequestId,
  57. request_meta: RequestParams.Meta | None,
  58. request: ReceiveRequestT,
  59. session: """BaseSession[
  60. SendRequestT,
  61. SendNotificationT,
  62. SendResultT,
  63. ReceiveRequestT,
  64. ReceiveNotificationT
  65. ]""",
  66. on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
  67. ) -> None:
  68. self.request_id = request_id
  69. self.request_meta = request_meta
  70. self.request = request
  71. self._session = session
  72. self._completed = False
  73. self._on_complete = on_complete
  74. self._entered = False # Track if we're in a context manager
  75. def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
  76. """Enter the context manager, enabling request cancellation tracking."""
  77. self._entered = True
  78. return self
  79. def __exit__(
  80. self,
  81. exc_type: type[BaseException] | None,
  82. exc_val: BaseException | None,
  83. exc_tb: TracebackType | None,
  84. ) -> None:
  85. """Exit the context manager, performing cleanup and notifying completion."""
  86. try:
  87. if self._completed:
  88. self._on_complete(self)
  89. finally:
  90. self._entered = False
  91. def respond(self, response: SendResultT | ErrorData) -> None:
  92. """Send a response for this request.
  93. Must be called within a context manager block.
  94. Raises:
  95. RuntimeError: If not used within a context manager
  96. AssertionError: If request was already responded to
  97. """
  98. if not self._entered:
  99. raise RuntimeError("RequestResponder must be used as a context manager")
  100. assert not self._completed, "Request already responded to"
  101. self._completed = True
  102. self._session._send_response(request_id=self.request_id, response=response)
  103. def cancel(self) -> None:
  104. """Cancel this request and mark it as completed."""
  105. if not self._entered:
  106. raise RuntimeError("RequestResponder must be used as a context manager")
  107. self._completed = True # Mark as completed so it's removed from in_flight
  108. # Send an error response to indicate cancellation
  109. self._session._send_response(
  110. request_id=self.request_id,
  111. response=ErrorData(code=0, message="Request cancelled", data=None),
  112. )
  113. class BaseSession(
  114. Generic[
  115. SendRequestT,
  116. SendNotificationT,
  117. SendResultT,
  118. ReceiveRequestT,
  119. ReceiveNotificationT,
  120. ],
  121. ):
  122. """
  123. Implements an MCP "session" on top of read/write streams, including features
  124. like request/response linking, notifications, and progress.
  125. This class is a context manager that automatically starts processing
  126. messages when entered.
  127. """
  128. _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
  129. _request_id: int
  130. _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
  131. _receive_request_type: type[ReceiveRequestT]
  132. _receive_notification_type: type[ReceiveNotificationT]
  133. def __init__(
  134. self,
  135. read_stream: queue.Queue,
  136. write_stream: queue.Queue,
  137. receive_request_type: type[ReceiveRequestT],
  138. receive_notification_type: type[ReceiveNotificationT],
  139. # If none, reading will never time out
  140. read_timeout_seconds: timedelta | None = None,
  141. ) -> None:
  142. self._read_stream = read_stream
  143. self._write_stream = write_stream
  144. self._response_streams = {}
  145. self._request_id = 0
  146. self._receive_request_type = receive_request_type
  147. self._receive_notification_type = receive_notification_type
  148. self._session_read_timeout_seconds = read_timeout_seconds
  149. self._in_flight = {}
  150. self._exit_stack = ExitStack()
  151. def __enter__(self) -> Self:
  152. self._executor = ThreadPoolExecutor()
  153. self._receiver_future = self._executor.submit(self._receive_loop)
  154. return self
  155. def check_receiver_status(self) -> None:
  156. if self._receiver_future.done():
  157. self._receiver_future.result()
  158. def __exit__(
  159. self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
  160. ) -> None:
  161. self._exit_stack.close()
  162. self._read_stream.put(None)
  163. self._write_stream.put(None)
  164. def send_request(
  165. self,
  166. request: SendRequestT,
  167. result_type: type[ReceiveResultT],
  168. request_read_timeout_seconds: timedelta | None = None,
  169. metadata: MessageMetadata = None,
  170. ) -> ReceiveResultT:
  171. """
  172. Sends a request and wait for a response. Raises an McpError if the
  173. response contains an error. If a request read timeout is provided, it
  174. will take precedence over the session read timeout.
  175. Do not use this method to emit notifications! Use send_notification()
  176. instead.
  177. """
  178. self.check_receiver_status()
  179. request_id = self._request_id
  180. self._request_id = request_id + 1
  181. response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
  182. self._response_streams[request_id] = response_queue
  183. try:
  184. jsonrpc_request = JSONRPCRequest(
  185. jsonrpc="2.0",
  186. id=request_id,
  187. **request.model_dump(by_alias=True, mode="json", exclude_none=True),
  188. )
  189. self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
  190. timeout = DEFAULT_RESPONSE_READ_TIMEOUT
  191. if request_read_timeout_seconds is not None:
  192. timeout = float(request_read_timeout_seconds.total_seconds())
  193. elif self._session_read_timeout_seconds is not None:
  194. timeout = float(self._session_read_timeout_seconds.total_seconds())
  195. while True:
  196. try:
  197. response_or_error = response_queue.get(timeout=timeout)
  198. break
  199. except queue.Empty:
  200. self.check_receiver_status()
  201. continue
  202. if response_or_error is None:
  203. raise MCPConnectionError(
  204. ErrorData(
  205. code=500,
  206. message="No response received",
  207. )
  208. )
  209. elif isinstance(response_or_error, JSONRPCError):
  210. if response_or_error.error.code == 401:
  211. raise MCPAuthError(
  212. ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
  213. )
  214. else:
  215. raise MCPConnectionError(
  216. ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
  217. )
  218. else:
  219. return result_type.model_validate(response_or_error.result)
  220. finally:
  221. self._response_streams.pop(request_id, None)
  222. def send_notification(
  223. self,
  224. notification: SendNotificationT,
  225. related_request_id: RequestId | None = None,
  226. ) -> None:
  227. """
  228. Emits a notification, which is a one-way message that does not expect
  229. a response.
  230. """
  231. self.check_receiver_status()
  232. # Some transport implementations may need to set the related_request_id
  233. # to attribute to the notifications to the request that triggered them.
  234. jsonrpc_notification = JSONRPCNotification(
  235. jsonrpc="2.0",
  236. **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
  237. )
  238. session_message = SessionMessage(
  239. message=JSONRPCMessage(jsonrpc_notification),
  240. metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
  241. )
  242. self._write_stream.put(session_message)
  243. def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
  244. if isinstance(response, ErrorData):
  245. jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
  246. session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
  247. self._write_stream.put(session_message)
  248. else:
  249. jsonrpc_response = JSONRPCResponse(
  250. jsonrpc="2.0",
  251. id=request_id,
  252. result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
  253. )
  254. session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
  255. self._write_stream.put(session_message)
  256. def _receive_loop(self) -> None:
  257. """
  258. Main message processing loop.
  259. In a real synchronous implementation, this would likely run in a separate thread.
  260. """
  261. while True:
  262. try:
  263. # Attempt to receive a message (this would be blocking in a synchronous context)
  264. message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
  265. if message is None:
  266. break
  267. if isinstance(message, HTTPStatusError):
  268. response_queue = self._response_streams.get(self._request_id - 1)
  269. if response_queue is not None:
  270. response_queue.put(
  271. JSONRPCError(
  272. jsonrpc="2.0",
  273. id=self._request_id - 1,
  274. error=ErrorData(code=message.response.status_code, message=message.args[0]),
  275. )
  276. )
  277. else:
  278. self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
  279. elif isinstance(message, Exception):
  280. self._handle_incoming(message)
  281. elif isinstance(message.message.root, JSONRPCRequest):
  282. validated_request = self._receive_request_type.model_validate(
  283. message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
  284. )
  285. responder = RequestResponder(
  286. request_id=message.message.root.id,
  287. request_meta=validated_request.root.params.meta if validated_request.root.params else None,
  288. request=validated_request,
  289. session=self,
  290. on_complete=lambda r: self._in_flight.pop(r.request_id, None),
  291. )
  292. self._in_flight[responder.request_id] = responder
  293. self._received_request(responder)
  294. if not responder._completed:
  295. self._handle_incoming(responder)
  296. elif isinstance(message.message.root, JSONRPCNotification):
  297. try:
  298. notification = self._receive_notification_type.model_validate(
  299. message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
  300. )
  301. # Handle cancellation notifications
  302. if isinstance(notification.root, CancelledNotification):
  303. cancelled_id = notification.root.params.requestId
  304. if cancelled_id in self._in_flight:
  305. self._in_flight[cancelled_id].cancel()
  306. else:
  307. self._received_notification(notification)
  308. self._handle_incoming(notification)
  309. except Exception as e:
  310. # For other validation errors, log and continue
  311. logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
  312. else: # Response or error
  313. response_queue = self._response_streams.get(message.message.root.id)
  314. if response_queue is not None:
  315. response_queue.put(message.message.root)
  316. else:
  317. self._handle_incoming(RuntimeError(f"Server Error: {message}"))
  318. except queue.Empty:
  319. continue
  320. except Exception as e:
  321. logging.exception("Error in message processing loop")
  322. raise
  323. def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
  324. """
  325. Can be overridden by subclasses to handle a request without needing to
  326. listen on the message stream.
  327. If the request is responded to within this method, it will not be
  328. forwarded on to the message stream.
  329. """
  330. pass
  331. def _received_notification(self, notification: ReceiveNotificationT) -> None:
  332. """
  333. Can be overridden by subclasses to handle a notification without needing
  334. to listen on the message stream.
  335. """
  336. pass
  337. def send_progress_notification(
  338. self, progress_token: str | int, progress: float, total: float | None = None
  339. ) -> None:
  340. """
  341. Sends a progress notification for a request that is currently being
  342. processed.
  343. """
  344. pass
  345. def _handle_incoming(
  346. self,
  347. req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
  348. ) -> None:
  349. """A generic handler for incoming messages. Overwritten by subclasses."""
  350. pass