| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750 |
- import json
- import logging
- import re
- from collections.abc import Mapping
- from pathlib import Path
- from typing import Any, Optional
-
- from sqlalchemy.orm import Session
-
- from configs import dify_config
- from constants import HIDDEN_VALUE, UNKNOWN_VALUE
- from core.helper.position_helper import is_filtered
- from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
- from core.plugin.entities.plugin import ToolProviderID
- from core.tools.builtin_tool.provider import BuiltinToolProviderController
- from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
- from core.tools.entities.api_entities import (
- ToolApiEntity,
- ToolProviderApiEntity,
- ToolProviderCredentialApiEntity,
- ToolProviderCredentialInfoApiEntity,
- )
- from core.tools.entities.tool_entities import CredentialType
- from core.tools.errors import ToolProviderNotFoundError
- from core.tools.plugin_tool.provider import PluginToolProviderController
- from core.tools.tool_label_manager import ToolLabelManager
- from core.tools.tool_manager import ToolManager
- from core.tools.utils.encryption import create_provider_encrypter
- from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
- from extensions.ext_database import db
- from extensions.ext_redis import redis_client
- from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
- from services.plugin.plugin_service import PluginService
- from services.tools.tools_transform_service import ToolTransformService
-
- logger = logging.getLogger(__name__)
-
-
- class BuiltinToolManageService:
- __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
- __DEFAULT_EXPIRES_AT__ = 2147483647
-
- @staticmethod
- def delete_custom_oauth_client_params(tenant_id: str, provider: str):
- """
- delete custom oauth client params
- """
- tool_provider = ToolProviderID(provider)
- with Session(db.engine) as session:
- session.query(ToolOAuthTenantClient).filter_by(
- tenant_id=tenant_id,
- provider=tool_provider.provider_name,
- plugin_id=tool_provider.plugin_id,
- ).delete()
- session.commit()
- return {"result": "success"}
-
- @staticmethod
- def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
- """
- get builtin tool provider oauth client schema
- """
- provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
- verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
- tenant_id, provider.plugin_unique_identifier
- )
-
- is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
- tenant_id, provider_name
- )
- is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
- provider_name
- )
- result = {
- "schema": provider.get_oauth_client_schema(),
- "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
- "is_system_oauth_params_exists": is_system_oauth_params_exists,
- "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
- "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
- }
- return result
-
- @staticmethod
- def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
- """
- list builtin tool provider tools
-
- :param tenant_id: the id of the tenant
- :param provider: the name of the provider
-
- :return: the list of tools
- """
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- tools = provider_controller.get_tools()
-
- result: list[ToolApiEntity] = []
- for tool in tools or []:
- result.append(
- ToolTransformService.convert_tool_entity_to_api_entity(
- tool=tool,
- tenant_id=tenant_id,
- labels=ToolLabelManager.get_tool_labels(provider_controller),
- )
- )
-
- return result
-
- @staticmethod
- def get_builtin_tool_provider_info(tenant_id: str, provider: str):
- """
- get builtin tool provider info
- """
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- # check if user has added the provider
- builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
- if builtin_provider is None:
- raise ValueError(f"you have not added provider {provider}")
-
- entity = ToolTransformService.builtin_provider_to_user_provider(
- provider_controller=provider_controller,
- db_provider=builtin_provider,
- decrypt_credentials=True,
- )
-
- entity.original_credentials = {}
- return entity
-
- @staticmethod
- def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
- """
- list builtin provider credentials schema
-
- :param credential_type: credential type
- :param provider_name: the name of the provider
- :param tenant_id: the id of the tenant
- :return: the list of tool providers
- """
- provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
- return provider.get_credentials_schema_by_type(credential_type)
-
- @staticmethod
- def update_builtin_tool_provider(
- user_id: str,
- tenant_id: str,
- provider: str,
- credential_id: str,
- credentials: dict | None = None,
- name: str | None = None,
- ):
- """
- update builtin tool provider
- """
- with Session(db.engine) as session:
- # get if the provider exists
- db_provider = (
- session.query(BuiltinToolProvider)
- .where(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.id == credential_id,
- )
- .first()
- )
- if db_provider is None:
- raise ValueError(f"you have not added provider {provider}")
-
- try:
- if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- if not provider_controller.need_credentials:
- raise ValueError(f"provider {provider} does not need credentials")
-
- encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
- tenant_id, db_provider, provider, provider_controller
- )
-
- original_credentials = encrypter.decrypt(db_provider.credentials)
- new_credentials: dict = {
- key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
- for key, value in credentials.items()
- }
-
- if CredentialType.of(db_provider.credential_type).is_validate_allowed():
- provider_controller.validate_credentials(user_id, new_credentials)
-
- # encrypt credentials
- db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
-
- cache.delete()
-
- # update name if provided
- if name and name != db_provider.name:
- # check if the name is already used
- if (
- session.query(BuiltinToolProvider)
- .filter_by(tenant_id=tenant_id, provider=provider, name=name)
- .count()
- > 0
- ):
- raise ValueError(f"the credential name '{name}' is already used")
-
- db_provider.name = name
-
- session.commit()
- except Exception as e:
- session.rollback()
- raise ValueError(str(e))
- return {"result": "success"}
-
- @staticmethod
- def add_builtin_tool_provider(
- user_id: str,
- api_type: CredentialType,
- tenant_id: str,
- provider: str,
- credentials: dict,
- expires_at: int = -1,
- name: str | None = None,
- ):
- """
- add builtin tool provider
- """
- try:
- with Session(db.engine) as session:
- lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
- with redis_client.lock(lock, timeout=20):
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- if not provider_controller.need_credentials:
- raise ValueError(f"provider {provider} does not need credentials")
-
- provider_count = (
- session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
- )
-
- # check if the provider count is reached the limit
- if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
- raise ValueError(f"you have reached the maximum number of providers for {provider}")
-
- # validate credentials if allowed
- if CredentialType.of(api_type).is_validate_allowed():
- provider_controller.validate_credentials(user_id, credentials)
-
- # generate name if not provided
- if name is None or name == "":
- name = BuiltinToolManageService.generate_builtin_tool_provider_name(
- session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
- )
- else:
- # check if the name is already used
- if (
- session.query(BuiltinToolProvider)
- .filter_by(tenant_id=tenant_id, provider=provider, name=name)
- .count()
- > 0
- ):
- raise ValueError(f"the credential name '{name}' is already used")
-
- # create encrypter
- encrypter, _ = create_provider_encrypter(
- tenant_id=tenant_id,
- config=[
- x.to_basic_provider_config()
- for x in provider_controller.get_credentials_schema_by_type(api_type)
- ],
- cache=NoOpProviderCredentialCache(),
- )
-
- db_provider = BuiltinToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- provider=provider,
- encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
- credential_type=api_type.value,
- name=name,
- expires_at=expires_at
- if expires_at is not None
- else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
- )
-
- session.add(db_provider)
- session.commit()
- except Exception as e:
- session.rollback()
- raise ValueError(str(e))
- return {"result": "success"}
-
- @staticmethod
- def create_tool_encrypter(
- tenant_id: str,
- db_provider: BuiltinToolProvider,
- provider: str,
- provider_controller: BuiltinToolProviderController,
- ):
- encrypter, cache = create_provider_encrypter(
- tenant_id=tenant_id,
- config=[
- x.to_basic_provider_config()
- for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
- ],
- cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
- )
- return encrypter, cache
-
- @staticmethod
- def generate_builtin_tool_provider_name(
- session: Session, tenant_id: str, provider: str, credential_type: CredentialType
- ) -> str:
- try:
- db_providers = (
- session.query(BuiltinToolProvider)
- .filter_by(
- tenant_id=tenant_id,
- provider=provider,
- credential_type=credential_type.value,
- )
- .order_by(BuiltinToolProvider.created_at.desc())
- .all()
- )
-
- # Get the default name pattern
- default_pattern = f"{credential_type.get_name()}"
-
- # Find all names that match the default pattern: "{default_pattern} {number}"
- pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
- numbers = []
-
- for db_provider in db_providers:
- if db_provider.name:
- match = re.match(pattern, db_provider.name.strip())
- if match:
- numbers.append(int(match.group(1)))
-
- # If no default pattern names found, start with 1
- if not numbers:
- return f"{default_pattern} 1"
-
- # Find the next number
- max_number = max(numbers)
- return f"{default_pattern} {max_number + 1}"
- except Exception as e:
- logger.warning("Error generating next provider name for %s: %s", provider, str(e))
- # fallback
- return f"{credential_type.get_name()} 1"
-
- @staticmethod
- def get_builtin_tool_provider_credentials(
- tenant_id: str, provider_name: str
- ) -> list[ToolProviderCredentialApiEntity]:
- """
- get builtin tool provider credentials
- """
- with db.session.no_autoflush:
- providers = (
- db.session.query(BuiltinToolProvider)
- .filter_by(tenant_id=tenant_id, provider=provider_name)
- .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
- .all()
- )
-
- if len(providers) == 0:
- return []
-
- default_provider = providers[0]
- default_provider.is_default = True
- provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
-
- credentials: list[ToolProviderCredentialApiEntity] = []
- encrypters = {}
- for provider in providers:
- credential_type = provider.credential_type
- if credential_type not in encrypters:
- encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
- tenant_id, provider, provider.provider, provider_controller
- )[0]
- encrypter = encrypters[credential_type]
- decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
- credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
- provider=provider,
- credentials=decrypt_credential,
- )
- credentials.append(credential_entity)
- return credentials
-
- @staticmethod
- def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
- """
- get builtin tool provider credential info
- """
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- supported_credential_types = provider_controller.get_supported_credential_types()
- credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
- credential_info = ToolProviderCredentialInfoApiEntity(
- supported_credential_types=supported_credential_types,
- is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
- credentials=credentials,
- )
-
- return credential_info
-
- @staticmethod
- def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
- """
- delete tool provider
- """
- with Session(db.engine) as session:
- db_provider = (
- session.query(BuiltinToolProvider)
- .where(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.id == credential_id,
- )
- .first()
- )
-
- if db_provider is None:
- raise ValueError(f"you have not added provider {provider}")
-
- session.delete(db_provider)
- session.commit()
-
- # delete cache
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- _, cache = BuiltinToolManageService.create_tool_encrypter(
- tenant_id, db_provider, provider, provider_controller
- )
- cache.delete()
-
- return {"result": "success"}
-
- @staticmethod
- def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
- """
- set default provider
- """
- with Session(db.engine) as session:
- # get provider
- target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
- if target_provider is None:
- raise ValueError("provider not found")
-
- # clear default provider
- session.query(BuiltinToolProvider).filter_by(
- tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
- ).update({"is_default": False})
-
- # set new default provider
- target_provider.is_default = True
- session.commit()
- return {"result": "success"}
-
- @staticmethod
- def is_oauth_system_client_exists(provider_name: str) -> bool:
- """
- check if oauth system client exists
- """
- tool_provider = ToolProviderID(provider_name)
- with Session(db.engine).no_autoflush as session:
- system_client: ToolOAuthSystemClient | None = (
- session.query(ToolOAuthSystemClient)
- .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
- .first()
- )
- return system_client is not None
-
- @staticmethod
- def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
- """
- check if oauth custom client is enabled
- """
- tool_provider = ToolProviderID(provider)
- with Session(db.engine).no_autoflush as session:
- user_client: ToolOAuthTenantClient | None = (
- session.query(ToolOAuthTenantClient)
- .filter_by(
- tenant_id=tenant_id,
- provider=tool_provider.provider_name,
- plugin_id=tool_provider.plugin_id,
- enabled=True,
- )
- .first()
- )
- return user_client is not None and user_client.enabled
-
- @staticmethod
- def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
- """
- get builtin tool provider
- """
- tool_provider = ToolProviderID(provider)
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- encrypter, _ = create_provider_encrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
- cache=NoOpProviderCredentialCache(),
- )
- with Session(db.engine).no_autoflush as session:
- user_client: ToolOAuthTenantClient | None = (
- session.query(ToolOAuthTenantClient)
- .filter_by(
- tenant_id=tenant_id,
- provider=tool_provider.provider_name,
- plugin_id=tool_provider.plugin_id,
- enabled=True,
- )
- .first()
- )
- oauth_params: Mapping[str, Any] | None = None
- if user_client:
- oauth_params = encrypter.decrypt(user_client.oauth_params)
- return oauth_params
-
- # only verified provider can use official oauth client
- is_verified = not isinstance(
- provider_controller, PluginToolProviderController
- ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
- if not is_verified:
- return oauth_params
-
- system_client: ToolOAuthSystemClient | None = (
- session.query(ToolOAuthSystemClient)
- .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
- .first()
- )
- if system_client:
- try:
- oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
- except Exception as e:
- raise ValueError(f"Error decrypting system oauth params: {e}")
-
- return oauth_params
-
- @staticmethod
- def get_builtin_tool_provider_icon(provider: str):
- """
- get tool provider icon and it's mimetype
- """
- icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
- icon_bytes = Path(icon_path).read_bytes()
-
- return icon_bytes, mime_type
-
- @staticmethod
- def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
- """
- list builtin tools
- """
- # get all builtin providers
- provider_controllers = ToolManager.list_builtin_providers(tenant_id)
-
- with db.session.no_autoflush:
- # get all user added providers
- db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
-
- # rewrite db_providers
- for db_provider in db_providers:
- db_provider.provider = str(ToolProviderID(db_provider.provider))
-
- # find provider
- def find_provider(provider):
- return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
-
- result: list[ToolProviderApiEntity] = []
-
- for provider_controller in provider_controllers:
- try:
- # handle include, exclude
- if is_filtered(
- include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
- exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
- data=provider_controller,
- name_func=lambda x: x.identity.name,
- ):
- continue
-
- # convert provider controller to user provider
- user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
- provider_controller=provider_controller,
- db_provider=find_provider(provider_controller.entity.identity.name),
- decrypt_credentials=True,
- )
-
- # add icon
- ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
-
- tools = provider_controller.get_tools()
- for tool in tools or []:
- user_builtin_provider.tools.append(
- ToolTransformService.convert_tool_entity_to_api_entity(
- tenant_id=tenant_id,
- tool=tool,
- labels=ToolLabelManager.get_tool_labels(provider_controller),
- )
- )
-
- result.append(user_builtin_provider)
- except Exception as e:
- raise e
-
- return BuiltinToolProviderSort.sort(result)
-
- @staticmethod
- def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
- """
- This method is used to fetch the builtin provider from the database
- 1.if the default provider exists, return the default provider
- 2.if the default provider does not exist, return the oldest provider
- """
- with Session(db.engine) as session:
- try:
- full_provider_name = provider_name
- provider_id_entity = ToolProviderID(provider_name)
- provider_name = provider_id_entity.provider_name
-
- if provider_id_entity.organization != "langgenius":
- provider = (
- session.query(BuiltinToolProvider)
- .where(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == full_provider_name,
- )
- .order_by(
- BuiltinToolProvider.is_default.desc(), # default=True first
- BuiltinToolProvider.created_at.asc(), # oldest first
- )
- .first()
- )
- else:
- provider = (
- session.query(BuiltinToolProvider)
- .where(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == provider_name)
- | (BuiltinToolProvider.provider == full_provider_name),
- )
- .order_by(
- BuiltinToolProvider.is_default.desc(), # default=True first
- BuiltinToolProvider.created_at.asc(), # oldest first
- )
- .first()
- )
-
- if provider is None:
- return None
-
- provider.provider = ToolProviderID(provider.provider).to_string()
- return provider
- except Exception:
- # it's an old provider without organization
- return (
- session.query(BuiltinToolProvider)
- .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
- .order_by(
- BuiltinToolProvider.is_default.desc(), # default=True first
- BuiltinToolProvider.created_at.asc(), # oldest first
- )
- .first()
- )
-
- @staticmethod
- def save_custom_oauth_client_params(
- tenant_id: str,
- provider: str,
- client_params: Optional[dict] = None,
- enable_oauth_custom_client: Optional[bool] = None,
- ):
- """
- setup oauth custom client
- """
- if client_params is None and enable_oauth_custom_client is None:
- return {"result": "success"}
-
- tool_provider = ToolProviderID(provider)
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- if not provider_controller:
- raise ToolProviderNotFoundError(f"Provider {provider} not found")
-
- if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
- raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
-
- with Session(db.engine) as session:
- custom_client_params = (
- session.query(ToolOAuthTenantClient)
- .filter_by(
- tenant_id=tenant_id,
- plugin_id=tool_provider.plugin_id,
- provider=tool_provider.provider_name,
- )
- .first()
- )
-
- # if the record does not exist, create a basic record
- if custom_client_params is None:
- custom_client_params = ToolOAuthTenantClient(
- tenant_id=tenant_id,
- plugin_id=tool_provider.plugin_id,
- provider=tool_provider.provider_name,
- )
- session.add(custom_client_params)
-
- if client_params is not None:
- encrypter, _ = create_provider_encrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
- cache=NoOpProviderCredentialCache(),
- )
- original_params = encrypter.decrypt(custom_client_params.oauth_params)
- new_params: dict = {
- key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
- for key, value in client_params.items()
- }
- custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
-
- if enable_oauth_custom_client is not None:
- custom_client_params.enabled = enable_oauth_custom_client
-
- session.commit()
- return {"result": "success"}
-
- @staticmethod
- def get_custom_oauth_client_params(tenant_id: str, provider: str):
- """
- get custom oauth client params
- """
- with Session(db.engine) as session:
- tool_provider = ToolProviderID(provider)
- custom_oauth_client_params: ToolOAuthTenantClient | None = (
- session.query(ToolOAuthTenantClient)
- .filter_by(
- tenant_id=tenant_id,
- plugin_id=tool_provider.plugin_id,
- provider=tool_provider.provider_name,
- )
- .first()
- )
- if custom_oauth_client_params is None:
- return {}
-
- provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- if not provider_controller:
- raise ToolProviderNotFoundError(f"Provider {provider} not found")
-
- if not isinstance(provider_controller, BuiltinToolProviderController):
- raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
-
- encrypter, _ = create_provider_encrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
- cache=NoOpProviderCredentialCache(),
- )
-
- return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|