You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auth_flow.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import base64
  2. import hashlib
  3. import json
  4. import os
  5. import secrets
  6. import urllib.parse
  7. from typing import Optional
  8. from urllib.parse import urljoin, urlparse
  9. import httpx
  10. from pydantic import BaseModel, ValidationError
  11. from core.mcp.auth.auth_provider import OAuthClientProvider
  12. from core.mcp.types import (
  13. OAuthClientInformation,
  14. OAuthClientInformationFull,
  15. OAuthClientMetadata,
  16. OAuthMetadata,
  17. OAuthTokens,
  18. )
  19. from extensions.ext_redis import redis_client
  20. LATEST_PROTOCOL_VERSION = "1.0"
  21. OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
  22. OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
  23. class OAuthCallbackState(BaseModel):
  24. provider_id: str
  25. tenant_id: str
  26. server_url: str
  27. metadata: OAuthMetadata | None = None
  28. client_information: OAuthClientInformation
  29. code_verifier: str
  30. redirect_uri: str
  31. def generate_pkce_challenge() -> tuple[str, str]:
  32. """Generate PKCE challenge and verifier."""
  33. code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
  34. code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
  35. code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
  36. code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
  37. code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
  38. return code_verifier, code_challenge
  39. def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
  40. """Create a secure state parameter by storing state data in Redis and returning a random state key."""
  41. # Generate a secure random state key
  42. state_key = secrets.token_urlsafe(32)
  43. # Store the state data in Redis with expiration
  44. redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
  45. redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
  46. return state_key
  47. def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
  48. """Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
  49. redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
  50. # Get state data from Redis
  51. state_data = redis_client.get(redis_key)
  52. if not state_data:
  53. raise ValueError("State parameter has expired or does not exist")
  54. # Delete the state data from Redis immediately after retrieval to prevent reuse
  55. redis_client.delete(redis_key)
  56. try:
  57. # Parse and validate the state data
  58. oauth_state = OAuthCallbackState.model_validate_json(state_data)
  59. return oauth_state
  60. except ValidationError as e:
  61. raise ValueError(f"Invalid state parameter: {str(e)}")
  62. def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
  63. """Handle the callback from the OAuth provider."""
  64. # Retrieve state data from Redis (state is automatically deleted after retrieval)
  65. full_state_data = _retrieve_redis_state(state_key)
  66. tokens = exchange_authorization(
  67. full_state_data.server_url,
  68. full_state_data.metadata,
  69. full_state_data.client_information,
  70. authorization_code,
  71. full_state_data.code_verifier,
  72. full_state_data.redirect_uri,
  73. )
  74. provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
  75. provider.save_tokens(tokens)
  76. return full_state_data
  77. def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
  78. """Check if the server supports OAuth 2.0 Resource Discovery."""
  79. b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True)
  80. url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
  81. if b_query:
  82. url_for_resource_discovery += f"?{b_query}"
  83. if b_fragment:
  84. url_for_resource_discovery += f"#{b_fragment}"
  85. try:
  86. headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
  87. response = httpx.get(url_for_resource_discovery, headers=headers)
  88. if 200 <= response.status_code < 300:
  89. body = response.json()
  90. if "authorization_server_url" in body:
  91. return True, body["authorization_server_url"][0]
  92. else:
  93. return False, ""
  94. return False, ""
  95. except httpx.RequestError as e:
  96. # Not support resource discovery, fall back to well-known OAuth metadata
  97. return False, ""
  98. def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
  99. """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
  100. # First check if the server supports OAuth 2.0 Resource Discovery
  101. support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
  102. if support_resource_discovery:
  103. url = oauth_discovery_url
  104. else:
  105. url = urljoin(server_url, "/.well-known/oauth-authorization-server")
  106. try:
  107. headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
  108. response = httpx.get(url, headers=headers)
  109. if response.status_code == 404:
  110. return None
  111. if not response.is_success:
  112. raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
  113. return OAuthMetadata.model_validate(response.json())
  114. except httpx.RequestError as e:
  115. if isinstance(e, httpx.ConnectError):
  116. response = httpx.get(url)
  117. if response.status_code == 404:
  118. return None
  119. if not response.is_success:
  120. raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
  121. return OAuthMetadata.model_validate(response.json())
  122. raise
  123. def start_authorization(
  124. server_url: str,
  125. metadata: Optional[OAuthMetadata],
  126. client_information: OAuthClientInformation,
  127. redirect_url: str,
  128. provider_id: str,
  129. tenant_id: str,
  130. ) -> tuple[str, str]:
  131. """Begins the authorization flow with secure Redis state storage."""
  132. response_type = "code"
  133. code_challenge_method = "S256"
  134. if metadata:
  135. authorization_url = metadata.authorization_endpoint
  136. if response_type not in metadata.response_types_supported:
  137. raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
  138. if (
  139. not metadata.code_challenge_methods_supported
  140. or code_challenge_method not in metadata.code_challenge_methods_supported
  141. ):
  142. raise ValueError(
  143. f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
  144. )
  145. else:
  146. authorization_url = urljoin(server_url, "/authorize")
  147. code_verifier, code_challenge = generate_pkce_challenge()
  148. # Prepare state data with all necessary information
  149. state_data = OAuthCallbackState(
  150. provider_id=provider_id,
  151. tenant_id=tenant_id,
  152. server_url=server_url,
  153. metadata=metadata,
  154. client_information=client_information,
  155. code_verifier=code_verifier,
  156. redirect_uri=redirect_url,
  157. )
  158. # Store state data in Redis and generate secure state key
  159. state_key = _create_secure_redis_state(state_data)
  160. params = {
  161. "response_type": response_type,
  162. "client_id": client_information.client_id,
  163. "code_challenge": code_challenge,
  164. "code_challenge_method": code_challenge_method,
  165. "redirect_uri": redirect_url,
  166. "state": state_key,
  167. }
  168. authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
  169. return authorization_url, code_verifier
  170. def exchange_authorization(
  171. server_url: str,
  172. metadata: Optional[OAuthMetadata],
  173. client_information: OAuthClientInformation,
  174. authorization_code: str,
  175. code_verifier: str,
  176. redirect_uri: str,
  177. ) -> OAuthTokens:
  178. """Exchanges an authorization code for an access token."""
  179. grant_type = "authorization_code"
  180. if metadata:
  181. token_url = metadata.token_endpoint
  182. if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
  183. raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
  184. else:
  185. token_url = urljoin(server_url, "/token")
  186. params = {
  187. "grant_type": grant_type,
  188. "client_id": client_information.client_id,
  189. "code": authorization_code,
  190. "code_verifier": code_verifier,
  191. "redirect_uri": redirect_uri,
  192. }
  193. if client_information.client_secret:
  194. params["client_secret"] = client_information.client_secret
  195. response = httpx.post(token_url, data=params)
  196. if not response.is_success:
  197. raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
  198. return OAuthTokens.model_validate(response.json())
  199. def refresh_authorization(
  200. server_url: str,
  201. metadata: Optional[OAuthMetadata],
  202. client_information: OAuthClientInformation,
  203. refresh_token: str,
  204. ) -> OAuthTokens:
  205. """Exchange a refresh token for an updated access token."""
  206. grant_type = "refresh_token"
  207. if metadata:
  208. token_url = metadata.token_endpoint
  209. if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
  210. raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
  211. else:
  212. token_url = urljoin(server_url, "/token")
  213. params = {
  214. "grant_type": grant_type,
  215. "client_id": client_information.client_id,
  216. "refresh_token": refresh_token,
  217. }
  218. if client_information.client_secret:
  219. params["client_secret"] = client_information.client_secret
  220. response = httpx.post(token_url, data=params)
  221. if not response.is_success:
  222. raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
  223. return OAuthTokens.model_validate(response.json())
  224. def register_client(
  225. server_url: str,
  226. metadata: Optional[OAuthMetadata],
  227. client_metadata: OAuthClientMetadata,
  228. ) -> OAuthClientInformationFull:
  229. """Performs OAuth 2.0 Dynamic Client Registration."""
  230. if metadata:
  231. if not metadata.registration_endpoint:
  232. raise ValueError("Incompatible auth server: does not support dynamic client registration")
  233. registration_url = metadata.registration_endpoint
  234. else:
  235. registration_url = urljoin(server_url, "/register")
  236. response = httpx.post(
  237. registration_url,
  238. json=client_metadata.model_dump(),
  239. headers={"Content-Type": "application/json"},
  240. )
  241. if not response.is_success:
  242. response.raise_for_status()
  243. return OAuthClientInformationFull.model_validate(response.json())
  244. def auth(
  245. provider: OAuthClientProvider,
  246. server_url: str,
  247. authorization_code: Optional[str] = None,
  248. state_param: Optional[str] = None,
  249. for_list: bool = False,
  250. ) -> dict[str, str]:
  251. """Orchestrates the full auth flow with a server using secure Redis state storage."""
  252. metadata = discover_oauth_metadata(server_url)
  253. # Handle client registration if needed
  254. client_information = provider.client_information()
  255. if not client_information:
  256. if authorization_code is not None:
  257. raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
  258. try:
  259. full_information = register_client(server_url, metadata, provider.client_metadata)
  260. except httpx.RequestError as e:
  261. raise ValueError(f"Could not register OAuth client: {e}")
  262. provider.save_client_information(full_information)
  263. client_information = full_information
  264. # Exchange authorization code for tokens
  265. if authorization_code is not None:
  266. if not state_param:
  267. raise ValueError("State parameter is required when exchanging authorization code")
  268. try:
  269. # Retrieve state data from Redis using state key
  270. full_state_data = _retrieve_redis_state(state_param)
  271. code_verifier = full_state_data.code_verifier
  272. redirect_uri = full_state_data.redirect_uri
  273. if not code_verifier or not redirect_uri:
  274. raise ValueError("Missing code_verifier or redirect_uri in state data")
  275. except (json.JSONDecodeError, ValueError) as e:
  276. raise ValueError(f"Invalid state parameter: {e}")
  277. tokens = exchange_authorization(
  278. server_url,
  279. metadata,
  280. client_information,
  281. authorization_code,
  282. code_verifier,
  283. redirect_uri,
  284. )
  285. provider.save_tokens(tokens)
  286. return {"result": "success"}
  287. provider_tokens = provider.tokens()
  288. # Handle token refresh or new authorization
  289. if provider_tokens and provider_tokens.refresh_token:
  290. try:
  291. new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
  292. provider.save_tokens(new_tokens)
  293. return {"result": "success"}
  294. except Exception as e:
  295. raise ValueError(f"Could not refresh OAuth tokens: {e}")
  296. # Start new authorization flow
  297. authorization_url, code_verifier = start_authorization(
  298. server_url,
  299. metadata,
  300. client_information,
  301. provider.redirect_url,
  302. provider.mcp_provider.id,
  303. provider.mcp_provider.tenant_id,
  304. )
  305. provider.save_code_verifier(code_verifier)
  306. return {"authorization_url": authorization_url}