浏览代码

feat: oauth refresh token (#22744)

Co-authored-by: Yeuoly <admin@srmxy.cn>
tags/1.7.0
Maries 3 个月前
父节点
当前提交
ad67094e54
没有帐户链接到提交者的电子邮件

+ 6
- 2
api/controllers/console/workspace/tool_providers.py 查看文件

@@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider")

redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
@@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
)

credentials = credentials_response.credentials
expires_at = credentials_response.expires_at

if not credentials:
raise Exception("the plugin credentials failed")
@@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

+ 4
- 0
api/core/plugin/entities/plugin_daemon.py 查看文件

@@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):


class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")



+ 35
- 0
api/core/plugin/impl/oauth.py 查看文件

@@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")

def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")

def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
Convert a Request object to raw HTTP data.

+ 40
- 2
api/core/tools/tool_manager.py 查看文件

@@ -1,16 +1,19 @@
import json
import logging
import mimetypes
from collections.abc import Generator
import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

from pydantic import TypeAdapter
from yarl import URL

import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)

# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)

# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials

return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials),
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,

+ 34
- 0
api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py 查看文件

@@ -0,0 +1,34 @@
"""oauth_refresh_token

Revision ID: 375fe79ead14
Revises: 1a83934ad6d1
Create Date: 2025-07-22 00:19:45.599636

"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '375fe79ead14'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###

with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_column('expires_at')

# ### end Alembic commands ###

+ 1
- 0
api/models/tools.py 查看文件

@@ -93,6 +93,7 @@ class BuiltinToolProvider(Base):
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))

@property
def credentials(self) -> dict:

+ 5
- 0
api/services/tools/builtin_tools_manage_service.py 查看文件

@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)

class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647

@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@@ -212,6 +213,7 @@ class BuiltinToolManageService:
tenant_id: str,
provider: str,
credentials: dict,
expires_at: int = -1,
name: str | None = None,
):
"""
@@ -269,6 +271,9 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
)

session.add(db_provider)

正在加载...
取消
保存