Selaa lähdekoodia

fix: tool provider deadlock (#24532)

Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
tags/1.8.0
Maries 2 kuukautta sitten
vanhempi
commit
c06cfcbb5a
No account linked to committer's email address

+ 4
- 2
api/core/tools/tool_manager.py Näytä tiedosto



import sqlalchemy as sa import sqlalchemy as sa
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from yarl import URL from yarl import URL


import contexts import contexts
WHERE tenant_id = :tenant_id WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC ORDER BY tenant_id, provider, is_default DESC, created_at DESC
""" """
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
with Session(db.engine, autoflush=False) as session:
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()


@classmethod @classmethod
def list_providers_from_api( def list_providers_from_api(

+ 43
- 44
api/services/tools/builtin_tools_manage_service.py Näytä tiedosto

check if oauth system client exists check if oauth system client exists
""" """
tool_provider = ToolProviderID(provider_name) tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = ( system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient) session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
check if oauth custom client is enabled check if oauth custom client is enabled
""" """
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = ( user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
with Session(db.engine).no_autoflush as session:
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = ( user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
# get all builtin providers # get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id) provider_controllers = ToolManager.list_builtin_providers(tenant_id)


with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# get all user added providers
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)


# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))


# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)


result: list[ToolProviderApiEntity] = []
result: list[ToolProviderApiEntity] = []


for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue

# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)


# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)


# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)

tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
) )
)


result.append(user_builtin_provider)
except Exception as e:
raise e
result.append(user_builtin_provider)
except Exception as e:
raise e


return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)


1.if the default provider exists, return the default provider 1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider 2.if the default provider does not exist, return the oldest provider
""" """
with Session(db.engine) as session:
with Session(db.engine, autoflush=False) as session:
try: try:
full_provider_name = provider_name full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name) provider_id_entity = ToolProviderID(provider_name)

Loading…
Peruuta
Tallenna