Co-authored-by: GareArc <chen4851@purude.edu>tags/0.7.2
| # Celery beat configuration | # Celery beat configuration | ||||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||||
| # Position configuration | |||||
| POSITION_TOOL_PINS= | |||||
| POSITION_TOOL_INCLUDES= | |||||
| POSITION_TOOL_EXCLUDES= | |||||
| POSITION_PROVIDER_PINS= | |||||
| POSITION_PROVIDER_INCLUDES= | |||||
| POSITION_PROVIDER_EXCLUDES= |
| default=False, | default=False, | ||||
| ) | ) | ||||
| class WorkspaceConfig(BaseSettings): | class WorkspaceConfig(BaseSettings): | ||||
| """ | """ | ||||
| Workspace configs | Workspace configs | ||||
| ) | ) | ||||
| class PositionConfig(BaseSettings): | |||||
| POSITION_PROVIDER_PINS: str = Field( | |||||
| description='The heads of model providers', | |||||
| default='', | |||||
| ) | |||||
| POSITION_PROVIDER_INCLUDES: str = Field( | |||||
| description='The included model providers', | |||||
| default='', | |||||
| ) | |||||
| POSITION_PROVIDER_EXCLUDES: str = Field( | |||||
| description='The excluded model providers', | |||||
| default='', | |||||
| ) | |||||
| POSITION_TOOL_PINS: str = Field( | |||||
| description='The heads of tools', | |||||
| default='', | |||||
| ) | |||||
| POSITION_TOOL_INCLUDES: str = Field( | |||||
| description='The included tools', | |||||
| default='', | |||||
| ) | |||||
| POSITION_TOOL_EXCLUDES: str = Field( | |||||
| description='The excluded tools', | |||||
| default='', | |||||
| ) | |||||
| @computed_field | |||||
| def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: | |||||
| return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] | |||||
| @computed_field | |||||
| def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: | |||||
| return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''} | |||||
| @computed_field | |||||
| def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: | |||||
| return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''} | |||||
| @computed_field | |||||
| def POSITION_TOOL_PINS_LIST(self) -> list[str]: | |||||
| return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] | |||||
| @computed_field | |||||
| def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: | |||||
| return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''} | |||||
| @computed_field | |||||
| def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: | |||||
| return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''} | |||||
| class FeatureConfig( | class FeatureConfig( | ||||
| # place the configs in alphabet order | # place the configs in alphabet order | ||||
| AppExecutionConfig, | AppExecutionConfig, | ||||
| UpdateConfig, | UpdateConfig, | ||||
| WorkflowConfig, | WorkflowConfig, | ||||
| WorkspaceConfig, | WorkspaceConfig, | ||||
| PositionConfig, | |||||
| # hosted services config | # hosted services config | ||||
| HostedServiceConfig, | HostedServiceConfig, |
| from collections.abc import Callable | from collections.abc import Callable | ||||
| from typing import Any | from typing import Any | ||||
| from configs import dify_config | |||||
| from core.tools.utils.yaml_utils import load_yaml_file | from core.tools.utils.yaml_utils import load_yaml_file | ||||
| return {name: index for index, name in enumerate(positions)} | return {name: index for index, name in enumerate(positions)} | ||||
| def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: | |||||
| """ | |||||
| Get the mapping for tools from name to index from a YAML file. | |||||
| :param folder_path: | |||||
| :param file_name: the YAML file name, default to '_position.yaml' | |||||
| :return: a dict with name as key and index as value | |||||
| """ | |||||
| position_map = get_position_map(folder_path, file_name=file_name) | |||||
| return pin_position_map( | |||||
| position_map, | |||||
| pin_list=dify_config.POSITION_TOOL_PINS_LIST, | |||||
| ) | |||||
| def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: | |||||
| """ | |||||
| Get the mapping for providers from name to index from a YAML file. | |||||
| :param folder_path: | |||||
| :param file_name: the YAML file name, default to '_position.yaml' | |||||
| :return: a dict with name as key and index as value | |||||
| """ | |||||
| position_map = get_position_map(folder_path, file_name=file_name) | |||||
| return pin_position_map( | |||||
| position_map, | |||||
| pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, | |||||
| ) | |||||
| def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: | |||||
| """ | |||||
| Pin the items in the pin list to the beginning of the position map. | |||||
| Overall logic: exclude > include > pin | |||||
| :param position_map: the position map to be sorted and filtered | |||||
| :param pin_list: the list of pins to be put at the beginning | |||||
| :return: the sorted position map | |||||
| """ | |||||
| positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) | |||||
| # Add pins to position map | |||||
| position_map = {name: idx for idx, name in enumerate(pin_list)} | |||||
| # Add remaining positions to position map | |||||
| start_idx = len(position_map) | |||||
| for name in positions: | |||||
| if name not in position_map: | |||||
| position_map[name] = start_idx | |||||
| start_idx += 1 | |||||
| return position_map | |||||
| def is_filtered( | |||||
| include_set: set[str], | |||||
| exclude_set: set[str], | |||||
| data: Any, | |||||
| name_func: Callable[[Any], str], | |||||
| ) -> bool: | |||||
| """ | |||||
| Chcek if the object should be filtered out. | |||||
| Overall logic: exclude > include > pin | |||||
| :param include_set: the set of names to be included | |||||
| :param exclude_set: the set of names to be excluded | |||||
| :param name_func: the function to get the name of the object | |||||
| :param data: the data to be filtered | |||||
| :return: True if the object should be filtered out, False otherwise | |||||
| """ | |||||
| if not data: | |||||
| return False | |||||
| if not include_set and not exclude_set: | |||||
| return False | |||||
| name = name_func(data) | |||||
| if name in exclude_set: # exclude_set is prioritized | |||||
| return True | |||||
| if include_set and name not in include_set: # filter out only if include_set is not empty | |||||
| return True | |||||
| return False | |||||
| def sort_by_position_map( | def sort_by_position_map( | ||||
| position_map: dict[str, int], | position_map: dict[str, int], | ||||
| data: list[Any], | data: list[Any], |
| return ModelInstance(provider_model_bundle, model) | return ModelInstance(provider_model_bundle, model) | ||||
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | |||||
| """ | |||||
| Return first provider and the first model in the provider | |||||
| :param tenant_id: tenant id | |||||
| :param model_type: model type | |||||
| :return: provider name, model name | |||||
| """ | |||||
| return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) | |||||
| def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: | def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: | ||||
| """ | """ | ||||
| Get default model instance | Get default model instance | ||||
| config.id | config.id | ||||
| ) | ) | ||||
| res = redis_client.exists(cooldown_cache_key) | res = redis_client.exists(cooldown_cache_key) | ||||
| res = cast(bool, res) | res = cast(bool, res) | ||||
| return res | return res |
| os.path.join(provider_model_type_path, model_schema_yaml) | os.path.join(provider_model_type_path, model_schema_yaml) | ||||
| for model_schema_yaml in os.listdir(provider_model_type_path) | for model_schema_yaml in os.listdir(provider_model_type_path) | ||||
| if not model_schema_yaml.startswith('__') | if not model_schema_yaml.startswith('__') | ||||
| and not model_schema_yaml.startswith('_') | |||||
| and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) | |||||
| and model_schema_yaml.endswith('.yaml') | |||||
| and not model_schema_yaml.startswith('_') | |||||
| and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) | |||||
| and model_schema_yaml.endswith('.yaml') | |||||
| ] | ] | ||||
| # get _position.yaml file path | # get _position.yaml file path |
| from pydantic import BaseModel, ConfigDict | from pydantic import BaseModel, ConfigDict | ||||
| from core.helper.module_import_helper import load_single_subclass_from_source | from core.helper.module_import_helper import load_single_subclass_from_source | ||||
| from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map | |||||
| from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map | |||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | ||||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||||
| ] | ] | ||||
| # get _position.yaml file path | # get _position.yaml file path | ||||
| position_map = get_position_map(model_providers_path) | |||||
| position_map = get_provider_position_map(model_providers_path) | |||||
| # traverse all model_provider_dir_paths | # traverse all model_provider_dir_paths | ||||
| model_providers: list[ModelProviderExtension] = [] | model_providers: list[ModelProviderExtension] = [] |
| from sqlalchemy.exc import IntegrityError | from sqlalchemy.exc import IntegrityError | ||||
| from configs import dify_config | |||||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | ||||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | ||||
| from core.entities.provider_entities import ( | from core.entities.provider_entities import ( | ||||
| ) | ) | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | ||||
| from core.helper.position_helper import is_filtered | |||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.entities.provider_entities import ( | |||||
| CredentialFormSchema, | |||||
| FormType, | |||||
| ProviderEntity, | |||||
| ) | |||||
| from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity | |||||
| from core.model_runtime.model_providers import model_provider_factory | from core.model_runtime.model_providers import model_provider_factory | ||||
| from extensions import ext_hosting_provider | from extensions import ext_hosting_provider | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| """ | """ | ||||
| ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. | ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. | ||||
| """ | """ | ||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.decoding_rsa_key = None | self.decoding_rsa_key = None | ||||
| self.decoding_cipher_rsa = None | self.decoding_cipher_rsa = None | ||||
| # Construct ProviderConfiguration objects for each provider | # Construct ProviderConfiguration objects for each provider | ||||
| for provider_entity in provider_entities: | for provider_entity in provider_entities: | ||||
| # handle include, exclude | |||||
| if is_filtered( | |||||
| include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, | |||||
| exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, | |||||
| data=provider_entity, | |||||
| name_func=lambda x: x.provider, | |||||
| ): | |||||
| continue | |||||
| provider_name = provider_entity.provider | provider_name = provider_entity.provider | ||||
| provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) | provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) | ||||
| provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) | provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) | ||||
| ) | ) | ||||
| ) | ) | ||||
| def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | |||||
| """ | |||||
| Get names of first model and its provider | |||||
| :param tenant_id: workspace id | |||||
| :param model_type: model type | |||||
| :return: provider name, model name | |||||
| """ | |||||
| provider_configurations = self.get_configurations(tenant_id) | |||||
| # get available models from provider_configurations | |||||
| all_models = provider_configurations.get_models( | |||||
| model_type=model_type, | |||||
| only_active=False | |||||
| ) | |||||
| return all_models[0].provider.provider, all_models[0].model | |||||
| def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ | def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ | ||||
| -> TenantDefaultModel: | -> TenantDefaultModel: | ||||
| """ | """ |
| import os.path | import os.path | ||||
| from core.helper.position_helper import get_position_map, sort_by_position_map | |||||
| from core.helper.position_helper import get_tool_position_map, sort_by_position_map | |||||
| from core.tools.entities.api_entities import UserToolProvider | from core.tools.entities.api_entities import UserToolProvider | ||||
| @classmethod | @classmethod | ||||
| def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: | def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: | ||||
| if not cls._position: | if not cls._position: | ||||
| cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) | |||||
| cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) | |||||
| def name_func(provider: UserToolProvider) -> str: | def name_func(provider: UserToolProvider) -> str: | ||||
| return provider.name | return provider.name | ||||
| sorted_providers = sort_by_position_map(cls._position, providers, name_func) | sorted_providers = sort_by_position_map(cls._position, providers, name_func) | ||||
| return sorted_providers | |||||
| return sorted_providers |
| from core.agent.entities import AgentToolEntity | from core.agent.entities import AgentToolEntity | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.helper.module_import_helper import load_single_subclass_from_source | from core.helper.module_import_helper import load_single_subclass_from_source | ||||
| from core.helper.position_helper import is_filtered | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral | from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral | ||||
| from core.tools.entities.common_entities import I18nObject | from core.tools.entities.common_entities import I18nObject | ||||
| from core.tools.entities.tool_entities import ( | |||||
| ApiProviderAuthType, | |||||
| ToolInvokeFrom, | |||||
| ToolParameter, | |||||
| ) | |||||
| from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter | |||||
| from core.tools.errors import ToolProviderNotFoundError | from core.tools.errors import ToolProviderNotFoundError | ||||
| from core.tools.provider.api_tool_provider import ApiToolProviderController | from core.tools.provider.api_tool_provider import ApiToolProviderController | ||||
| from core.tools.provider.builtin._positions import BuiltinToolProviderSort | from core.tools.provider.builtin._positions import BuiltinToolProviderSort | ||||
| from core.tools.tool.builtin_tool import BuiltinTool | from core.tools.tool.builtin_tool import BuiltinTool | ||||
| from core.tools.tool.tool import Tool | from core.tools.tool.tool import Tool | ||||
| from core.tools.tool_label_manager import ToolLabelManager | from core.tools.tool_label_manager import ToolLabelManager | ||||
| from core.tools.utils.configuration import ( | |||||
| ToolConfigurationManager, | |||||
| ToolParameterConfigurationManager, | |||||
| ) | |||||
| from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager | |||||
| from core.tools.utils.tool_parameter_converter import ToolParameterConverter | from core.tools.utils.tool_parameter_converter import ToolParameterConverter | ||||
| from core.workflow.nodes.tool.entities import ToolEntity | from core.workflow.nodes.tool.entities import ToolEntity | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class ToolManager: | class ToolManager: | ||||
| _builtin_provider_lock = Lock() | _builtin_provider_lock = Lock() | ||||
| _builtin_providers = {} | _builtin_providers = {} | ||||
| tenant_id: str, | tenant_id: str, | ||||
| invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, | invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, | ||||
| tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ | tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ | ||||
| -> Union[BuiltinTool, ApiTool]: | |||||
| -> Union[BuiltinTool, ApiTool]: | |||||
| """ | """ | ||||
| get the tool runtime | get the tool runtime | ||||
| provider_class = load_single_subclass_from_source( | provider_class = load_single_subclass_from_source( | ||||
| module_name=f'core.tools.provider.builtin.{provider}.{provider}', | module_name=f'core.tools.provider.builtin.{provider}.{provider}', | ||||
| script_path=path.join(path.dirname(path.realpath(__file__)), | script_path=path.join(path.dirname(path.realpath(__file__)), | ||||
| 'provider', 'builtin', provider, f'{provider}.py'), | |||||
| 'provider', 'builtin', provider, f'{provider}.py'), | |||||
| parent_type=BuiltinToolProviderController) | parent_type=BuiltinToolProviderController) | ||||
| provider: BuiltinToolProviderController = provider_class() | provider: BuiltinToolProviderController = provider_class() | ||||
| cls._builtin_providers[provider.identity.name] = provider | cls._builtin_providers[provider.identity.name] = provider | ||||
| # append builtin providers | # append builtin providers | ||||
| for provider in builtin_providers: | for provider in builtin_providers: | ||||
| # handle include, exclude | |||||
| if is_filtered( | |||||
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, | |||||
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, | |||||
| data=provider, | |||||
| name_func=lambda x: x.identity.name | |||||
| ): | |||||
| continue | |||||
| user_provider = ToolTransformService.builtin_provider_to_user_provider( | user_provider = ToolTransformService.builtin_provider_to_user_provider( | ||||
| provider_controller=provider, | provider_controller=provider, | ||||
| db_provider=find_db_builtin_provider(provider.identity.name), | db_provider=find_db_builtin_provider(provider.identity.name), | ||||
| @classmethod | @classmethod | ||||
| def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ | def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ | ||||
| ApiToolProviderController, dict[str, Any]]: | |||||
| ApiToolProviderController, dict[str, Any]]: | |||||
| """ | """ | ||||
| get the api provider | get the api provider | ||||
| else: | else: | ||||
| raise ValueError(f"provider type {provider_type} not found") | raise ValueError(f"provider type {provider_type} not found") | ||||
| ToolManager.load_builtin_providers_cache() | ToolManager.load_builtin_providers_cache() |
| 'completion_params': {} | 'completion_params': {} | ||||
| } | } | ||||
| else: | else: | ||||
| provider, model = model_manager.get_default_provider_model_name( | |||||
| tenant_id=account.current_tenant_id, | |||||
| model_type=ModelType.LLM | |||||
| ) | |||||
| default_model_config['model']['provider'] = provider | |||||
| default_model_config['model']['name'] = model | |||||
| default_model_dict = default_model_config['model'] | default_model_dict = default_model_config['model'] | ||||
| default_model_config['model'] = json.dumps(default_model_dict) | default_model_config['model'] = json.dumps(default_model_dict) | ||||
| """ | """ | ||||
| Modified App class | Modified App class | ||||
| """ | """ | ||||
| def __init__(self, app): | def __init__(self, app): | ||||
| self.__dict__.update(app.__dict__) | self.__dict__.update(app.__dict__) | ||||
| @property | @property | ||||
| def app_model_config(self): | def app_model_config(self): | ||||
| return model_config | return model_config | ||||
| app = ModifiedApp(app) | app = ModifiedApp(app) | ||||
| return app | return app |
| """ | """ | ||||
| Model Provider Service | Model Provider Service | ||||
| """ | """ | ||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.provider_manager = ProviderManager() | self.provider_manager = ProviderManager() | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_type=model_type_enum | model_type=model_type_enum | ||||
| ) | ) | ||||
| return DefaultModelResponse( | |||||
| model=result.model, | |||||
| model_type=result.model_type, | |||||
| provider=SimpleProviderEntityResponse( | |||||
| provider=result.provider.provider, | |||||
| label=result.provider.label, | |||||
| icon_small=result.provider.icon_small, | |||||
| icon_large=result.provider.icon_large, | |||||
| supported_model_types=result.provider.supported_model_types | |||||
| ) | |||||
| ) if result else None | |||||
| try: | |||||
| return DefaultModelResponse( | |||||
| model=result.model, | |||||
| model_type=result.model_type, | |||||
| provider=SimpleProviderEntityResponse( | |||||
| provider=result.provider.provider, | |||||
| label=result.provider.label, | |||||
| icon_small=result.provider.icon_small, | |||||
| icon_large=result.provider.icon_large, | |||||
| supported_model_types=result.provider.supported_model_types | |||||
| ) | |||||
| ) if result else None | |||||
| except Exception as e: | |||||
| logger.info(f"get_default_model_of_model_type error: {e}") | |||||
| return None | |||||
| def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: | def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: | ||||
| """ | """ |
| import json | import json | ||||
| import logging | import logging | ||||
| from configs import dify_config | |||||
| from core.helper.position_helper import is_filtered | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.tools.entities.api_entities import UserTool, UserToolProvider | from core.tools.entities.api_entities import UserTool, UserToolProvider | ||||
| from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError | from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError | ||||
| result = [] | result = [] | ||||
| for tool in tools: | for tool in tools: | ||||
| result.append(ToolTransformService.tool_to_user_tool( | result.append(ToolTransformService.tool_to_user_tool( | ||||
| tool=tool, | |||||
| credentials=credentials, | |||||
| tool=tool, | |||||
| credentials=credentials, | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| labels=ToolLabelManager.get_tool_labels(provider_controller) | labels=ToolLabelManager.get_tool_labels(provider_controller) | ||||
| )) | )) | ||||
| return result | return result | ||||
| @staticmethod | @staticmethod | ||||
| def list_builtin_provider_credentials_schema( | def list_builtin_provider_credentials_schema( | ||||
| provider_name | provider_name | ||||
| BuiltinToolProvider.provider == provider_name, | BuiltinToolProvider.provider == provider_name, | ||||
| ).first() | ).first() | ||||
| try: | |||||
| try: | |||||
| # get provider | # get provider | ||||
| provider_controller = ToolManager.get_builtin_provider(provider_name) | provider_controller = ToolManager.get_builtin_provider(provider_name) | ||||
| if not provider_controller.need_credentials: | if not provider_controller.need_credentials: | ||||
| # delete cache | # delete cache | ||||
| tool_configuration.delete_tool_credentials_cache() | tool_configuration.delete_tool_credentials_cache() | ||||
| return { 'result': 'success' } | |||||
| return {'result': 'success'} | |||||
| @staticmethod | @staticmethod | ||||
| def get_builtin_tool_provider_credentials( | def get_builtin_tool_provider_credentials( | ||||
| user_id: str, tenant_id: str, provider: str | user_id: str, tenant_id: str, provider: str | ||||
| if provider is None: | if provider is None: | ||||
| return {} | return {} | ||||
| provider_controller = ToolManager.get_builtin_provider(provider.provider) | provider_controller = ToolManager.get_builtin_provider(provider.provider) | ||||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | ||||
| credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | ||||
| if provider is None: | if provider is None: | ||||
| raise ValueError(f'you have not added provider {provider_name}') | raise ValueError(f'you have not added provider {provider_name}') | ||||
| db.session.delete(provider) | db.session.delete(provider) | ||||
| db.session.commit() | db.session.commit() | ||||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | ||||
| tool_configuration.delete_tool_credentials_cache() | tool_configuration.delete_tool_credentials_cache() | ||||
| return { 'result': 'success' } | |||||
| return {'result': 'success'} | |||||
| @staticmethod | @staticmethod | ||||
| def get_builtin_tool_provider_icon( | def get_builtin_tool_provider_icon( | ||||
| provider: str | provider: str | ||||
| icon_bytes = f.read() | icon_bytes = f.read() | ||||
| return icon_bytes, mime_type | return icon_bytes, mime_type | ||||
| @staticmethod | @staticmethod | ||||
| def list_builtin_tools( | def list_builtin_tools( | ||||
| user_id: str, tenant_id: str | user_id: str, tenant_id: str | ||||
| for provider_controller in provider_controllers: | for provider_controller in provider_controllers: | ||||
| try: | try: | ||||
| # handle include, exclude | |||||
| if is_filtered( | |||||
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, | |||||
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, | |||||
| data=provider_controller, | |||||
| name_func=lambda x: x.identity.name | |||||
| ): | |||||
| continue | |||||
| # convert provider controller to user provider | # convert provider controller to user provider | ||||
| user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( | user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( | ||||
| provider_controller=provider_controller, | provider_controller=provider_controller, | ||||
| raise e | raise e | ||||
| return BuiltinToolProviderSort.sort(result) | return BuiltinToolProviderSort.sort(result) | ||||
| import pytest | import pytest | ||||
| from core.helper.position_helper import get_position_map | |||||
| from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map | |||||
| @pytest.fixture | @pytest.fixture | ||||
| - second | - second | ||||
| # - commented | # - commented | ||||
| - third | - third | ||||
| - 9999999999999 | - 9999999999999 | ||||
| - forth | - forth | ||||
| """)) | """)) | ||||
| """\ | """\ | ||||
| # - commented1 | # - commented1 | ||||
| # - commented2 | # - commented2 | ||||
| - | |||||
| - | |||||
| - | |||||
| - | |||||
| """)) | """)) | ||||
| return str(tmp_path) | return str(tmp_path) | ||||
| folder_path=prepare_empty_commented_positions_yaml, | folder_path=prepare_empty_commented_positions_yaml, | ||||
| file_name='example_positions_all_commented.yaml') | file_name='example_positions_all_commented.yaml') | ||||
| assert position_map == {} | assert position_map == {} | ||||
| def test_excluded_position_data(prepare_example_positions_yaml): | |||||
| position_map = get_position_map( | |||||
| folder_path=prepare_example_positions_yaml, | |||||
| file_name='example_positions.yaml' | |||||
| ) | |||||
| pin_list = ['forth', 'first'] | |||||
| include_set = set() | |||||
| exclude_set = {'9999999999999'} | |||||
| position_map = pin_position_map( | |||||
| original_position_map=position_map, | |||||
| pin_list=pin_list | |||||
| ) | |||||
| data = [ | |||||
| "forth", | |||||
| "first", | |||||
| "second", | |||||
| "third", | |||||
| "9999999999999", | |||||
| "extra1", | |||||
| "extra2", | |||||
| ] | |||||
| # filter out the data | |||||
| data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] | |||||
| # sort data by position map | |||||
| sorted_data = sort_by_position_map( | |||||
| position_map=position_map, | |||||
| data=data, | |||||
| name_func=lambda x: x, | |||||
| ) | |||||
| # assert the result in the correct order | |||||
| assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2'] | |||||
| def test_included_position_data(prepare_example_positions_yaml): | |||||
| position_map = get_position_map( | |||||
| folder_path=prepare_example_positions_yaml, | |||||
| file_name='example_positions.yaml' | |||||
| ) | |||||
| pin_list = ['forth', 'first'] | |||||
| include_set = {'forth', 'first'} | |||||
| exclude_set = {} | |||||
| position_map = pin_position_map( | |||||
| original_position_map=position_map, | |||||
| pin_list=pin_list | |||||
| ) | |||||
| data = [ | |||||
| "forth", | |||||
| "first", | |||||
| "second", | |||||
| "third", | |||||
| "9999999999999", | |||||
| "extra1", | |||||
| "extra2", | |||||
| ] | |||||
| # filter out the data | |||||
| data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] | |||||
| # sort data by position map | |||||
| sorted_data = sort_by_position_map( | |||||
| position_map=position_map, | |||||
| data=data, | |||||
| name_func=lambda x: x, | |||||
| ) | |||||
| # assert the result in the correct order | |||||
| assert sorted_data == ['forth', 'first'] |
| # ------------------------------ | # ------------------------------ | ||||
| EXPOSE_NGINX_PORT=80 | EXPOSE_NGINX_PORT=80 | ||||
| EXPOSE_NGINX_SSL_PORT=443 | EXPOSE_NGINX_SSL_PORT=443 | ||||
| # ---------------------------------------------------------------------------- | |||||
| # ModelProvider & Tool Position Configuration | |||||
| # Used to specify the model providers and tools that can be used in the app. | |||||
| # ---------------------------------------------------------------------------- | |||||
| # Pin, include, and exclude tools | |||||
| # Use comma-separated values with no spaces between items. | |||||
| # Example: POSITION_TOOL_PINS=bing,google | |||||
| POSITION_TOOL_PINS= | |||||
| POSITION_TOOL_INCLUDES= | |||||
| POSITION_TOOL_EXCLUDES= | |||||
| # Pin, include, and exclude model providers | |||||
| # Use comma-separated values with no spaces between items. | |||||
| # Example: POSITION_PROVIDER_PINS=openai,openllm | |||||
| POSITION_PROVIDER_PINS= | |||||
| POSITION_PROVIDER_INCLUDES= | |||||
| POSITION_PROVIDER_EXCLUDES= |