| from typing import Optional | from typing import Optional | ||||
| from urllib.parse import urljoin | from urllib.parse import urljoin | ||||
| import requests | |||||
| import httpx | |||||
| from pydantic import BaseModel, ValidationError | from pydantic import BaseModel, ValidationError | ||||
| from core.mcp.auth.auth_provider import OAuthClientProvider | from core.mcp.auth.auth_provider import OAuthClientProvider | ||||
| try: | try: | ||||
| headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} | headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} | ||||
| response = requests.get(url, headers=headers) | |||||
| response = httpx.get(url, headers=headers) | |||||
| if response.status_code == 404: | if response.status_code == 404: | ||||
| return None | return None | ||||
| if not response.ok: | |||||
| if not response.is_success: | |||||
| raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") | raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") | ||||
| return OAuthMetadata.model_validate(response.json()) | return OAuthMetadata.model_validate(response.json()) | ||||
| except requests.RequestException as e: | |||||
| if isinstance(e, requests.ConnectionError): | |||||
| response = requests.get(url) | |||||
| except httpx.RequestError as e: | |||||
| if isinstance(e, httpx.ConnectError): | |||||
| response = httpx.get(url) | |||||
| if response.status_code == 404: | if response.status_code == 404: | ||||
| return None | return None | ||||
| if not response.ok: | |||||
| if not response.is_success: | |||||
| raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") | raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") | ||||
| return OAuthMetadata.model_validate(response.json()) | return OAuthMetadata.model_validate(response.json()) | ||||
| raise | raise | ||||
| if client_information.client_secret: | if client_information.client_secret: | ||||
| params["client_secret"] = client_information.client_secret | params["client_secret"] = client_information.client_secret | ||||
| response = requests.post(token_url, data=params) | |||||
| if not response.ok: | |||||
| response = httpx.post(token_url, data=params) | |||||
| if not response.is_success: | |||||
| raise ValueError(f"Token exchange failed: HTTP {response.status_code}") | raise ValueError(f"Token exchange failed: HTTP {response.status_code}") | ||||
| return OAuthTokens.model_validate(response.json()) | return OAuthTokens.model_validate(response.json()) | ||||
| if client_information.client_secret: | if client_information.client_secret: | ||||
| params["client_secret"] = client_information.client_secret | params["client_secret"] = client_information.client_secret | ||||
| response = requests.post(token_url, data=params) | |||||
| if not response.ok: | |||||
| response = httpx.post(token_url, data=params) | |||||
| if not response.is_success: | |||||
| raise ValueError(f"Token refresh failed: HTTP {response.status_code}") | raise ValueError(f"Token refresh failed: HTTP {response.status_code}") | ||||
| return OAuthTokens.model_validate(response.json()) | return OAuthTokens.model_validate(response.json()) | ||||
| else: | else: | ||||
| registration_url = urljoin(server_url, "/register") | registration_url = urljoin(server_url, "/register") | ||||
| response = requests.post( | |||||
| response = httpx.post( | |||||
| registration_url, | registration_url, | ||||
| json=client_metadata.model_dump(), | json=client_metadata.model_dump(), | ||||
| headers={"Content-Type": "application/json"}, | headers={"Content-Type": "application/json"}, | ||||
| ) | ) | ||||
| if not response.ok: | |||||
| if not response.is_success: | |||||
| response.raise_for_status() | response.raise_for_status() | ||||
| return OAuthClientInformationFull.model_validate(response.json()) | return OAuthClientInformationFull.model_validate(response.json()) | ||||
| raise ValueError("Existing OAuth client information is required when exchanging an authorization code") | raise ValueError("Existing OAuth client information is required when exchanging an authorization code") | ||||
| try: | try: | ||||
| full_information = register_client(server_url, metadata, provider.client_metadata) | full_information = register_client(server_url, metadata, provider.client_metadata) | ||||
| except requests.RequestException as e: | |||||
| except httpx.RequestError as e: | |||||
| raise ValueError(f"Could not register OAuth client: {e}") | raise ValueError(f"Could not register OAuth client: {e}") | ||||
| provider.save_client_information(full_information) | provider.save_client_information(full_information) | ||||
| client_information = full_information | client_information = full_information |