You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sse_client.py 12KB


  1. import logging
  2. import queue
  3. from collections.abc import Generator
  4. from concurrent.futures import ThreadPoolExecutor
  5. from contextlib import contextmanager
  6. from typing import Any, TypeAlias, final
  7. from urllib.parse import urljoin, urlparse
  8. import httpx
  9. from sseclient import SSEClient
  10. from core.mcp import types
  11. from core.mcp.error import MCPAuthError, MCPConnectionError
  12. from core.mcp.types import SessionMessage
  13. from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
  14. logger = logging.getLogger(__name__)
  15. DEFAULT_QUEUE_READ_TIMEOUT = 3
  16. @final
  17. class _StatusReady:
  18. def __init__(self, endpoint_url: str):
  19. self._endpoint_url = endpoint_url
  20. @final
  21. class _StatusError:
  22. def __init__(self, exc: Exception):
  23. self._exc = exc
  24. # Type aliases for better readability
  25. ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  26. WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  27. StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
  28. def remove_request_params(url: str) -> str:
  29. """Remove request parameters from URL, keeping only the path."""
  30. return urljoin(url, urlparse(url).path)
  31. class SSETransport:
  32. """SSE client transport implementation."""
  33. def __init__(
  34. self,
  35. url: str,
  36. headers: dict[str, Any] | None = None,
  37. timeout: float = 5.0,
  38. sse_read_timeout: float = 5 * 60,
  39. ) -> None:
  40. """Initialize the SSE transport.
  41. Args:
  42. url: The SSE endpoint URL.
  43. headers: Optional headers to include in requests.
  44. timeout: HTTP timeout for regular operations.
  45. sse_read_timeout: Timeout for SSE read operations.
  46. """
  47. self.url = url
  48. self.headers = headers or {}
  49. self.timeout = timeout
  50. self.sse_read_timeout = sse_read_timeout
  51. self.endpoint_url: str | None = None
  52. def _validate_endpoint_url(self, endpoint_url: str) -> bool:
  53. """Validate that the endpoint URL matches the connection origin.
  54. Args:
  55. endpoint_url: The endpoint URL to validate.
  56. Returns:
  57. True if valid, False otherwise.
  58. """
  59. url_parsed = urlparse(self.url)
  60. endpoint_parsed = urlparse(endpoint_url)
  61. return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
  62. def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
  63. """Handle an 'endpoint' SSE event.
  64. Args:
  65. sse_data: The SSE event data.
  66. status_queue: Queue to put status updates.
  67. """
  68. endpoint_url = urljoin(self.url, sse_data)
  69. logger.info(f"Received endpoint URL: {endpoint_url}")
  70. if not self._validate_endpoint_url(endpoint_url):
  71. error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
  72. logger.error(error_msg)
  73. status_queue.put(_StatusError(ValueError(error_msg)))
  74. return
  75. status_queue.put(_StatusReady(endpoint_url))
  76. def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
  77. """Handle a 'message' SSE event.
  78. Args:
  79. sse_data: The SSE event data.
  80. read_queue: Queue to put parsed messages.
  81. """
  82. try:
  83. message = types.JSONRPCMessage.model_validate_json(sse_data)
  84. logger.debug(f"Received server message: {message}")
  85. session_message = SessionMessage(message)
  86. read_queue.put(session_message)
  87. except Exception as exc:
  88. logger.exception("Error parsing server message")
  89. read_queue.put(exc)
  90. def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
  91. """Handle a single SSE event.
  92. Args:
  93. sse: The SSE event object.
  94. read_queue: Queue for message events.
  95. status_queue: Queue for status events.
  96. """
  97. match sse.event:
  98. case "endpoint":
  99. self._handle_endpoint_event(sse.data, status_queue)
  100. case "message":
  101. self._handle_message_event(sse.data, read_queue)
  102. case _:
  103. logger.warning(f"Unknown SSE event: {sse.event}")
  104. def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
  105. """Read and process SSE events.
  106. Args:
  107. event_source: The SSE event source.
  108. read_queue: Queue to put received messages.
  109. status_queue: Queue to put status updates.
  110. """
  111. try:
  112. for sse in event_source.iter_sse():
  113. self._handle_sse_event(sse, read_queue, status_queue)
  114. except httpx.ReadError as exc:
  115. logger.debug(f"SSE reader shutting down normally: {exc}")
  116. except Exception as exc:
  117. read_queue.put(exc)
  118. finally:
  119. read_queue.put(None)
  120. def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
  121. """Send a single message to the server.
  122. Args:
  123. client: HTTP client to use.
  124. endpoint_url: The endpoint URL to send to.
  125. message: The message to send.
  126. """
  127. response = client.post(
  128. endpoint_url,
  129. json=message.message.model_dump(
  130. by_alias=True,
  131. mode="json",
  132. exclude_none=True,
  133. ),
  134. )
  135. response.raise_for_status()
  136. logger.debug(f"Client message sent successfully: {response.status_code}")
  137. def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
  138. """Handle writing messages to the server.
  139. Args:
  140. client: HTTP client to use.
  141. endpoint_url: The endpoint URL to send messages to.
  142. write_queue: Queue to read messages from.
  143. """
  144. try:
  145. while True:
  146. try:
  147. message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
  148. if message is None:
  149. break
  150. if isinstance(message, Exception):
  151. write_queue.put(message)
  152. continue
  153. self._send_message(client, endpoint_url, message)
  154. except queue.Empty:
  155. continue
  156. except httpx.ReadError as exc:
  157. logger.debug(f"Post writer shutting down normally: {exc}")
  158. except Exception as exc:
  159. logger.exception("Error writing messages")
  160. write_queue.put(exc)
  161. finally:
  162. write_queue.put(None)
  163. def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
  164. """Wait for the endpoint URL from the status queue.
  165. Args:
  166. status_queue: Queue to read status from.
  167. Returns:
  168. The endpoint URL.
  169. Raises:
  170. ValueError: If endpoint URL is not received or there's an error.
  171. """
  172. try:
  173. status = status_queue.get(timeout=1)
  174. except queue.Empty:
  175. raise ValueError("failed to get endpoint URL")
  176. if isinstance(status, _StatusReady):
  177. return status._endpoint_url
  178. elif isinstance(status, _StatusError):
  179. raise status._exc
  180. else:
  181. raise ValueError("failed to get endpoint URL")
  182. def connect(
  183. self,
  184. executor: ThreadPoolExecutor,
  185. client: httpx.Client,
  186. event_source,
  187. ) -> tuple[ReadQueue, WriteQueue]:
  188. """Establish connection and start worker threads.
  189. Args:
  190. executor: Thread pool executor.
  191. client: HTTP client.
  192. event_source: SSE event source.
  193. Returns:
  194. Tuple of (read_queue, write_queue).
  195. """
  196. read_queue: ReadQueue = queue.Queue()
  197. write_queue: WriteQueue = queue.Queue()
  198. status_queue: StatusQueue = queue.Queue()
  199. # Start SSE reader thread
  200. executor.submit(self.sse_reader, event_source, read_queue, status_queue)
  201. # Wait for endpoint URL
  202. endpoint_url = self._wait_for_endpoint(status_queue)
  203. self.endpoint_url = endpoint_url
  204. # Start post writer thread
  205. executor.submit(self.post_writer, client, endpoint_url, write_queue)
  206. return read_queue, write_queue
  207. @contextmanager
  208. def sse_client(
  209. url: str,
  210. headers: dict[str, Any] | None = None,
  211. timeout: float = 5.0,
  212. sse_read_timeout: float = 5 * 60,
  213. ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
  214. """
  215. Client transport for SSE.
  216. `sse_read_timeout` determines how long (in seconds) the client will wait for a new
  217. event before disconnecting. All other HTTP operations are controlled by `timeout`.
  218. Args:
  219. url: The SSE endpoint URL.
  220. headers: Optional headers to include in requests.
  221. timeout: HTTP timeout for regular operations.
  222. sse_read_timeout: Timeout for SSE read operations.
  223. Yields:
  224. Tuple of (read_queue, write_queue) for message communication.
  225. """
  226. transport = SSETransport(url, headers, timeout, sse_read_timeout)
  227. read_queue: ReadQueue | None = None
  228. write_queue: WriteQueue | None = None
  229. with ThreadPoolExecutor() as executor:
  230. try:
  231. with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
  232. with ssrf_proxy_sse_connect(
  233. url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
  234. ) as event_source:
  235. event_source.response.raise_for_status()
  236. read_queue, write_queue = transport.connect(executor, client, event_source)
  237. yield read_queue, write_queue
  238. except httpx.HTTPStatusError as exc:
  239. if exc.response.status_code == 401:
  240. raise MCPAuthError()
  241. raise MCPConnectionError()
  242. except Exception:
  243. logger.exception("Error connecting to SSE endpoint")
  244. raise
  245. finally:
  246. # Clean up queues
  247. if read_queue:
  248. read_queue.put(None)
  249. if write_queue:
  250. write_queue.put(None)
  251. def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
  252. """
  253. Send a message to the server using the provided HTTP client.
  254. Args:
  255. http_client: The HTTP client to use for sending
  256. endpoint_url: The endpoint URL to send the message to
  257. session_message: The message to send
  258. """
  259. try:
  260. response = http_client.post(
  261. endpoint_url,
  262. json=session_message.message.model_dump(
  263. by_alias=True,
  264. mode="json",
  265. exclude_none=True,
  266. ),
  267. )
  268. response.raise_for_status()
  269. logger.debug(f"Client message sent successfully: {response.status_code}")
  270. except Exception as exc:
  271. logger.exception("Error sending message")
  272. raise
  273. def read_messages(
  274. sse_client: SSEClient,
  275. ) -> Generator[SessionMessage | Exception, None, None]:
  276. """
  277. Read messages from the SSE client.
  278. Args:
  279. sse_client: The SSE client to read from
  280. Yields:
  281. SessionMessage or Exception for each event received
  282. """
  283. try:
  284. for sse in sse_client.events():
  285. if sse.event == "message":
  286. try:
  287. message = types.JSONRPCMessage.model_validate_json(sse.data)
  288. logger.debug(f"Received server message: {message}")
  289. yield SessionMessage(message)
  290. except Exception as exc:
  291. logger.exception("Error parsing server message")
  292. yield exc
  293. else:
  294. logger.warning(f"Unknown SSE event: {sse.event}")
  295. except Exception as exc:
  296. logger.exception("Error reading SSE messages")
  297. yield exc