| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361 |
- import logging
- import queue
- from collections.abc import Generator
- from concurrent.futures import ThreadPoolExecutor
- from contextlib import contextmanager
- from typing import Any, TypeAlias, final
- from urllib.parse import urljoin, urlparse
-
- import httpx
- from sseclient import SSEClient
-
- from core.mcp import types
- from core.mcp.error import MCPAuthError, MCPConnectionError
- from core.mcp.types import SessionMessage
- from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
-
- logger = logging.getLogger(__name__)
-
- DEFAULT_QUEUE_READ_TIMEOUT = 3
-
-
- @final
- class _StatusReady:
- def __init__(self, endpoint_url: str):
- self._endpoint_url = endpoint_url
-
-
- @final
- class _StatusError:
- def __init__(self, exc: Exception):
- self._exc = exc
-
-
- # Type aliases for better readability
- ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
- WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
- StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
-
-
- def remove_request_params(url: str) -> str:
- """Remove request parameters from URL, keeping only the path."""
- return urljoin(url, urlparse(url).path)
-
-
- class SSETransport:
- """SSE client transport implementation."""
-
- def __init__(
- self,
- url: str,
- headers: dict[str, Any] | None = None,
- timeout: float = 5.0,
- sse_read_timeout: float = 5 * 60,
- ) -> None:
- """Initialize the SSE transport.
-
- Args:
- url: The SSE endpoint URL.
- headers: Optional headers to include in requests.
- timeout: HTTP timeout for regular operations.
- sse_read_timeout: Timeout for SSE read operations.
- """
- self.url = url
- self.headers = headers or {}
- self.timeout = timeout
- self.sse_read_timeout = sse_read_timeout
- self.endpoint_url: str | None = None
-
- def _validate_endpoint_url(self, endpoint_url: str) -> bool:
- """Validate that the endpoint URL matches the connection origin.
-
- Args:
- endpoint_url: The endpoint URL to validate.
-
- Returns:
- True if valid, False otherwise.
- """
- url_parsed = urlparse(self.url)
- endpoint_parsed = urlparse(endpoint_url)
-
- return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
-
- def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
- """Handle an 'endpoint' SSE event.
-
- Args:
- sse_data: The SSE event data.
- status_queue: Queue to put status updates.
- """
- endpoint_url = urljoin(self.url, sse_data)
- logger.info("Received endpoint URL: %s", endpoint_url)
-
- if not self._validate_endpoint_url(endpoint_url):
- error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
- logger.error(error_msg)
- status_queue.put(_StatusError(ValueError(error_msg)))
- return
-
- status_queue.put(_StatusReady(endpoint_url))
-
- def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
- """Handle a 'message' SSE event.
-
- Args:
- sse_data: The SSE event data.
- read_queue: Queue to put parsed messages.
- """
- try:
- message = types.JSONRPCMessage.model_validate_json(sse_data)
- logger.debug("Received server message: %s", message)
- session_message = SessionMessage(message)
- read_queue.put(session_message)
- except Exception as exc:
- logger.exception("Error parsing server message")
- read_queue.put(exc)
-
- def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
- """Handle a single SSE event.
-
- Args:
- sse: The SSE event object.
- read_queue: Queue for message events.
- status_queue: Queue for status events.
- """
- match sse.event:
- case "endpoint":
- self._handle_endpoint_event(sse.data, status_queue)
- case "message":
- self._handle_message_event(sse.data, read_queue)
- case _:
- logger.warning("Unknown SSE event: %s", sse.event)
-
- def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
- """Read and process SSE events.
-
- Args:
- event_source: The SSE event source.
- read_queue: Queue to put received messages.
- status_queue: Queue to put status updates.
- """
- try:
- for sse in event_source.iter_sse():
- self._handle_sse_event(sse, read_queue, status_queue)
- except httpx.ReadError as exc:
- logger.debug("SSE reader shutting down normally: %s", exc)
- except Exception as exc:
- read_queue.put(exc)
- finally:
- read_queue.put(None)
-
- def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
- """Send a single message to the server.
-
- Args:
- client: HTTP client to use.
- endpoint_url: The endpoint URL to send to.
- message: The message to send.
- """
- response = client.post(
- endpoint_url,
- json=message.message.model_dump(
- by_alias=True,
- mode="json",
- exclude_none=True,
- ),
- )
- response.raise_for_status()
- logger.debug("Client message sent successfully: %s", response.status_code)
-
- def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
- """Handle writing messages to the server.
-
- Args:
- client: HTTP client to use.
- endpoint_url: The endpoint URL to send messages to.
- write_queue: Queue to read messages from.
- """
- try:
- while True:
- try:
- message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
- if message is None:
- break
- if isinstance(message, Exception):
- write_queue.put(message)
- continue
-
- self._send_message(client, endpoint_url, message)
-
- except queue.Empty:
- continue
- except httpx.ReadError as exc:
- logger.debug("Post writer shutting down normally: %s", exc)
- except Exception as exc:
- logger.exception("Error writing messages")
- write_queue.put(exc)
- finally:
- write_queue.put(None)
-
- def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
- """Wait for the endpoint URL from the status queue.
-
- Args:
- status_queue: Queue to read status from.
-
- Returns:
- The endpoint URL.
-
- Raises:
- ValueError: If endpoint URL is not received or there's an error.
- """
- try:
- status = status_queue.get(timeout=1)
- except queue.Empty:
- raise ValueError("failed to get endpoint URL")
-
- if isinstance(status, _StatusReady):
- return status._endpoint_url
- elif isinstance(status, _StatusError):
- raise status._exc
- else:
- raise ValueError("failed to get endpoint URL")
-
- def connect(
- self,
- executor: ThreadPoolExecutor,
- client: httpx.Client,
- event_source,
- ) -> tuple[ReadQueue, WriteQueue]:
- """Establish connection and start worker threads.
-
- Args:
- executor: Thread pool executor.
- client: HTTP client.
- event_source: SSE event source.
-
- Returns:
- Tuple of (read_queue, write_queue).
- """
- read_queue: ReadQueue = queue.Queue()
- write_queue: WriteQueue = queue.Queue()
- status_queue: StatusQueue = queue.Queue()
-
- # Start SSE reader thread
- executor.submit(self.sse_reader, event_source, read_queue, status_queue)
-
- # Wait for endpoint URL
- endpoint_url = self._wait_for_endpoint(status_queue)
- self.endpoint_url = endpoint_url
-
- # Start post writer thread
- executor.submit(self.post_writer, client, endpoint_url, write_queue)
-
- return read_queue, write_queue
-
-
- @contextmanager
- def sse_client(
- url: str,
- headers: dict[str, Any] | None = None,
- timeout: float = 5.0,
- sse_read_timeout: float = 5 * 60,
- ) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
- """
- Client transport for SSE.
- `sse_read_timeout` determines how long (in seconds) the client will wait for a new
- event before disconnecting. All other HTTP operations are controlled by `timeout`.
-
- Args:
- url: The SSE endpoint URL.
- headers: Optional headers to include in requests.
- timeout: HTTP timeout for regular operations.
- sse_read_timeout: Timeout for SSE read operations.
-
- Yields:
- Tuple of (read_queue, write_queue) for message communication.
- """
- transport = SSETransport(url, headers, timeout, sse_read_timeout)
-
- read_queue: ReadQueue | None = None
- write_queue: WriteQueue | None = None
-
- with ThreadPoolExecutor() as executor:
- try:
- with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
- with ssrf_proxy_sse_connect(
- url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
- ) as event_source:
- event_source.response.raise_for_status()
-
- read_queue, write_queue = transport.connect(executor, client, event_source)
-
- yield read_queue, write_queue
-
- except httpx.HTTPStatusError as exc:
- if exc.response.status_code == 401:
- raise MCPAuthError()
- raise MCPConnectionError()
- except Exception:
- logger.exception("Error connecting to SSE endpoint")
- raise
- finally:
- # Clean up queues
- if read_queue:
- read_queue.put(None)
- if write_queue:
- write_queue.put(None)
-
-
- def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
- """
- Send a message to the server using the provided HTTP client.
-
- Args:
- http_client: The HTTP client to use for sending
- endpoint_url: The endpoint URL to send the message to
- session_message: The message to send
- """
- try:
- response = http_client.post(
- endpoint_url,
- json=session_message.message.model_dump(
- by_alias=True,
- mode="json",
- exclude_none=True,
- ),
- )
- response.raise_for_status()
- logger.debug("Client message sent successfully: %s", response.status_code)
- except Exception as exc:
- logger.exception("Error sending message")
- raise
-
-
- def read_messages(
- sse_client: SSEClient,
- ) -> Generator[SessionMessage | Exception, None, None]:
- """
- Read messages from the SSE client.
-
- Args:
- sse_client: The SSE client to read from
-
- Yields:
- SessionMessage or Exception for each event received
- """
- try:
- for sse in sse_client.events():
- if sse.event == "message":
- try:
- message = types.JSONRPCMessage.model_validate_json(sse.data)
- logger.debug("Received server message: %s", message)
- yield SessionMessage(message)
- except Exception as exc:
- logger.exception("Error parsing server message")
- yield exc
- else:
- logger.warning("Unknown SSE event: %s", sse.event)
- except Exception as exc:
- logger.exception("Error reading SSE messages")
- yield exc
|