Browse Source

feat: oauth refresh token (#22744)

Co-authored-by: Yeuoly <admin@srmxy.cn>
tags/1.7.0
Maries 3 months ago
parent
commit
ad67094e54
No account linked to committer's email address

+ 6
- 2
api/controllers/console/workspace/tool_providers.py View File

raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")


redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_client_params, system_credentials=oauth_client_params,
request=request, request=request,
).credentials
)

credentials = credentials_response.credentials
expires_at = credentials_response.expires_at


if not credentials: if not credentials:
raise Exception("the plugin credentials failed") raise Exception("the plugin credentials failed")
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

+ 4
- 0
api/core/plugin/entities/plugin_daemon.py View File





class PluginOAuthCredentialsResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")





+ 35
- 0
api/core/plugin/impl/oauth.py View File

except Exception as e: except Exception as e:
raise ValueError(f"Error getting credentials: {e}") raise ValueError(f"Error getting credentials: {e}")


def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")

def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
Convert a Request object to raw HTTP data. Convert a Request object to raw HTTP data.

+ 40
- 2
api/core/tools/tool_manager.py View File

import json import json
import logging import logging
import mimetypes import mimetypes
from collections.abc import Generator
import time
from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast


from pydantic import TypeAdapter
from yarl import URL from yarl import URL


import contexts import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
), ),
) )

# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)

# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials

return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials),
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type), credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,

+ 34
- 0
api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py View File

"""oauth_refresh_token

Revision ID: 375fe79ead14
Revises: 1a83934ad6d1
Create Date: 2025-07-22 00:19:45.599636

"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '375fe79ead14'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###

with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_column('expires_at')

# ### end Alembic commands ###

+ 1
- 0
api/models/tools.py View File

credential_type: Mapped[str] = mapped_column( credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
) )
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))


@property @property
def credentials(self) -> dict: def credentials(self) -> dict:

+ 5
- 0
api/services/tools/builtin_tools_manage_service.py View File



class BuiltinToolManageService: class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647


@staticmethod @staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str): def delete_custom_oauth_client_params(tenant_id: str, provider: str):
tenant_id: str, tenant_id: str,
provider: str, provider: str,
credentials: dict, credentials: dict,
expires_at: int = -1,
name: str | None = None, name: str | None = None,
): ):
""" """
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value, credential_type=api_type.value,
name=name, name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
) )


session.add(db_provider) session.add(db_provider)

Loading…
Cancel
Save