You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sse_client.py 12KB


  1. import logging
  2. import queue
  3. from collections.abc import Generator
  4. from concurrent.futures import ThreadPoolExecutor
  5. from contextlib import contextmanager
  6. from typing import Any, TypeAlias, final
  7. from urllib.parse import urljoin, urlparse
  8. import httpx
  9. from httpx_sse import EventSource, ServerSentEvent
  10. from sseclient import SSEClient
  11. from core.mcp import types
  12. from core.mcp.error import MCPAuthError, MCPConnectionError
  13. from core.mcp.types import SessionMessage
  14. from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
  15. logger = logging.getLogger(__name__)
  16. DEFAULT_QUEUE_READ_TIMEOUT = 3
  17. @final
  18. class _StatusReady:
  19. def __init__(self, endpoint_url: str):
  20. self.endpoint_url = endpoint_url
  21. @final
  22. class _StatusError:
  23. def __init__(self, exc: Exception):
  24. self.exc = exc
  25. # Type aliases for better readability
  26. ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  27. WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  28. StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
  29. class SSETransport:
  30. """SSE client transport implementation."""
  31. def __init__(
  32. self,
  33. url: str,
  34. headers: dict[str, Any] | None = None,
  35. timeout: float = 5.0,
  36. sse_read_timeout: float = 5 * 60,
  37. ):
  38. """Initialize the SSE transport.
  39. Args:
  40. url: The SSE endpoint URL.
  41. headers: Optional headers to include in requests.
  42. timeout: HTTP timeout for regular operations.
  43. sse_read_timeout: Timeout for SSE read operations.
  44. """
  45. self.url = url
  46. self.headers = headers or {}
  47. self.timeout = timeout
  48. self.sse_read_timeout = sse_read_timeout
  49. self.endpoint_url: str | None = None
  50. def _validate_endpoint_url(self, endpoint_url: str) -> bool:
  51. """Validate that the endpoint URL matches the connection origin.
  52. Args:
  53. endpoint_url: The endpoint URL to validate.
  54. Returns:
  55. True if valid, False otherwise.
  56. """
  57. url_parsed = urlparse(self.url)
  58. endpoint_parsed = urlparse(endpoint_url)
  59. return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
  60. def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
  61. """Handle an 'endpoint' SSE event.
  62. Args:
  63. sse_data: The SSE event data.
  64. status_queue: Queue to put status updates.
  65. """
  66. endpoint_url = urljoin(self.url, sse_data)
  67. logger.info("Received endpoint URL: %s", endpoint_url)
  68. if not self._validate_endpoint_url(endpoint_url):
  69. error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
  70. logger.error(error_msg)
  71. status_queue.put(_StatusError(ValueError(error_msg)))
  72. return
  73. status_queue.put(_StatusReady(endpoint_url))
  74. def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
  75. """Handle a 'message' SSE event.
  76. Args:
  77. sse_data: The SSE event data.
  78. read_queue: Queue to put parsed messages.
  79. """
  80. try:
  81. message = types.JSONRPCMessage.model_validate_json(sse_data)
  82. logger.debug("Received server message: %s", message)
  83. session_message = SessionMessage(message)
  84. read_queue.put(session_message)
  85. except Exception as exc:
  86. logger.exception("Error parsing server message")
  87. read_queue.put(exc)
  88. def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
  89. """Handle a single SSE event.
  90. Args:
  91. sse: The SSE event object.
  92. read_queue: Queue for message events.
  93. status_queue: Queue for status events.
  94. """
  95. match sse.event:
  96. case "endpoint":
  97. self._handle_endpoint_event(sse.data, status_queue)
  98. case "message":
  99. self._handle_message_event(sse.data, read_queue)
  100. case _:
  101. logger.warning("Unknown SSE event: %s", sse.event)
  102. def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
  103. """Read and process SSE events.
  104. Args:
  105. event_source: The SSE event source.
  106. read_queue: Queue to put received messages.
  107. status_queue: Queue to put status updates.
  108. """
  109. try:
  110. for sse in event_source.iter_sse():
  111. self._handle_sse_event(sse, read_queue, status_queue)
  112. except httpx.ReadError as exc:
  113. logger.debug("SSE reader shutting down normally: %s", exc)
  114. except Exception as exc:
  115. read_queue.put(exc)
  116. finally:
  117. read_queue.put(None)
  118. def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
  119. """Send a single message to the server.
  120. Args:
  121. client: HTTP client to use.
  122. endpoint_url: The endpoint URL to send to.
  123. message: The message to send.
  124. """
  125. response = client.post(
  126. endpoint_url,
  127. json=message.message.model_dump(
  128. by_alias=True,
  129. mode="json",
  130. exclude_none=True,
  131. ),
  132. )
  133. response.raise_for_status()
  134. logger.debug("Client message sent successfully: %s", response.status_code)
  135. def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
  136. """Handle writing messages to the server.
  137. Args:
  138. client: HTTP client to use.
  139. endpoint_url: The endpoint URL to send messages to.
  140. write_queue: Queue to read messages from.
  141. """
  142. try:
  143. while True:
  144. try:
  145. message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
  146. if message is None:
  147. break
  148. if isinstance(message, Exception):
  149. write_queue.put(message)
  150. continue
  151. self._send_message(client, endpoint_url, message)
  152. except queue.Empty:
  153. continue
  154. except httpx.ReadError as exc:
  155. logger.debug("Post writer shutting down normally: %s", exc)
  156. except Exception as exc:
  157. logger.exception("Error writing messages")
  158. write_queue.put(exc)
  159. finally:
  160. write_queue.put(None)
  161. def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
  162. """Wait for the endpoint URL from the status queue.
  163. Args:
  164. status_queue: Queue to read status from.
  165. Returns:
  166. The endpoint URL.
  167. Raises:
  168. ValueError: If endpoint URL is not received or there's an error.
  169. """
  170. try:
  171. status = status_queue.get(timeout=1)
  172. except queue.Empty:
  173. raise ValueError("failed to get endpoint URL")
  174. if isinstance(status, _StatusReady):
  175. return status.endpoint_url
  176. elif isinstance(status, _StatusError):
  177. raise status.exc
  178. else:
  179. raise ValueError("failed to get endpoint URL")
  180. def connect(
  181. self,
  182. executor: ThreadPoolExecutor,
  183. client: httpx.Client,
  184. event_source: EventSource,
  185. ) -> tuple[ReadQueue, WriteQueue]:
  186. """Establish connection and start worker threads.
  187. Args:
  188. executor: Thread pool executor.
  189. client: HTTP client.
  190. event_source: SSE event source.
  191. Returns:
  192. Tuple of (read_queue, write_queue).
  193. """
  194. read_queue: ReadQueue = queue.Queue()
  195. write_queue: WriteQueue = queue.Queue()
  196. status_queue: StatusQueue = queue.Queue()
  197. # Start SSE reader thread
  198. executor.submit(self.sse_reader, event_source, read_queue, status_queue)
  199. # Wait for endpoint URL
  200. endpoint_url = self._wait_for_endpoint(status_queue)
  201. self.endpoint_url = endpoint_url
  202. # Start post writer thread
  203. executor.submit(self.post_writer, client, endpoint_url, write_queue)
  204. return read_queue, write_queue
  205. @contextmanager
  206. def sse_client(
  207. url: str,
  208. headers: dict[str, Any] | None = None,
  209. timeout: float = 5.0,
  210. sse_read_timeout: float = 5 * 60,
  211. ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
  212. """
  213. Client transport for SSE.
  214. `sse_read_timeout` determines how long (in seconds) the client will wait for a new
  215. event before disconnecting. All other HTTP operations are controlled by `timeout`.
  216. Args:
  217. url: The SSE endpoint URL.
  218. headers: Optional headers to include in requests.
  219. timeout: HTTP timeout for regular operations.
  220. sse_read_timeout: Timeout for SSE read operations.
  221. Yields:
  222. Tuple of (read_queue, write_queue) for message communication.
  223. """
  224. transport = SSETransport(url, headers, timeout, sse_read_timeout)
  225. read_queue: ReadQueue | None = None
  226. write_queue: WriteQueue | None = None
  227. with ThreadPoolExecutor() as executor:
  228. try:
  229. with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
  230. with ssrf_proxy_sse_connect(
  231. url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
  232. ) as event_source:
  233. event_source.response.raise_for_status()
  234. read_queue, write_queue = transport.connect(executor, client, event_source)
  235. yield read_queue, write_queue
  236. except httpx.HTTPStatusError as exc:
  237. if exc.response.status_code == 401:
  238. raise MCPAuthError()
  239. raise MCPConnectionError()
  240. except Exception:
  241. logger.exception("Error connecting to SSE endpoint")
  242. raise
  243. finally:
  244. # Clean up queues
  245. if read_queue:
  246. read_queue.put(None)
  247. if write_queue:
  248. write_queue.put(None)
  249. def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
  250. """
  251. Send a message to the server using the provided HTTP client.
  252. Args:
  253. http_client: The HTTP client to use for sending
  254. endpoint_url: The endpoint URL to send the message to
  255. session_message: The message to send
  256. """
  257. try:
  258. response = http_client.post(
  259. endpoint_url,
  260. json=session_message.message.model_dump(
  261. by_alias=True,
  262. mode="json",
  263. exclude_none=True,
  264. ),
  265. )
  266. response.raise_for_status()
  267. logger.debug("Client message sent successfully: %s", response.status_code)
  268. except Exception:
  269. logger.exception("Error sending message")
  270. raise
  271. def read_messages(
  272. sse_client: SSEClient,
  273. ) -> Generator[SessionMessage | Exception, None, None]:
  274. """
  275. Read messages from the SSE client.
  276. Args:
  277. sse_client: The SSE client to read from
  278. Yields:
  279. SessionMessage or Exception for each event received
  280. """
  281. try:
  282. for sse in sse_client.events():
  283. if sse.event == "message":
  284. try:
  285. message = types.JSONRPCMessage.model_validate_json(sse.data)
  286. logger.debug("Received server message: %s", message)
  287. yield SessionMessage(message)
  288. except Exception as exc:
  289. logger.exception("Error parsing server message")
  290. yield exc
  291. else:
  292. logger.warning("Unknown SSE event: %s", sse.event)
  293. except Exception as exc:
  294. logger.exception("Error reading SSE messages")
  295. yield exc