|
|
|
@@ -453,7 +453,7 @@ class BuiltinToolManageService: |
|
|
|
check if oauth system client exists |
|
|
|
""" |
|
|
|
tool_provider = ToolProviderID(provider_name) |
|
|
|
with Session(db.engine).no_autoflush as session: |
|
|
|
with Session(db.engine, autoflush=False) as session: |
|
|
|
system_client: ToolOAuthSystemClient | None = ( |
|
|
|
session.query(ToolOAuthSystemClient) |
|
|
|
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) |
|
|
|
@@ -467,7 +467,7 @@ class BuiltinToolManageService: |
|
|
|
check if oauth custom client is enabled |
|
|
|
""" |
|
|
|
tool_provider = ToolProviderID(provider) |
|
|
|
with Session(db.engine).no_autoflush as session: |
|
|
|
with Session(db.engine, autoflush=False) as session: |
|
|
|
user_client: ToolOAuthTenantClient | None = ( |
|
|
|
session.query(ToolOAuthTenantClient) |
|
|
|
.filter_by( |
|
|
|
@@ -492,7 +492,7 @@ class BuiltinToolManageService: |
|
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], |
|
|
|
cache=NoOpProviderCredentialCache(), |
|
|
|
) |
|
|
|
with Session(db.engine).no_autoflush as session: |
|
|
|
with Session(db.engine, autoflush=False) as session: |
|
|
|
user_client: ToolOAuthTenantClient | None = ( |
|
|
|
session.query(ToolOAuthTenantClient) |
|
|
|
.filter_by( |
|
|
|
@@ -546,54 +546,53 @@ class BuiltinToolManageService: |
|
|
|
# get all builtin providers |
|
|
|
provider_controllers = ToolManager.list_builtin_providers(tenant_id) |
|
|
|
|
|
|
|
with db.session.no_autoflush: |
|
|
|
# get all user added providers |
|
|
|
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) |
|
|
|
# get all user added providers |
|
|
|
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) |
|
|
|
|
|
|
|
# rewrite db_providers |
|
|
|
for db_provider in db_providers: |
|
|
|
db_provider.provider = str(ToolProviderID(db_provider.provider)) |
|
|
|
# rewrite db_providers |
|
|
|
for db_provider in db_providers: |
|
|
|
db_provider.provider = str(ToolProviderID(db_provider.provider)) |
|
|
|
|
|
|
|
# find provider |
|
|
|
def find_provider(provider): |
|
|
|
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) |
|
|
|
# find provider |
|
|
|
def find_provider(provider): |
|
|
|
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) |
|
|
|
|
|
|
|
result: list[ToolProviderApiEntity] = [] |
|
|
|
result: list[ToolProviderApiEntity] = [] |
|
|
|
|
|
|
|
for provider_controller in provider_controllers: |
|
|
|
try: |
|
|
|
# handle include, exclude |
|
|
|
if is_filtered( |
|
|
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore |
|
|
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore |
|
|
|
data=provider_controller, |
|
|
|
name_func=lambda x: x.identity.name, |
|
|
|
): |
|
|
|
continue |
|
|
|
for provider_controller in provider_controllers: |
|
|
|
try: |
|
|
|
# handle include, exclude |
|
|
|
if is_filtered( |
|
|
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore |
|
|
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore |
|
|
|
data=provider_controller, |
|
|
|
name_func=lambda x: x.identity.name, |
|
|
|
): |
|
|
|
continue |
|
|
|
|
|
|
|
# convert provider controller to user provider |
|
|
|
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( |
|
|
|
provider_controller=provider_controller, |
|
|
|
db_provider=find_provider(provider_controller.entity.identity.name), |
|
|
|
decrypt_credentials=True, |
|
|
|
) |
|
|
|
|
|
|
|
# convert provider controller to user provider |
|
|
|
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( |
|
|
|
provider_controller=provider_controller, |
|
|
|
db_provider=find_provider(provider_controller.entity.identity.name), |
|
|
|
decrypt_credentials=True, |
|
|
|
) |
|
|
|
# add icon |
|
|
|
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) |
|
|
|
|
|
|
|
# add icon |
|
|
|
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) |
|
|
|
|
|
|
|
tools = provider_controller.get_tools() |
|
|
|
for tool in tools or []: |
|
|
|
user_builtin_provider.tools.append( |
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity( |
|
|
|
tenant_id=tenant_id, |
|
|
|
tool=tool, |
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller), |
|
|
|
) |
|
|
|
tools = provider_controller.get_tools() |
|
|
|
for tool in tools or []: |
|
|
|
user_builtin_provider.tools.append( |
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity( |
|
|
|
tenant_id=tenant_id, |
|
|
|
tool=tool, |
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller), |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
result.append(user_builtin_provider) |
|
|
|
except Exception as e: |
|
|
|
raise e |
|
|
|
result.append(user_builtin_provider) |
|
|
|
except Exception as e: |
|
|
|
raise e |
|
|
|
|
|
|
|
return BuiltinToolProviderSort.sort(result) |
|
|
|
|
|
|
|
@@ -604,7 +603,7 @@ class BuiltinToolManageService: |
|
|
|
1.if the default provider exists, return the default provider |
|
|
|
2.if the default provider does not exist, return the oldest provider |
|
|
|
""" |
|
|
|
with Session(db.engine) as session: |
|
|
|
with Session(db.engine, autoflush=False) as session: |
|
|
|
try: |
|
|
|
full_provider_name = provider_name |
|
|
|
provider_id_entity = ToolProviderID(provider_name) |