| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342 | 
                        - import base64
 - import hashlib
 - import json
 - import os
 - import secrets
 - import urllib.parse
 - from typing import Optional
 - from urllib.parse import urljoin
 - 
 - import requests
 - from pydantic import BaseModel, ValidationError
 - 
 - from core.mcp.auth.auth_provider import OAuthClientProvider
 - from core.mcp.types import (
 -     OAuthClientInformation,
 -     OAuthClientInformationFull,
 -     OAuthClientMetadata,
 -     OAuthMetadata,
 -     OAuthTokens,
 - )
 - from extensions.ext_redis import redis_client
 - 
 - LATEST_PROTOCOL_VERSION = "1.0"
 - OAUTH_STATE_EXPIRY_SECONDS = 5 * 60  # 5 minutes expiry
 - OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
 - 
 - 
 - class OAuthCallbackState(BaseModel):
 -     provider_id: str
 -     tenant_id: str
 -     server_url: str
 -     metadata: OAuthMetadata | None = None
 -     client_information: OAuthClientInformation
 -     code_verifier: str
 -     redirect_uri: str
 - 
 - 
 - def generate_pkce_challenge() -> tuple[str, str]:
 -     """Generate PKCE challenge and verifier."""
 -     code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
 -     code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
 - 
 -     code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
 -     code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
 -     code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
 - 
 -     return code_verifier, code_challenge
 - 
 - 
 - def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
 -     """Create a secure state parameter by storing state data in Redis and returning a random state key."""
 -     # Generate a secure random state key
 -     state_key = secrets.token_urlsafe(32)
 - 
 -     # Store the state data in Redis with expiration
 -     redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
 -     redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
 - 
 -     return state_key
 - 
 - 
 - def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
 -     """Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
 -     redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
 - 
 -     # Get state data from Redis
 -     state_data = redis_client.get(redis_key)
 - 
 -     if not state_data:
 -         raise ValueError("State parameter has expired or does not exist")
 - 
 -     # Delete the state data from Redis immediately after retrieval to prevent reuse
 -     redis_client.delete(redis_key)
 - 
 -     try:
 -         # Parse and validate the state data
 -         oauth_state = OAuthCallbackState.model_validate_json(state_data)
 - 
 -         return oauth_state
 -     except ValidationError as e:
 -         raise ValueError(f"Invalid state parameter: {str(e)}")
 - 
 - 
 - def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
 -     """Handle the callback from the OAuth provider."""
 -     # Retrieve state data from Redis (state is automatically deleted after retrieval)
 -     full_state_data = _retrieve_redis_state(state_key)
 - 
 -     tokens = exchange_authorization(
 -         full_state_data.server_url,
 -         full_state_data.metadata,
 -         full_state_data.client_information,
 -         authorization_code,
 -         full_state_data.code_verifier,
 -         full_state_data.redirect_uri,
 -     )
 -     provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
 -     provider.save_tokens(tokens)
 -     return full_state_data
 - 
 - 
 - def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
 -     """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
 -     url = urljoin(server_url, "/.well-known/oauth-authorization-server")
 - 
 -     try:
 -         headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
 -         response = requests.get(url, headers=headers)
 -         if response.status_code == 404:
 -             return None
 -         if not response.ok:
 -             raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
 -         return OAuthMetadata.model_validate(response.json())
 -     except requests.RequestException as e:
 -         if isinstance(e, requests.ConnectionError):
 -             response = requests.get(url)
 -             if response.status_code == 404:
 -                 return None
 -             if not response.ok:
 -                 raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
 -             return OAuthMetadata.model_validate(response.json())
 -         raise
 - 
 - 
 - def start_authorization(
 -     server_url: str,
 -     metadata: Optional[OAuthMetadata],
 -     client_information: OAuthClientInformation,
 -     redirect_url: str,
 -     provider_id: str,
 -     tenant_id: str,
 - ) -> tuple[str, str]:
 -     """Begins the authorization flow with secure Redis state storage."""
 -     response_type = "code"
 -     code_challenge_method = "S256"
 - 
 -     if metadata:
 -         authorization_url = metadata.authorization_endpoint
 -         if response_type not in metadata.response_types_supported:
 -             raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
 -         if (
 -             not metadata.code_challenge_methods_supported
 -             or code_challenge_method not in metadata.code_challenge_methods_supported
 -         ):
 -             raise ValueError(
 -                 f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
 -             )
 -     else:
 -         authorization_url = urljoin(server_url, "/authorize")
 - 
 -     code_verifier, code_challenge = generate_pkce_challenge()
 - 
 -     # Prepare state data with all necessary information
 -     state_data = OAuthCallbackState(
 -         provider_id=provider_id,
 -         tenant_id=tenant_id,
 -         server_url=server_url,
 -         metadata=metadata,
 -         client_information=client_information,
 -         code_verifier=code_verifier,
 -         redirect_uri=redirect_url,
 -     )
 - 
 -     # Store state data in Redis and generate secure state key
 -     state_key = _create_secure_redis_state(state_data)
 - 
 -     params = {
 -         "response_type": response_type,
 -         "client_id": client_information.client_id,
 -         "code_challenge": code_challenge,
 -         "code_challenge_method": code_challenge_method,
 -         "redirect_uri": redirect_url,
 -         "state": state_key,
 -     }
 - 
 -     authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
 -     return authorization_url, code_verifier
 - 
 - 
 - def exchange_authorization(
 -     server_url: str,
 -     metadata: Optional[OAuthMetadata],
 -     client_information: OAuthClientInformation,
 -     authorization_code: str,
 -     code_verifier: str,
 -     redirect_uri: str,
 - ) -> OAuthTokens:
 -     """Exchanges an authorization code for an access token."""
 -     grant_type = "authorization_code"
 - 
 -     if metadata:
 -         token_url = metadata.token_endpoint
 -         if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
 -             raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
 -     else:
 -         token_url = urljoin(server_url, "/token")
 - 
 -     params = {
 -         "grant_type": grant_type,
 -         "client_id": client_information.client_id,
 -         "code": authorization_code,
 -         "code_verifier": code_verifier,
 -         "redirect_uri": redirect_uri,
 -     }
 - 
 -     if client_information.client_secret:
 -         params["client_secret"] = client_information.client_secret
 - 
 -     response = requests.post(token_url, data=params)
 -     if not response.ok:
 -         raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
 -     return OAuthTokens.model_validate(response.json())
 - 
 - 
 - def refresh_authorization(
 -     server_url: str,
 -     metadata: Optional[OAuthMetadata],
 -     client_information: OAuthClientInformation,
 -     refresh_token: str,
 - ) -> OAuthTokens:
 -     """Exchange a refresh token for an updated access token."""
 -     grant_type = "refresh_token"
 - 
 -     if metadata:
 -         token_url = metadata.token_endpoint
 -         if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
 -             raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
 -     else:
 -         token_url = urljoin(server_url, "/token")
 - 
 -     params = {
 -         "grant_type": grant_type,
 -         "client_id": client_information.client_id,
 -         "refresh_token": refresh_token,
 -     }
 - 
 -     if client_information.client_secret:
 -         params["client_secret"] = client_information.client_secret
 - 
 -     response = requests.post(token_url, data=params)
 -     if not response.ok:
 -         raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
 -     return OAuthTokens.parse_obj(response.json())
 - 
 - 
 - def register_client(
 -     server_url: str,
 -     metadata: Optional[OAuthMetadata],
 -     client_metadata: OAuthClientMetadata,
 - ) -> OAuthClientInformationFull:
 -     """Performs OAuth 2.0 Dynamic Client Registration."""
 -     if metadata:
 -         if not metadata.registration_endpoint:
 -             raise ValueError("Incompatible auth server: does not support dynamic client registration")
 -         registration_url = metadata.registration_endpoint
 -     else:
 -         registration_url = urljoin(server_url, "/register")
 - 
 -     response = requests.post(
 -         registration_url,
 -         json=client_metadata.model_dump(),
 -         headers={"Content-Type": "application/json"},
 -     )
 -     if not response.ok:
 -         response.raise_for_status()
 -     return OAuthClientInformationFull.model_validate(response.json())
 - 
 - 
 - def auth(
 -     provider: OAuthClientProvider,
 -     server_url: str,
 -     authorization_code: Optional[str] = None,
 -     state_param: Optional[str] = None,
 -     for_list: bool = False,
 - ) -> dict[str, str]:
 -     """Orchestrates the full auth flow with a server using secure Redis state storage."""
 -     metadata = discover_oauth_metadata(server_url)
 - 
 -     # Handle client registration if needed
 -     client_information = provider.client_information()
 -     if not client_information:
 -         if authorization_code is not None:
 -             raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
 -         try:
 -             full_information = register_client(server_url, metadata, provider.client_metadata)
 -         except requests.RequestException as e:
 -             raise ValueError(f"Could not register OAuth client: {e}")
 -         provider.save_client_information(full_information)
 -         client_information = full_information
 - 
 -     # Exchange authorization code for tokens
 -     if authorization_code is not None:
 -         if not state_param:
 -             raise ValueError("State parameter is required when exchanging authorization code")
 - 
 -         try:
 -             # Retrieve state data from Redis using state key
 -             full_state_data = _retrieve_redis_state(state_param)
 - 
 -             code_verifier = full_state_data.code_verifier
 -             redirect_uri = full_state_data.redirect_uri
 - 
 -             if not code_verifier or not redirect_uri:
 -                 raise ValueError("Missing code_verifier or redirect_uri in state data")
 - 
 -         except (json.JSONDecodeError, ValueError) as e:
 -             raise ValueError(f"Invalid state parameter: {e}")
 - 
 -         tokens = exchange_authorization(
 -             server_url,
 -             metadata,
 -             client_information,
 -             authorization_code,
 -             code_verifier,
 -             redirect_uri,
 -         )
 -         provider.save_tokens(tokens)
 -         return {"result": "success"}
 - 
 -     provider_tokens = provider.tokens()
 - 
 -     # Handle token refresh or new authorization
 -     if provider_tokens and provider_tokens.refresh_token:
 -         try:
 -             new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
 -             provider.save_tokens(new_tokens)
 -             return {"result": "success"}
 -         except Exception as e:
 -             raise ValueError(f"Could not refresh OAuth tokens: {e}")
 - 
 -     # Start new authorization flow
 -     authorization_url, code_verifier = start_authorization(
 -         server_url,
 -         metadata,
 -         client_information,
 -         provider.redirect_url,
 -         provider.mcp_provider.id,
 -         provider.mcp_provider.tenant_id,
 -     )
 - 
 -     provider.save_code_verifier(code_verifier)
 -     return {"authorization_url": authorization_url}
 
 
  |