| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 | 
                        - """
 - StreamableHTTP Client Transport Module
 - 
 - This module implements the StreamableHTTP transport for MCP clients,
 - providing support for HTTP POST requests with optional SSE streaming responses
 - and session management.
 - """
 - 
 - import logging
 - import queue
 - from collections.abc import Callable, Generator
 - from concurrent.futures import ThreadPoolExecutor
 - from contextlib import contextmanager
 - from dataclasses import dataclass
 - from datetime import timedelta
 - from typing import Any, cast
 - 
 - import httpx
 - from httpx_sse import EventSource, ServerSentEvent
 - 
 - from core.mcp.types import (
 -     ClientMessageMetadata,
 -     ErrorData,
 -     JSONRPCError,
 -     JSONRPCMessage,
 -     JSONRPCNotification,
 -     JSONRPCRequest,
 -     JSONRPCResponse,
 -     RequestId,
 -     SessionMessage,
 - )
 - from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
 - 
 - logger = logging.getLogger(__name__)
 - 
 - 
 - SessionMessageOrError = SessionMessage | Exception | None
 - # Queue types with clearer names for their roles
 - ServerToClientQueue = queue.Queue[SessionMessageOrError]  # Server to client messages
 - ClientToServerQueue = queue.Queue[SessionMessage | None]  # Client to server messages
 - GetSessionIdCallback = Callable[[], str | None]
 - 
 - MCP_SESSION_ID = "mcp-session-id"
 - LAST_EVENT_ID = "last-event-id"
 - CONTENT_TYPE = "content-type"
 - ACCEPT = "Accept"
 - 
 - 
 - JSON = "application/json"
 - SSE = "text/event-stream"
 - 
 - DEFAULT_QUEUE_READ_TIMEOUT = 3
 - 
 - 
 - class StreamableHTTPError(Exception):
 -     """Base exception for StreamableHTTP transport errors."""
 - 
 -     pass
 - 
 - 
 - class ResumptionError(StreamableHTTPError):
 -     """Raised when resumption request is invalid."""
 - 
 -     pass
 - 
 - 
 - @dataclass
 - class RequestContext:
 -     """Context for a request operation."""
 - 
 -     client: httpx.Client
 -     headers: dict[str, str]
 -     session_id: str | None
 -     session_message: SessionMessage
 -     metadata: ClientMessageMetadata | None
 -     server_to_client_queue: ServerToClientQueue  # Renamed for clarity
 -     sse_read_timeout: timedelta
 - 
 - 
 - class StreamableHTTPTransport:
 -     """StreamableHTTP client transport implementation."""
 - 
 -     def __init__(
 -         self,
 -         url: str,
 -         headers: dict[str, Any] | None = None,
 -         timeout: timedelta = timedelta(seconds=30),
 -         sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
 -     ) -> None:
 -         """Initialize the StreamableHTTP transport.
 - 
 -         Args:
 -             url: The endpoint URL.
 -             headers: Optional headers to include in requests.
 -             timeout: HTTP timeout for regular operations.
 -             sse_read_timeout: Timeout for SSE read operations.
 -         """
 -         self.url = url
 -         self.headers = headers or {}
 -         self.timeout = timeout
 -         self.sse_read_timeout = sse_read_timeout
 -         self.session_id: str | None = None
 -         self.request_headers = {
 -             ACCEPT: f"{JSON}, {SSE}",
 -             CONTENT_TYPE: JSON,
 -             **self.headers,
 -         }
 - 
 -     def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
 -         """Update headers with session ID if available."""
 -         headers = base_headers.copy()
 -         if self.session_id:
 -             headers[MCP_SESSION_ID] = self.session_id
 -         return headers
 - 
 -     def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
 -         """Check if the message is an initialization request."""
 -         return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
 - 
 -     def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
 -         """Check if the message is an initialized notification."""
 -         return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
 - 
 -     def _maybe_extract_session_id_from_response(
 -         self,
 -         response: httpx.Response,
 -     ) -> None:
 -         """Extract and store session ID from response headers."""
 -         new_session_id = response.headers.get(MCP_SESSION_ID)
 -         if new_session_id:
 -             self.session_id = new_session_id
 -             logger.info(f"Received session ID: {self.session_id}")
 - 
 -     def _handle_sse_event(
 -         self,
 -         sse: ServerSentEvent,
 -         server_to_client_queue: ServerToClientQueue,
 -         original_request_id: RequestId | None = None,
 -         resumption_callback: Callable[[str], None] | None = None,
 -     ) -> bool:
 -         """Handle an SSE event, returning True if the response is complete."""
 -         if sse.event == "message":
 -             try:
 -                 message = JSONRPCMessage.model_validate_json(sse.data)
 -                 logger.debug(f"SSE message: {message}")
 - 
 -                 # If this is a response and we have original_request_id, replace it
 -                 if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
 -                     message.root.id = original_request_id
 - 
 -                 session_message = SessionMessage(message)
 -                 # Put message in queue that goes to client
 -                 server_to_client_queue.put(session_message)
 - 
 -                 # Call resumption token callback if we have an ID
 -                 if sse.id and resumption_callback:
 -                     resumption_callback(sse.id)
 - 
 -                 # If this is a response or error return True indicating completion
 -                 # Otherwise, return False to continue listening
 -                 return isinstance(message.root, JSONRPCResponse | JSONRPCError)
 - 
 -             except Exception as exc:
 -                 # Put exception in queue that goes to client
 -                 server_to_client_queue.put(exc)
 -                 return False
 -         elif sse.event == "ping":
 -             logger.debug("Received ping event")
 -             return False
 -         else:
 -             logger.warning(f"Unknown SSE event: {sse.event}")
 -             return False
 - 
 -     def handle_get_stream(
 -         self,
 -         client: httpx.Client,
 -         server_to_client_queue: ServerToClientQueue,
 -     ) -> None:
 -         """Handle GET stream for server-initiated messages."""
 -         try:
 -             if not self.session_id:
 -                 return
 - 
 -             headers = self._update_headers_with_session(self.request_headers)
 - 
 -             with ssrf_proxy_sse_connect(
 -                 self.url,
 -                 headers=headers,
 -                 timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
 -                 client=client,
 -                 method="GET",
 -             ) as event_source:
 -                 event_source.response.raise_for_status()
 -                 logger.debug("GET SSE connection established")
 - 
 -                 for sse in event_source.iter_sse():
 -                     self._handle_sse_event(sse, server_to_client_queue)
 - 
 -         except Exception as exc:
 -             logger.debug(f"GET stream error (non-fatal): {exc}")
 - 
 -     def _handle_resumption_request(self, ctx: RequestContext) -> None:
 -         """Handle a resumption request using GET with SSE."""
 -         headers = self._update_headers_with_session(ctx.headers)
 -         if ctx.metadata and ctx.metadata.resumption_token:
 -             headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
 -         else:
 -             raise ResumptionError("Resumption request requires a resumption token")
 - 
 -         # Extract original request ID to map responses
 -         original_request_id = None
 -         if isinstance(ctx.session_message.message.root, JSONRPCRequest):
 -             original_request_id = ctx.session_message.message.root.id
 - 
 -         with ssrf_proxy_sse_connect(
 -             self.url,
 -             headers=headers,
 -             timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
 -             client=ctx.client,
 -             method="GET",
 -         ) as event_source:
 -             event_source.response.raise_for_status()
 -             logger.debug("Resumption GET SSE connection established")
 - 
 -             for sse in event_source.iter_sse():
 -                 is_complete = self._handle_sse_event(
 -                     sse,
 -                     ctx.server_to_client_queue,
 -                     original_request_id,
 -                     ctx.metadata.on_resumption_token_update if ctx.metadata else None,
 -                 )
 -                 if is_complete:
 -                     break
 - 
 -     def _handle_post_request(self, ctx: RequestContext) -> None:
 -         """Handle a POST request with response processing."""
 -         headers = self._update_headers_with_session(ctx.headers)
 -         message = ctx.session_message.message
 -         is_initialization = self._is_initialization_request(message)
 - 
 -         with ctx.client.stream(
 -             "POST",
 -             self.url,
 -             json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
 -             headers=headers,
 -         ) as response:
 -             if response.status_code == 202:
 -                 logger.debug("Received 202 Accepted")
 -                 return
 - 
 -             if response.status_code == 404:
 -                 if isinstance(message.root, JSONRPCRequest):
 -                     self._send_session_terminated_error(
 -                         ctx.server_to_client_queue,
 -                         message.root.id,
 -                     )
 -                 return
 - 
 -             response.raise_for_status()
 -             if is_initialization:
 -                 self._maybe_extract_session_id_from_response(response)
 - 
 -             content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
 - 
 -             if content_type.startswith(JSON):
 -                 self._handle_json_response(response, ctx.server_to_client_queue)
 -             elif content_type.startswith(SSE):
 -                 self._handle_sse_response(response, ctx)
 -             else:
 -                 self._handle_unexpected_content_type(
 -                     content_type,
 -                     ctx.server_to_client_queue,
 -                 )
 - 
 -     def _handle_json_response(
 -         self,
 -         response: httpx.Response,
 -         server_to_client_queue: ServerToClientQueue,
 -     ) -> None:
 -         """Handle JSON response from the server."""
 -         try:
 -             content = response.read()
 -             message = JSONRPCMessage.model_validate_json(content)
 -             session_message = SessionMessage(message)
 -             server_to_client_queue.put(session_message)
 -         except Exception as exc:
 -             server_to_client_queue.put(exc)
 - 
 -     def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
 -         """Handle SSE response from the server."""
 -         try:
 -             event_source = EventSource(response)
 -             for sse in event_source.iter_sse():
 -                 is_complete = self._handle_sse_event(
 -                     sse,
 -                     ctx.server_to_client_queue,
 -                     resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
 -                 )
 -                 if is_complete:
 -                     break
 -         except Exception as e:
 -             ctx.server_to_client_queue.put(e)
 - 
 -     def _handle_unexpected_content_type(
 -         self,
 -         content_type: str,
 -         server_to_client_queue: ServerToClientQueue,
 -     ) -> None:
 -         """Handle unexpected content type in response."""
 -         error_msg = f"Unexpected content type: {content_type}"
 -         logger.error(error_msg)
 -         server_to_client_queue.put(ValueError(error_msg))
 - 
 -     def _send_session_terminated_error(
 -         self,
 -         server_to_client_queue: ServerToClientQueue,
 -         request_id: RequestId,
 -     ) -> None:
 -         """Send a session terminated error response."""
 -         jsonrpc_error = JSONRPCError(
 -             jsonrpc="2.0",
 -             id=request_id,
 -             error=ErrorData(code=32600, message="Session terminated by server"),
 -         )
 -         session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
 -         server_to_client_queue.put(session_message)
 - 
 -     def post_writer(
 -         self,
 -         client: httpx.Client,
 -         client_to_server_queue: ClientToServerQueue,
 -         server_to_client_queue: ServerToClientQueue,
 -         start_get_stream: Callable[[], None],
 -     ) -> None:
 -         """Handle writing requests to the server.
 - 
 -         This method processes messages from the client_to_server_queue and sends them to the server.
 -         Responses are written to the server_to_client_queue.
 -         """
 -         while True:
 -             try:
 -                 # Read message from client queue with timeout to check stop_event periodically
 -                 session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
 -                 if session_message is None:
 -                     break
 - 
 -                 message = session_message.message
 -                 metadata = (
 -                     session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
 -                 )
 - 
 -                 # Check if this is a resumption request
 -                 is_resumption = bool(metadata and metadata.resumption_token)
 - 
 -                 logger.debug(f"Sending client message: {message}")
 - 
 -                 # Handle initialized notification
 -                 if self._is_initialized_notification(message):
 -                     start_get_stream()
 - 
 -                 ctx = RequestContext(
 -                     client=client,
 -                     headers=self.request_headers,
 -                     session_id=self.session_id,
 -                     session_message=session_message,
 -                     metadata=metadata,
 -                     server_to_client_queue=server_to_client_queue,  # Queue to write responses to client
 -                     sse_read_timeout=self.sse_read_timeout,
 -                 )
 - 
 -                 if is_resumption:
 -                     self._handle_resumption_request(ctx)
 -                 else:
 -                     self._handle_post_request(ctx)
 -             except queue.Empty:
 -                 continue
 -             except Exception as exc:
 -                 server_to_client_queue.put(exc)
 - 
 -     def terminate_session(self, client: httpx.Client) -> None:
 -         """Terminate the session by sending a DELETE request."""
 -         if not self.session_id:
 -             return
 - 
 -         try:
 -             headers = self._update_headers_with_session(self.request_headers)
 -             response = client.delete(self.url, headers=headers)
 - 
 -             if response.status_code == 405:
 -                 logger.debug("Server does not allow session termination")
 -             elif response.status_code != 200:
 -                 logger.warning(f"Session termination failed: {response.status_code}")
 -         except Exception as exc:
 -             logger.warning(f"Session termination failed: {exc}")
 - 
 -     def get_session_id(self) -> str | None:
 -         """Get the current session ID."""
 -         return self.session_id
 - 
 - 
 - @contextmanager
 - def streamablehttp_client(
 -     url: str,
 -     headers: dict[str, Any] | None = None,
 -     timeout: timedelta = timedelta(seconds=30),
 -     sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
 -     terminate_on_close: bool = True,
 - ) -> Generator[
 -     tuple[
 -         ServerToClientQueue,  # Queue for receiving messages FROM server
 -         ClientToServerQueue,  # Queue for sending messages TO server
 -         GetSessionIdCallback,
 -     ],
 -     None,
 -     None,
 - ]:
 -     """
 -     Client transport for StreamableHTTP.
 - 
 -     `sse_read_timeout` determines how long (in seconds) the client will wait for a new
 -     event before disconnecting. All other HTTP operations are controlled by `timeout`.
 - 
 -     Yields:
 -         Tuple containing:
 -             - server_to_client_queue: Queue for reading messages FROM the server
 -             - client_to_server_queue: Queue for sending messages TO the server
 -             - get_session_id_callback: Function to retrieve the current session ID
 -     """
 -     transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
 - 
 -     # Create queues with clear directional meaning
 -     server_to_client_queue: ServerToClientQueue = queue.Queue()  # For messages FROM server TO client
 -     client_to_server_queue: ClientToServerQueue = queue.Queue()  # For messages FROM client TO server
 - 
 -     with ThreadPoolExecutor(max_workers=2) as executor:
 -         try:
 -             with create_ssrf_proxy_mcp_http_client(
 -                 headers=transport.request_headers,
 -                 timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
 -             ) as client:
 -                 # Define callbacks that need access to thread pool
 -                 def start_get_stream() -> None:
 -                     """Start a worker thread to handle server-initiated messages."""
 -                     executor.submit(transport.handle_get_stream, client, server_to_client_queue)
 - 
 -                 # Start the post_writer worker thread
 -                 executor.submit(
 -                     transport.post_writer,
 -                     client,
 -                     client_to_server_queue,  # Queue for messages FROM client TO server
 -                     server_to_client_queue,  # Queue for messages FROM server TO client
 -                     start_get_stream,
 -                 )
 - 
 -                 try:
 -                     yield (
 -                         server_to_client_queue,  # Queue for receiving messages FROM server
 -                         client_to_server_queue,  # Queue for sending messages TO server
 -                         transport.get_session_id,
 -                     )
 -                 finally:
 -                     if transport.session_id and terminate_on_close:
 -                         transport.terminate_session(client)
 - 
 -                     # Signal threads to stop
 -                     client_to_server_queue.put(None)
 -         finally:
 -             # Clear any remaining items and add None sentinel to unblock any waiting threads
 -             try:
 -                 while not client_to_server_queue.empty():
 -                     client_to_server_queue.get_nowait()
 -             except queue.Empty:
 -                 pass
 - 
 -             client_to_server_queue.put(None)
 -             server_to_client_queue.put(None)
 
 
  |