You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auth_flow.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import base64
  2. import hashlib
  3. import json
  4. import os
  5. import secrets
  6. import urllib.parse
  7. from typing import Optional
  8. from urllib.parse import urljoin
  9. import requests
  10. from pydantic import BaseModel, ValidationError
  11. from core.mcp.auth.auth_provider import OAuthClientProvider
  12. from core.mcp.types import (
  13. OAuthClientInformation,
  14. OAuthClientInformationFull,
  15. OAuthClientMetadata,
  16. OAuthMetadata,
  17. OAuthTokens,
  18. )
  19. from extensions.ext_redis import redis_client
  20. LATEST_PROTOCOL_VERSION = "1.0"
  21. OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
  22. OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
  23. class OAuthCallbackState(BaseModel):
  24. provider_id: str
  25. tenant_id: str
  26. server_url: str
  27. metadata: OAuthMetadata | None = None
  28. client_information: OAuthClientInformation
  29. code_verifier: str
  30. redirect_uri: str
  31. def generate_pkce_challenge() -> tuple[str, str]:
  32. """Generate PKCE challenge and verifier."""
  33. code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
  34. code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
  35. code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
  36. code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
  37. code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
  38. return code_verifier, code_challenge
  39. def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
  40. """Create a secure state parameter by storing state data in Redis and returning a random state key."""
  41. # Generate a secure random state key
  42. state_key = secrets.token_urlsafe(32)
  43. # Store the state data in Redis with expiration
  44. redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
  45. redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
  46. return state_key
  47. def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
  48. """Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
  49. redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
  50. # Get state data from Redis
  51. state_data = redis_client.get(redis_key)
  52. if not state_data:
  53. raise ValueError("State parameter has expired or does not exist")
  54. # Delete the state data from Redis immediately after retrieval to prevent reuse
  55. redis_client.delete(redis_key)
  56. try:
  57. # Parse and validate the state data
  58. oauth_state = OAuthCallbackState.model_validate_json(state_data)
  59. return oauth_state
  60. except ValidationError as e:
  61. raise ValueError(f"Invalid state parameter: {str(e)}")
  62. def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
  63. """Handle the callback from the OAuth provider."""
  64. # Retrieve state data from Redis (state is automatically deleted after retrieval)
  65. full_state_data = _retrieve_redis_state(state_key)
  66. tokens = exchange_authorization(
  67. full_state_data.server_url,
  68. full_state_data.metadata,
  69. full_state_data.client_information,
  70. authorization_code,
  71. full_state_data.code_verifier,
  72. full_state_data.redirect_uri,
  73. )
  74. provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
  75. provider.save_tokens(tokens)
  76. return full_state_data
  77. def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
  78. """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
  79. url = urljoin(server_url, "/.well-known/oauth-authorization-server")
  80. try:
  81. headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
  82. response = requests.get(url, headers=headers)
  83. if response.status_code == 404:
  84. return None
  85. if not response.ok:
  86. raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
  87. return OAuthMetadata.model_validate(response.json())
  88. except requests.RequestException as e:
  89. if isinstance(e, requests.ConnectionError):
  90. response = requests.get(url)
  91. if response.status_code == 404:
  92. return None
  93. if not response.ok:
  94. raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
  95. return OAuthMetadata.model_validate(response.json())
  96. raise
  97. def start_authorization(
  98. server_url: str,
  99. metadata: Optional[OAuthMetadata],
  100. client_information: OAuthClientInformation,
  101. redirect_url: str,
  102. provider_id: str,
  103. tenant_id: str,
  104. ) -> tuple[str, str]:
  105. """Begins the authorization flow with secure Redis state storage."""
  106. response_type = "code"
  107. code_challenge_method = "S256"
  108. if metadata:
  109. authorization_url = metadata.authorization_endpoint
  110. if response_type not in metadata.response_types_supported:
  111. raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
  112. if (
  113. not metadata.code_challenge_methods_supported
  114. or code_challenge_method not in metadata.code_challenge_methods_supported
  115. ):
  116. raise ValueError(
  117. f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
  118. )
  119. else:
  120. authorization_url = urljoin(server_url, "/authorize")
  121. code_verifier, code_challenge = generate_pkce_challenge()
  122. # Prepare state data with all necessary information
  123. state_data = OAuthCallbackState(
  124. provider_id=provider_id,
  125. tenant_id=tenant_id,
  126. server_url=server_url,
  127. metadata=metadata,
  128. client_information=client_information,
  129. code_verifier=code_verifier,
  130. redirect_uri=redirect_url,
  131. )
  132. # Store state data in Redis and generate secure state key
  133. state_key = _create_secure_redis_state(state_data)
  134. params = {
  135. "response_type": response_type,
  136. "client_id": client_information.client_id,
  137. "code_challenge": code_challenge,
  138. "code_challenge_method": code_challenge_method,
  139. "redirect_uri": redirect_url,
  140. "state": state_key,
  141. }
  142. authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
  143. return authorization_url, code_verifier
  144. def exchange_authorization(
  145. server_url: str,
  146. metadata: Optional[OAuthMetadata],
  147. client_information: OAuthClientInformation,
  148. authorization_code: str,
  149. code_verifier: str,
  150. redirect_uri: str,
  151. ) -> OAuthTokens:
  152. """Exchanges an authorization code for an access token."""
  153. grant_type = "authorization_code"
  154. if metadata:
  155. token_url = metadata.token_endpoint
  156. if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
  157. raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
  158. else:
  159. token_url = urljoin(server_url, "/token")
  160. params = {
  161. "grant_type": grant_type,
  162. "client_id": client_information.client_id,
  163. "code": authorization_code,
  164. "code_verifier": code_verifier,
  165. "redirect_uri": redirect_uri,
  166. }
  167. if client_information.client_secret:
  168. params["client_secret"] = client_information.client_secret
  169. response = requests.post(token_url, data=params)
  170. if not response.ok:
  171. raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
  172. return OAuthTokens.model_validate(response.json())
  173. def refresh_authorization(
  174. server_url: str,
  175. metadata: Optional[OAuthMetadata],
  176. client_information: OAuthClientInformation,
  177. refresh_token: str,
  178. ) -> OAuthTokens:
  179. """Exchange a refresh token for an updated access token."""
  180. grant_type = "refresh_token"
  181. if metadata:
  182. token_url = metadata.token_endpoint
  183. if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
  184. raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
  185. else:
  186. token_url = urljoin(server_url, "/token")
  187. params = {
  188. "grant_type": grant_type,
  189. "client_id": client_information.client_id,
  190. "refresh_token": refresh_token,
  191. }
  192. if client_information.client_secret:
  193. params["client_secret"] = client_information.client_secret
  194. response = requests.post(token_url, data=params)
  195. if not response.ok:
  196. raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
  197. return OAuthTokens.parse_obj(response.json())
  198. def register_client(
  199. server_url: str,
  200. metadata: Optional[OAuthMetadata],
  201. client_metadata: OAuthClientMetadata,
  202. ) -> OAuthClientInformationFull:
  203. """Performs OAuth 2.0 Dynamic Client Registration."""
  204. if metadata:
  205. if not metadata.registration_endpoint:
  206. raise ValueError("Incompatible auth server: does not support dynamic client registration")
  207. registration_url = metadata.registration_endpoint
  208. else:
  209. registration_url = urljoin(server_url, "/register")
  210. response = requests.post(
  211. registration_url,
  212. json=client_metadata.model_dump(),
  213. headers={"Content-Type": "application/json"},
  214. )
  215. if not response.ok:
  216. response.raise_for_status()
  217. return OAuthClientInformationFull.model_validate(response.json())
  218. def auth(
  219. provider: OAuthClientProvider,
  220. server_url: str,
  221. authorization_code: Optional[str] = None,
  222. state_param: Optional[str] = None,
  223. for_list: bool = False,
  224. ) -> dict[str, str]:
  225. """Orchestrates the full auth flow with a server using secure Redis state storage."""
  226. metadata = discover_oauth_metadata(server_url)
  227. # Handle client registration if needed
  228. client_information = provider.client_information()
  229. if not client_information:
  230. if authorization_code is not None:
  231. raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
  232. try:
  233. full_information = register_client(server_url, metadata, provider.client_metadata)
  234. except requests.RequestException as e:
  235. raise ValueError(f"Could not register OAuth client: {e}")
  236. provider.save_client_information(full_information)
  237. client_information = full_information
  238. # Exchange authorization code for tokens
  239. if authorization_code is not None:
  240. if not state_param:
  241. raise ValueError("State parameter is required when exchanging authorization code")
  242. try:
  243. # Retrieve state data from Redis using state key
  244. full_state_data = _retrieve_redis_state(state_param)
  245. code_verifier = full_state_data.code_verifier
  246. redirect_uri = full_state_data.redirect_uri
  247. if not code_verifier or not redirect_uri:
  248. raise ValueError("Missing code_verifier or redirect_uri in state data")
  249. except (json.JSONDecodeError, ValueError) as e:
  250. raise ValueError(f"Invalid state parameter: {e}")
  251. tokens = exchange_authorization(
  252. server_url,
  253. metadata,
  254. client_information,
  255. authorization_code,
  256. code_verifier,
  257. redirect_uri,
  258. )
  259. provider.save_tokens(tokens)
  260. return {"result": "success"}
  261. provider_tokens = provider.tokens()
  262. # Handle token refresh or new authorization
  263. if provider_tokens and provider_tokens.refresh_token:
  264. try:
  265. new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
  266. provider.save_tokens(new_tokens)
  267. return {"result": "success"}
  268. except Exception as e:
  269. raise ValueError(f"Could not refresh OAuth tokens: {e}")
  270. # Start new authorization flow
  271. authorization_url, code_verifier = start_authorization(
  272. server_url,
  273. metadata,
  274. client_information,
  275. provider.redirect_url,
  276. provider.mcp_provider.id,
  277. provider.mcp_provider.tenant_id,
  278. )
  279. provider.save_code_verifier(code_verifier)
  280. return {"authorization_url": authorization_url}