您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

sse_client.py 12KB


  1. import logging
  2. import queue
  3. from collections.abc import Generator
  4. from concurrent.futures import ThreadPoolExecutor
  5. from contextlib import contextmanager
  6. from typing import Any, TypeAlias, final
  7. from urllib.parse import urljoin, urlparse
  8. import httpx
  9. from httpx_sse import EventSource, ServerSentEvent
  10. from sseclient import SSEClient
  11. from core.mcp import types
  12. from core.mcp.error import MCPAuthError, MCPConnectionError
  13. from core.mcp.types import SessionMessage
  14. from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
  15. logger = logging.getLogger(__name__)
  16. DEFAULT_QUEUE_READ_TIMEOUT = 3
  17. @final
  18. class _StatusReady:
  19. def __init__(self, endpoint_url: str):
  20. self._endpoint_url = endpoint_url
  21. @final
  22. class _StatusError:
  23. def __init__(self, exc: Exception):
  24. self._exc = exc
  25. # Type aliases for better readability
  26. ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  27. WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
  28. StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
  29. def remove_request_params(url: str) -> str:
  30. """Remove request parameters from URL, keeping only the path."""
  31. return urljoin(url, urlparse(url).path)
  32. class SSETransport:
  33. """SSE client transport implementation."""
  34. def __init__(
  35. self,
  36. url: str,
  37. headers: dict[str, Any] | None = None,
  38. timeout: float = 5.0,
  39. sse_read_timeout: float = 5 * 60,
  40. ) -> None:
  41. """Initialize the SSE transport.
  42. Args:
  43. url: The SSE endpoint URL.
  44. headers: Optional headers to include in requests.
  45. timeout: HTTP timeout for regular operations.
  46. sse_read_timeout: Timeout for SSE read operations.
  47. """
  48. self.url = url
  49. self.headers = headers or {}
  50. self.timeout = timeout
  51. self.sse_read_timeout = sse_read_timeout
  52. self.endpoint_url: str | None = None
  53. def _validate_endpoint_url(self, endpoint_url: str) -> bool:
  54. """Validate that the endpoint URL matches the connection origin.
  55. Args:
  56. endpoint_url: The endpoint URL to validate.
  57. Returns:
  58. True if valid, False otherwise.
  59. """
  60. url_parsed = urlparse(self.url)
  61. endpoint_parsed = urlparse(endpoint_url)
  62. return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
  63. def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
  64. """Handle an 'endpoint' SSE event.
  65. Args:
  66. sse_data: The SSE event data.
  67. status_queue: Queue to put status updates.
  68. """
  69. endpoint_url = urljoin(self.url, sse_data)
  70. logger.info("Received endpoint URL: %s", endpoint_url)
  71. if not self._validate_endpoint_url(endpoint_url):
  72. error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
  73. logger.error(error_msg)
  74. status_queue.put(_StatusError(ValueError(error_msg)))
  75. return
  76. status_queue.put(_StatusReady(endpoint_url))
  77. def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
  78. """Handle a 'message' SSE event.
  79. Args:
  80. sse_data: The SSE event data.
  81. read_queue: Queue to put parsed messages.
  82. """
  83. try:
  84. message = types.JSONRPCMessage.model_validate_json(sse_data)
  85. logger.debug("Received server message: %s", message)
  86. session_message = SessionMessage(message)
  87. read_queue.put(session_message)
  88. except Exception as exc:
  89. logger.exception("Error parsing server message")
  90. read_queue.put(exc)
  91. def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
  92. """Handle a single SSE event.
  93. Args:
  94. sse: The SSE event object.
  95. read_queue: Queue for message events.
  96. status_queue: Queue for status events.
  97. """
  98. match sse.event:
  99. case "endpoint":
  100. self._handle_endpoint_event(sse.data, status_queue)
  101. case "message":
  102. self._handle_message_event(sse.data, read_queue)
  103. case _:
  104. logger.warning("Unknown SSE event: %s", sse.event)
  105. def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
  106. """Read and process SSE events.
  107. Args:
  108. event_source: The SSE event source.
  109. read_queue: Queue to put received messages.
  110. status_queue: Queue to put status updates.
  111. """
  112. try:
  113. for sse in event_source.iter_sse():
  114. self._handle_sse_event(sse, read_queue, status_queue)
  115. except httpx.ReadError as exc:
  116. logger.debug("SSE reader shutting down normally: %s", exc)
  117. except Exception as exc:
  118. read_queue.put(exc)
  119. finally:
  120. read_queue.put(None)
  121. def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
  122. """Send a single message to the server.
  123. Args:
  124. client: HTTP client to use.
  125. endpoint_url: The endpoint URL to send to.
  126. message: The message to send.
  127. """
  128. response = client.post(
  129. endpoint_url,
  130. json=message.message.model_dump(
  131. by_alias=True,
  132. mode="json",
  133. exclude_none=True,
  134. ),
  135. )
  136. response.raise_for_status()
  137. logger.debug("Client message sent successfully: %s", response.status_code)
  138. def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
  139. """Handle writing messages to the server.
  140. Args:
  141. client: HTTP client to use.
  142. endpoint_url: The endpoint URL to send messages to.
  143. write_queue: Queue to read messages from.
  144. """
  145. try:
  146. while True:
  147. try:
  148. message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
  149. if message is None:
  150. break
  151. if isinstance(message, Exception):
  152. write_queue.put(message)
  153. continue
  154. self._send_message(client, endpoint_url, message)
  155. except queue.Empty:
  156. continue
  157. except httpx.ReadError as exc:
  158. logger.debug("Post writer shutting down normally: %s", exc)
  159. except Exception as exc:
  160. logger.exception("Error writing messages")
  161. write_queue.put(exc)
  162. finally:
  163. write_queue.put(None)
  164. def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
  165. """Wait for the endpoint URL from the status queue.
  166. Args:
  167. status_queue: Queue to read status from.
  168. Returns:
  169. The endpoint URL.
  170. Raises:
  171. ValueError: If endpoint URL is not received or there's an error.
  172. """
  173. try:
  174. status = status_queue.get(timeout=1)
  175. except queue.Empty:
  176. raise ValueError("failed to get endpoint URL")
  177. if isinstance(status, _StatusReady):
  178. return status._endpoint_url
  179. elif isinstance(status, _StatusError):
  180. raise status._exc
  181. else:
  182. raise ValueError("failed to get endpoint URL")
  183. def connect(
  184. self,
  185. executor: ThreadPoolExecutor,
  186. client: httpx.Client,
  187. event_source: EventSource,
  188. ) -> tuple[ReadQueue, WriteQueue]:
  189. """Establish connection and start worker threads.
  190. Args:
  191. executor: Thread pool executor.
  192. client: HTTP client.
  193. event_source: SSE event source.
  194. Returns:
  195. Tuple of (read_queue, write_queue).
  196. """
  197. read_queue: ReadQueue = queue.Queue()
  198. write_queue: WriteQueue = queue.Queue()
  199. status_queue: StatusQueue = queue.Queue()
  200. # Start SSE reader thread
  201. executor.submit(self.sse_reader, event_source, read_queue, status_queue)
  202. # Wait for endpoint URL
  203. endpoint_url = self._wait_for_endpoint(status_queue)
  204. self.endpoint_url = endpoint_url
  205. # Start post writer thread
  206. executor.submit(self.post_writer, client, endpoint_url, write_queue)
  207. return read_queue, write_queue
  208. @contextmanager
  209. def sse_client(
  210. url: str,
  211. headers: dict[str, Any] | None = None,
  212. timeout: float = 5.0,
  213. sse_read_timeout: float = 5 * 60,
  214. ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
  215. """
  216. Client transport for SSE.
  217. `sse_read_timeout` determines how long (in seconds) the client will wait for a new
  218. event before disconnecting. All other HTTP operations are controlled by `timeout`.
  219. Args:
  220. url: The SSE endpoint URL.
  221. headers: Optional headers to include in requests.
  222. timeout: HTTP timeout for regular operations.
  223. sse_read_timeout: Timeout for SSE read operations.
  224. Yields:
  225. Tuple of (read_queue, write_queue) for message communication.
  226. """
  227. transport = SSETransport(url, headers, timeout, sse_read_timeout)
  228. read_queue: ReadQueue | None = None
  229. write_queue: WriteQueue | None = None
  230. with ThreadPoolExecutor() as executor:
  231. try:
  232. with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
  233. with ssrf_proxy_sse_connect(
  234. url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
  235. ) as event_source:
  236. event_source.response.raise_for_status()
  237. read_queue, write_queue = transport.connect(executor, client, event_source)
  238. yield read_queue, write_queue
  239. except httpx.HTTPStatusError as exc:
  240. if exc.response.status_code == 401:
  241. raise MCPAuthError()
  242. raise MCPConnectionError()
  243. except Exception:
  244. logger.exception("Error connecting to SSE endpoint")
  245. raise
  246. finally:
  247. # Clean up queues
  248. if read_queue:
  249. read_queue.put(None)
  250. if write_queue:
  251. write_queue.put(None)
  252. def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
  253. """
  254. Send a message to the server using the provided HTTP client.
  255. Args:
  256. http_client: The HTTP client to use for sending
  257. endpoint_url: The endpoint URL to send the message to
  258. session_message: The message to send
  259. """
  260. try:
  261. response = http_client.post(
  262. endpoint_url,
  263. json=session_message.message.model_dump(
  264. by_alias=True,
  265. mode="json",
  266. exclude_none=True,
  267. ),
  268. )
  269. response.raise_for_status()
  270. logger.debug("Client message sent successfully: %s", response.status_code)
  271. except Exception:
  272. logger.exception("Error sending message")
  273. raise
  274. def read_messages(
  275. sse_client: SSEClient,
  276. ) -> Generator[SessionMessage | Exception, None, None]:
  277. """
  278. Read messages from the SSE client.
  279. Args:
  280. sse_client: The SSE client to read from
  281. Yields:
  282. SessionMessage or Exception for each event received
  283. """
  284. try:
  285. for sse in sse_client.events():
  286. if sse.event == "message":
  287. try:
  288. message = types.JSONRPCMessage.model_validate_json(sse.data)
  289. logger.debug("Received server message: %s", message)
  290. yield SessionMessage(message)
  291. except Exception as exc:
  292. logger.exception("Error parsing server message")
  293. yield exc
  294. else:
  295. logger.warning("Unknown SSE event: %s", sse.event)
  296. except Exception as exc:
  297. logger.exception("Error reading SSE messages")
  298. yield exc