| @@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource): | |||
| raise Forbidden("no oauth available client config found for this tool provider") | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" | |||
| credentials = oauth_handler.get_credentials( | |||
| credentials_response = oauth_handler.get_credentials( | |||
| tenant_id=tenant_id, | |||
| user_id=user_id, | |||
| plugin_id=plugin_id, | |||
| @@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource): | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=oauth_client_params, | |||
| request=request, | |||
| ).credentials | |||
| ) | |||
| credentials = credentials_response.credentials | |||
| expires_at = credentials_response.expires_at | |||
| if not credentials: | |||
| raise Exception("the plugin credentials failed") | |||
| @@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource): | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| credentials=dict(credentials), | |||
| expires_at=expires_at, | |||
| api_type=CredentialType.OAUTH2, | |||
| ) | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") | |||
| @@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel): | |||
| class PluginOAuthCredentialsResponse(BaseModel): | |||
| metadata: Mapping[str, Any] = Field( | |||
| default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." | |||
| ) | |||
| expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") | |||
| credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") | |||
| @@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient): | |||
| except Exception as e: | |||
| raise ValueError(f"Error getting credentials: {e}") | |||
| def refresh_credentials( | |||
| self, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| plugin_id: str, | |||
| provider: str, | |||
| redirect_uri: str, | |||
| system_credentials: Mapping[str, Any], | |||
| credentials: Mapping[str, Any], | |||
| ) -> PluginOAuthCredentialsResponse: | |||
| try: | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", | |||
| PluginOAuthCredentialsResponse, | |||
| data={ | |||
| "user_id": user_id, | |||
| "data": { | |||
| "provider": provider, | |||
| "redirect_uri": redirect_uri, | |||
| "system_credentials": system_credentials, | |||
| "credentials": credentials, | |||
| }, | |||
| }, | |||
| headers={ | |||
| "X-Plugin-ID": plugin_id, | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for refresh credentials request.") | |||
| except Exception as e: | |||
| raise ValueError(f"Error refreshing credentials: {e}") | |||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | |||
| """ | |||
| Convert a Request object to raw HTTP data. | |||
| @@ -1,16 +1,19 @@ | |||
| import json | |||
| import logging | |||
| import mimetypes | |||
| from collections.abc import Generator | |||
| import time | |||
| from collections.abc import Generator, Mapping | |||
| from os import listdir, path | |||
| from threading import Lock | |||
| from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast | |||
| from pydantic import TypeAdapter | |||
| from yarl import URL | |||
| import contexts | |||
| from core.helper.provider_cache import ToolProviderCredentialsCache | |||
| from core.plugin.entities.plugin import ToolProviderID | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from core.plugin.impl.tool import PluginToolManager | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| @@ -244,12 +247,47 @@ class ToolManager: | |||
| tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id | |||
| ), | |||
| ) | |||
| # decrypt the credentials | |||
| decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) | |||
| # check if the credentials is expired | |||
| if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): | |||
| # TODO: circular import | |||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||
| # refresh the credentials | |||
| tool_provider = ToolProviderID(provider_id) | |||
| provider_name = tool_provider.provider_name | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" | |||
| system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) | |||
| oauth_handler = OAuthHandler() | |||
| # refresh the credentials | |||
| refreshed_credentials = oauth_handler.refresh_credentials( | |||
| tenant_id=tenant_id, | |||
| user_id=builtin_provider.user_id, | |||
| plugin_id=tool_provider.plugin_id, | |||
| provider=provider_name, | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=system_credentials or {}, | |||
| credentials=decrypted_credentials, | |||
| ) | |||
| # update the credentials | |||
| builtin_provider.encrypted_credentials = ( | |||
| TypeAdapter(dict[str, Any]) | |||
| .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) | |||
| .decode("utf-8") | |||
| ) | |||
| builtin_provider.expires_at = refreshed_credentials.expires_at | |||
| db.session.commit() | |||
| decrypted_credentials = refreshed_credentials.credentials | |||
| return cast( | |||
| BuiltinTool, | |||
| builtin_tool.fork_tool_runtime( | |||
| runtime=ToolRuntime( | |||
| tenant_id=tenant_id, | |||
| credentials=encrypter.decrypt(builtin_provider.credentials), | |||
| credentials=dict(decrypted_credentials), | |||
| credential_type=CredentialType.of(builtin_provider.credential_type), | |||
| runtime_parameters={}, | |||
| invoke_from=invoke_from, | |||
| @@ -0,0 +1,34 @@ | |||
| """oauth_refresh_token | |||
| Revision ID: 375fe79ead14 | |||
| Revises: 1a83934ad6d1 | |||
| Create Date: 2025-07-22 00:19:45.599636 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '375fe79ead14' | |||
| down_revision = '1a83934ad6d1' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: | |||
| batch_op.drop_column('expires_at') | |||
| # ### end Alembic commands ### | |||
| @@ -93,6 +93,7 @@ class BuiltinToolProvider(Base): | |||
| credential_type: Mapped[str] = mapped_column( | |||
| db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") | |||
| ) | |||
| expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) | |||
| @property | |||
| def credentials(self) -> dict: | |||
| @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) | |||
| class BuiltinToolManageService: | |||
| __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 | |||
| __DEFAULT_EXPIRES_AT__ = 2147483647 | |||
| @staticmethod | |||
| def delete_custom_oauth_client_params(tenant_id: str, provider: str): | |||
| @@ -212,6 +213,7 @@ class BuiltinToolManageService: | |||
| tenant_id: str, | |||
| provider: str, | |||
| credentials: dict, | |||
| expires_at: int = -1, | |||
| name: str | None = None, | |||
| ): | |||
| """ | |||
| @@ -269,6 +271,9 @@ class BuiltinToolManageService: | |||
| encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), | |||
| credential_type=api_type.value, | |||
| name=name, | |||
| expires_at=expires_at | |||
| if expires_at is not None | |||
| else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, | |||
| ) | |||
| session.add(db_provider) | |||