| raise Forbidden("no oauth available client config found for this tool provider") | raise Forbidden("no oauth available client config found for this tool provider") | ||||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" | redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" | ||||
| credentials = oauth_handler.get_credentials( | |||||
| credentials_response = oauth_handler.get_credentials( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| user_id=user_id, | user_id=user_id, | ||||
| plugin_id=plugin_id, | plugin_id=plugin_id, | ||||
| redirect_uri=redirect_uri, | redirect_uri=redirect_uri, | ||||
| system_credentials=oauth_client_params, | system_credentials=oauth_client_params, | ||||
| request=request, | request=request, | ||||
| ).credentials | |||||
| ) | |||||
| credentials = credentials_response.credentials | |||||
| expires_at = credentials_response.expires_at | |||||
| if not credentials: | if not credentials: | ||||
| raise Exception("the plugin credentials failed") | raise Exception("the plugin credentials failed") | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| provider=provider, | provider=provider, | ||||
| credentials=dict(credentials), | credentials=dict(credentials), | ||||
| expires_at=expires_at, | |||||
| api_type=CredentialType.OAUTH2, | api_type=CredentialType.OAUTH2, | ||||
| ) | ) | ||||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") | return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") |
| class PluginOAuthCredentialsResponse(BaseModel): | class PluginOAuthCredentialsResponse(BaseModel): | ||||
| metadata: Mapping[str, Any] = Field( | |||||
| default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." | |||||
| ) | |||||
| expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") | |||||
| credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") | credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") | ||||
| except Exception as e: | except Exception as e: | ||||
| raise ValueError(f"Error getting credentials: {e}") | raise ValueError(f"Error getting credentials: {e}") | ||||
| def refresh_credentials( | |||||
| self, | |||||
| tenant_id: str, | |||||
| user_id: str, | |||||
| plugin_id: str, | |||||
| provider: str, | |||||
| redirect_uri: str, | |||||
| system_credentials: Mapping[str, Any], | |||||
| credentials: Mapping[str, Any], | |||||
| ) -> PluginOAuthCredentialsResponse: | |||||
| try: | |||||
| response = self._request_with_plugin_daemon_response_stream( | |||||
| "POST", | |||||
| f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", | |||||
| PluginOAuthCredentialsResponse, | |||||
| data={ | |||||
| "user_id": user_id, | |||||
| "data": { | |||||
| "provider": provider, | |||||
| "redirect_uri": redirect_uri, | |||||
| "system_credentials": system_credentials, | |||||
| "credentials": credentials, | |||||
| }, | |||||
| }, | |||||
| headers={ | |||||
| "X-Plugin-ID": plugin_id, | |||||
| "Content-Type": "application/json", | |||||
| }, | |||||
| ) | |||||
| for resp in response: | |||||
| return resp | |||||
| raise ValueError("No response received from plugin daemon for refresh credentials request.") | |||||
| except Exception as e: | |||||
| raise ValueError(f"Error refreshing credentials: {e}") | |||||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | def _convert_request_to_raw_data(self, request: Request) -> bytes: | ||||
| """ | """ | ||||
| Convert a Request object to raw HTTP data. | Convert a Request object to raw HTTP data. |
| import json | import json | ||||
| import logging | import logging | ||||
| import mimetypes | import mimetypes | ||||
| from collections.abc import Generator | |||||
| import time | |||||
| from collections.abc import Generator, Mapping | |||||
| from os import listdir, path | from os import listdir, path | ||||
| from threading import Lock | from threading import Lock | ||||
| from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast | from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast | ||||
| from pydantic import TypeAdapter | |||||
| from yarl import URL | from yarl import URL | ||||
| import contexts | import contexts | ||||
| from core.helper.provider_cache import ToolProviderCredentialsCache | from core.helper.provider_cache import ToolProviderCredentialsCache | ||||
| from core.plugin.entities.plugin import ToolProviderID | from core.plugin.entities.plugin import ToolProviderID | ||||
| from core.plugin.impl.oauth import OAuthHandler | |||||
| from core.plugin.impl.tool import PluginToolManager | from core.plugin.impl.tool import PluginToolManager | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| from core.tools.__base.tool_runtime import ToolRuntime | from core.tools.__base.tool_runtime import ToolRuntime | ||||
| tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id | tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id | ||||
| ), | ), | ||||
| ) | ) | ||||
| # decrypt the credentials | |||||
| decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) | |||||
| # check if the credentials is expired | |||||
| if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): | |||||
| # TODO: circular import | |||||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||||
| # refresh the credentials | |||||
| tool_provider = ToolProviderID(provider_id) | |||||
| provider_name = tool_provider.provider_name | |||||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" | |||||
| system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) | |||||
| oauth_handler = OAuthHandler() | |||||
| # refresh the credentials | |||||
| refreshed_credentials = oauth_handler.refresh_credentials( | |||||
| tenant_id=tenant_id, | |||||
| user_id=builtin_provider.user_id, | |||||
| plugin_id=tool_provider.plugin_id, | |||||
| provider=provider_name, | |||||
| redirect_uri=redirect_uri, | |||||
| system_credentials=system_credentials or {}, | |||||
| credentials=decrypted_credentials, | |||||
| ) | |||||
| # update the credentials | |||||
| builtin_provider.encrypted_credentials = ( | |||||
| TypeAdapter(dict[str, Any]) | |||||
| .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) | |||||
| .decode("utf-8") | |||||
| ) | |||||
| builtin_provider.expires_at = refreshed_credentials.expires_at | |||||
| db.session.commit() | |||||
| decrypted_credentials = refreshed_credentials.credentials | |||||
| return cast( | return cast( | ||||
| BuiltinTool, | BuiltinTool, | ||||
| builtin_tool.fork_tool_runtime( | builtin_tool.fork_tool_runtime( | ||||
| runtime=ToolRuntime( | runtime=ToolRuntime( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| credentials=encrypter.decrypt(builtin_provider.credentials), | |||||
| credentials=dict(decrypted_credentials), | |||||
| credential_type=CredentialType.of(builtin_provider.credential_type), | credential_type=CredentialType.of(builtin_provider.credential_type), | ||||
| runtime_parameters={}, | runtime_parameters={}, | ||||
| invoke_from=invoke_from, | invoke_from=invoke_from, |
| """oauth_refresh_token | |||||
| Revision ID: 375fe79ead14 | |||||
| Revises: 1a83934ad6d1 | |||||
| Create Date: 2025-07-22 00:19:45.599636 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '375fe79ead14' | |||||
| down_revision = '1a83934ad6d1' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: | |||||
| batch_op.drop_column('expires_at') | |||||
| # ### end Alembic commands ### |
| credential_type: Mapped[str] = mapped_column( | credential_type: Mapped[str] = mapped_column( | ||||
| db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") | db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") | ||||
| ) | ) | ||||
| expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) | |||||
| @property | @property | ||||
| def credentials(self) -> dict: | def credentials(self) -> dict: |
| class BuiltinToolManageService: | class BuiltinToolManageService: | ||||
| __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 | __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 | ||||
| __DEFAULT_EXPIRES_AT__ = 2147483647 | |||||
| @staticmethod | @staticmethod | ||||
| def delete_custom_oauth_client_params(tenant_id: str, provider: str): | def delete_custom_oauth_client_params(tenant_id: str, provider: str): | ||||
| tenant_id: str, | tenant_id: str, | ||||
| provider: str, | provider: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| expires_at: int = -1, | |||||
| name: str | None = None, | name: str | None = None, | ||||
| ): | ): | ||||
| """ | """ | ||||
| encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), | encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), | ||||
| credential_type=api_type.value, | credential_type=api_type.value, | ||||
| name=name, | name=name, | ||||
| expires_at=expires_at | |||||
| if expires_at is not None | |||||
| else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, | |||||
| ) | ) | ||||
| session.add(db_provider) | session.add(db_provider) |