You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

auth_provider.py 3.3KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Optional
  2. from configs import dify_config
  3. from core.mcp.types import (
  4. OAuthClientInformation,
  5. OAuthClientInformationFull,
  6. OAuthClientMetadata,
  7. OAuthTokens,
  8. )
  9. from models.tools import MCPToolProvider
  10. from services.tools.mcp_tools_manage_service import MCPToolManageService
  11. class OAuthClientProvider:
  12. mcp_provider: MCPToolProvider
  13. def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
  14. if for_list:
  15. self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
  16. else:
  17. self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
  18. @property
  19. def redirect_url(self) -> str:
  20. """The URL to redirect the user agent to after authorization."""
  21. return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
  22. @property
  23. def client_metadata(self) -> OAuthClientMetadata:
  24. """Metadata about this OAuth client."""
  25. return OAuthClientMetadata(
  26. redirect_uris=[self.redirect_url],
  27. token_endpoint_auth_method="none",
  28. grant_types=["authorization_code", "refresh_token"],
  29. response_types=["code"],
  30. client_name="Dify",
  31. client_uri="https://github.com/langgenius/dify",
  32. )
  33. def client_information(self) -> Optional[OAuthClientInformation]:
  34. """Loads information about this OAuth client."""
  35. client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
  36. if not client_information:
  37. return None
  38. return OAuthClientInformation.model_validate(client_information)
  39. def save_client_information(self, client_information: OAuthClientInformationFull):
  40. """Saves client information after dynamic registration."""
  41. MCPToolManageService.update_mcp_provider_credentials(
  42. self.mcp_provider,
  43. {"client_information": client_information.model_dump()},
  44. )
  45. def tokens(self) -> Optional[OAuthTokens]:
  46. """Loads any existing OAuth tokens for the current session."""
  47. credentials = self.mcp_provider.decrypted_credentials
  48. if not credentials:
  49. return None
  50. return OAuthTokens(
  51. access_token=credentials.get("access_token", ""),
  52. token_type=credentials.get("token_type", "Bearer"),
  53. expires_in=int(credentials.get("expires_in", "3600") or 3600),
  54. refresh_token=credentials.get("refresh_token", ""),
  55. )
  56. def save_tokens(self, tokens: OAuthTokens):
  57. """Stores new OAuth tokens for the current session."""
  58. # update mcp provider credentials
  59. token_dict = tokens.model_dump()
  60. MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
  61. def save_code_verifier(self, code_verifier: str):
  62. """Saves a PKCE code verifier for the current session."""
  63. MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
  64. def code_verifier(self) -> str:
  65. """Loads the PKCE code verifier for the current session."""
  66. # get code verifier from mcp provider credentials
  67. return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))