| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 | 
							- import logging
 - import queue
 - from collections.abc import Generator
 - from concurrent.futures import ThreadPoolExecutor
 - from contextlib import contextmanager
 - from typing import Any, TypeAlias, final
 - from urllib.parse import urljoin, urlparse
 - 
 - import httpx
 - from httpx_sse import EventSource, ServerSentEvent
 - from sseclient import SSEClient
 - 
 - from core.mcp import types
 - from core.mcp.error import MCPAuthError, MCPConnectionError
 - from core.mcp.types import SessionMessage
 - from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
 - 
 - logger = logging.getLogger(__name__)
 - 
 - DEFAULT_QUEUE_READ_TIMEOUT = 3
 - 
 - 
 - @final
 - class _StatusReady:
 -     def __init__(self, endpoint_url: str):
 -         self._endpoint_url = endpoint_url
 - 
 - 
 - @final
 - class _StatusError:
 -     def __init__(self, exc: Exception):
 -         self._exc = exc
 - 
 - 
 - # Type aliases for better readability
 - ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
 - WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
 - StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
 - 
 - 
 - class SSETransport:
 -     """SSE client transport implementation."""
 - 
 -     def __init__(
 -         self,
 -         url: str,
 -         headers: dict[str, Any] | None = None,
 -         timeout: float = 5.0,
 -         sse_read_timeout: float = 5 * 60,
 -     ) -> None:
 -         """Initialize the SSE transport.
 - 
 -         Args:
 -             url: The SSE endpoint URL.
 -             headers: Optional headers to include in requests.
 -             timeout: HTTP timeout for regular operations.
 -             sse_read_timeout: Timeout for SSE read operations.
 -         """
 -         self.url = url
 -         self.headers = headers or {}
 -         self.timeout = timeout
 -         self.sse_read_timeout = sse_read_timeout
 -         self.endpoint_url: str | None = None
 - 
 -     def _validate_endpoint_url(self, endpoint_url: str) -> bool:
 -         """Validate that the endpoint URL matches the connection origin.
 - 
 -         Args:
 -             endpoint_url: The endpoint URL to validate.
 - 
 -         Returns:
 -             True if valid, False otherwise.
 -         """
 -         url_parsed = urlparse(self.url)
 -         endpoint_parsed = urlparse(endpoint_url)
 - 
 -         return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
 - 
 -     def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
 -         """Handle an 'endpoint' SSE event.
 - 
 -         Args:
 -             sse_data: The SSE event data.
 -             status_queue: Queue to put status updates.
 -         """
 -         endpoint_url = urljoin(self.url, sse_data)
 -         logger.info("Received endpoint URL: %s", endpoint_url)
 - 
 -         if not self._validate_endpoint_url(endpoint_url):
 -             error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
 -             logger.error(error_msg)
 -             status_queue.put(_StatusError(ValueError(error_msg)))
 -             return
 - 
 -         status_queue.put(_StatusReady(endpoint_url))
 - 
 -     def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
 -         """Handle a 'message' SSE event.
 - 
 -         Args:
 -             sse_data: The SSE event data.
 -             read_queue: Queue to put parsed messages.
 -         """
 -         try:
 -             message = types.JSONRPCMessage.model_validate_json(sse_data)
 -             logger.debug("Received server message: %s", message)
 -             session_message = SessionMessage(message)
 -             read_queue.put(session_message)
 -         except Exception as exc:
 -             logger.exception("Error parsing server message")
 -             read_queue.put(exc)
 - 
 -     def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
 -         """Handle a single SSE event.
 - 
 -         Args:
 -             sse: The SSE event object.
 -             read_queue: Queue for message events.
 -             status_queue: Queue for status events.
 -         """
 -         match sse.event:
 -             case "endpoint":
 -                 self._handle_endpoint_event(sse.data, status_queue)
 -             case "message":
 -                 self._handle_message_event(sse.data, read_queue)
 -             case _:
 -                 logger.warning("Unknown SSE event: %s", sse.event)
 - 
 -     def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
 -         """Read and process SSE events.
 - 
 -         Args:
 -             event_source: The SSE event source.
 -             read_queue: Queue to put received messages.
 -             status_queue: Queue to put status updates.
 -         """
 -         try:
 -             for sse in event_source.iter_sse():
 -                 self._handle_sse_event(sse, read_queue, status_queue)
 -         except httpx.ReadError as exc:
 -             logger.debug("SSE reader shutting down normally: %s", exc)
 -         except Exception as exc:
 -             read_queue.put(exc)
 -         finally:
 -             read_queue.put(None)
 - 
 -     def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
 -         """Send a single message to the server.
 - 
 -         Args:
 -             client: HTTP client to use.
 -             endpoint_url: The endpoint URL to send to.
 -             message: The message to send.
 -         """
 -         response = client.post(
 -             endpoint_url,
 -             json=message.message.model_dump(
 -                 by_alias=True,
 -                 mode="json",
 -                 exclude_none=True,
 -             ),
 -         )
 -         response.raise_for_status()
 -         logger.debug("Client message sent successfully: %s", response.status_code)
 - 
 -     def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
 -         """Handle writing messages to the server.
 - 
 -         Args:
 -             client: HTTP client to use.
 -             endpoint_url: The endpoint URL to send messages to.
 -             write_queue: Queue to read messages from.
 -         """
 -         try:
 -             while True:
 -                 try:
 -                     message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
 -                     if message is None:
 -                         break
 -                     if isinstance(message, Exception):
 -                         write_queue.put(message)
 -                         continue
 - 
 -                     self._send_message(client, endpoint_url, message)
 - 
 -                 except queue.Empty:
 -                     continue
 -         except httpx.ReadError as exc:
 -             logger.debug("Post writer shutting down normally: %s", exc)
 -         except Exception as exc:
 -             logger.exception("Error writing messages")
 -             write_queue.put(exc)
 -         finally:
 -             write_queue.put(None)
 - 
 -     def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
 -         """Wait for the endpoint URL from the status queue.
 - 
 -         Args:
 -             status_queue: Queue to read status from.
 - 
 -         Returns:
 -             The endpoint URL.
 - 
 -         Raises:
 -             ValueError: If endpoint URL is not received or there's an error.
 -         """
 -         try:
 -             status = status_queue.get(timeout=1)
 -         except queue.Empty:
 -             raise ValueError("failed to get endpoint URL")
 - 
 -         if isinstance(status, _StatusReady):
 -             return status._endpoint_url
 -         elif isinstance(status, _StatusError):
 -             raise status._exc
 -         else:
 -             raise ValueError("failed to get endpoint URL")
 - 
 -     def connect(
 -         self,
 -         executor: ThreadPoolExecutor,
 -         client: httpx.Client,
 -         event_source: EventSource,
 -     ) -> tuple[ReadQueue, WriteQueue]:
 -         """Establish connection and start worker threads.
 - 
 -         Args:
 -             executor: Thread pool executor.
 -             client: HTTP client.
 -             event_source: SSE event source.
 - 
 -         Returns:
 -             Tuple of (read_queue, write_queue).
 -         """
 -         read_queue: ReadQueue = queue.Queue()
 -         write_queue: WriteQueue = queue.Queue()
 -         status_queue: StatusQueue = queue.Queue()
 - 
 -         # Start SSE reader thread
 -         executor.submit(self.sse_reader, event_source, read_queue, status_queue)
 - 
 -         # Wait for endpoint URL
 -         endpoint_url = self._wait_for_endpoint(status_queue)
 -         self.endpoint_url = endpoint_url
 - 
 -         # Start post writer thread
 -         executor.submit(self.post_writer, client, endpoint_url, write_queue)
 - 
 -         return read_queue, write_queue
 - 
 - 
 - @contextmanager
 - def sse_client(
 -     url: str,
 -     headers: dict[str, Any] | None = None,
 -     timeout: float = 5.0,
 -     sse_read_timeout: float = 5 * 60,
 - ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
 -     """
 -     Client transport for SSE.
 -     `sse_read_timeout` determines how long (in seconds) the client will wait for a new
 -     event before disconnecting. All other HTTP operations are controlled by `timeout`.
 - 
 -     Args:
 -         url: The SSE endpoint URL.
 -         headers: Optional headers to include in requests.
 -         timeout: HTTP timeout for regular operations.
 -         sse_read_timeout: Timeout for SSE read operations.
 - 
 -     Yields:
 -         Tuple of (read_queue, write_queue) for message communication.
 -     """
 -     transport = SSETransport(url, headers, timeout, sse_read_timeout)
 - 
 -     read_queue: ReadQueue | None = None
 -     write_queue: WriteQueue | None = None
 - 
 -     with ThreadPoolExecutor() as executor:
 -         try:
 -             with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
 -                 with ssrf_proxy_sse_connect(
 -                     url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
 -                 ) as event_source:
 -                     event_source.response.raise_for_status()
 - 
 -                     read_queue, write_queue = transport.connect(executor, client, event_source)
 - 
 -                     yield read_queue, write_queue
 - 
 -         except httpx.HTTPStatusError as exc:
 -             if exc.response.status_code == 401:
 -                 raise MCPAuthError()
 -             raise MCPConnectionError()
 -         except Exception:
 -             logger.exception("Error connecting to SSE endpoint")
 -             raise
 -         finally:
 -             # Clean up queues
 -             if read_queue:
 -                 read_queue.put(None)
 -             if write_queue:
 -                 write_queue.put(None)
 - 
 - 
 - def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
 -     """
 -     Send a message to the server using the provided HTTP client.
 - 
 -     Args:
 -         http_client: The HTTP client to use for sending
 -         endpoint_url: The endpoint URL to send the message to
 -         session_message: The message to send
 -     """
 -     try:
 -         response = http_client.post(
 -             endpoint_url,
 -             json=session_message.message.model_dump(
 -                 by_alias=True,
 -                 mode="json",
 -                 exclude_none=True,
 -             ),
 -         )
 -         response.raise_for_status()
 -         logger.debug("Client message sent successfully: %s", response.status_code)
 -     except Exception:
 -         logger.exception("Error sending message")
 -         raise
 - 
 - 
 - def read_messages(
 -     sse_client: SSEClient,
 - ) -> Generator[SessionMessage | Exception, None, None]:
 -     """
 -     Read messages from the SSE client.
 - 
 -     Args:
 -         sse_client: The SSE client to read from
 - 
 -     Yields:
 -         SessionMessage or Exception for each event received
 -     """
 -     try:
 -         for sse in sse_client.events():
 -             if sse.event == "message":
 -                 try:
 -                     message = types.JSONRPCMessage.model_validate_json(sse.data)
 -                     logger.debug("Received server message: %s", message)
 -                     yield SessionMessage(message)
 -                 except Exception as exc:
 -                     logger.exception("Error parsing server message")
 -                     yield exc
 -             else:
 -                 logger.warning("Unknown SSE event: %s", sse.event)
 -     except Exception as exc:
 -         logger.exception("Error reading SSE messages")
 -         yield exc
 
 
  |