| @@ -27,7 +27,9 @@ from fields.app_fields import ( | |||
| from libs.login import login_required | |||
| from models.model import App, AppModelConfig, Site | |||
| from services.app_model_config_service import AppModelConfigService | |||
| from core.tools.utils.configuration import ToolParameterConfigurationManager | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.entities.application_entities import AgentToolEntity | |||
| def _get_app(app_id, tenant_id): | |||
| app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() | |||
| @@ -236,7 +238,39 @@ class AppApi(Resource): | |||
| def get(self, app_id): | |||
| """Get app detail""" | |||
| app_id = str(app_id) | |||
| app = _get_app(app_id, current_user.current_tenant_id) | |||
| app: App = _get_app(app_id, current_user.current_tenant_id) | |||
| # get original app model config | |||
| model_config: AppModelConfig = app.app_model_config | |||
| agent_mode = model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| for tool in agent_mode.get('tools') or []: | |||
| agent_tool_entity = AgentToolEntity(**tool) | |||
| # get tool | |||
| tool_runtime = ToolManager.get_agent_tool_runtime( | |||
| tenant_id=current_user.current_tenant_id, | |||
| agent_tool=agent_tool_entity, | |||
| agent_callback=None | |||
| ) | |||
| manager = ToolParameterConfigurationManager( | |||
| tenant_id=current_user.current_tenant_id, | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| ) | |||
| # get decrypted parameters | |||
| if agent_tool_entity.tool_parameters: | |||
| parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) | |||
| masked_parameter = manager.mask_tool_parameters(parameters or {}) | |||
| else: | |||
| masked_parameter = {} | |||
| # override tool parameters | |||
| tool['tool_parameters'] = masked_parameter | |||
| # override agent mode | |||
| model_config.agent_mode = json.dumps(agent_mode) | |||
| return app | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| from flask import request | |||
| from flask_login import current_user | |||
| @@ -7,6 +8,9 @@ from controllers.console import api | |||
| from controllers.console.app import _get_app | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.entities.application_entities import AgentToolEntity | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.tools.utils.configuration import ToolParameterConfigurationManager | |||
| from events.app_event import app_model_config_was_updated | |||
| from extensions.ext_database import db | |||
| from libs.login import login_required | |||
| @@ -38,6 +42,82 @@ class ModelConfigResource(Resource): | |||
| ) | |||
| new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) | |||
| # get original app model config | |||
| original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( | |||
| AppModelConfig.id == app.app_model_config_id | |||
| ).first() | |||
| agent_mode = original_app_model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| parameter_map = {} | |||
| masked_parameter_map = {} | |||
| tool_map = {} | |||
| for tool in agent_mode.get('tools') or []: | |||
| agent_tool_entity = AgentToolEntity(**tool) | |||
| # get tool | |||
| tool_runtime = ToolManager.get_agent_tool_runtime( | |||
| tenant_id=current_user.current_tenant_id, | |||
| agent_tool=agent_tool_entity, | |||
| agent_callback=None | |||
| ) | |||
| manager = ToolParameterConfigurationManager( | |||
| tenant_id=current_user.current_tenant_id, | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| ) | |||
| # get decrypted parameters | |||
| if agent_tool_entity.tool_parameters: | |||
| parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) | |||
| masked_parameter = manager.mask_tool_parameters(parameters or {}) | |||
| else: | |||
| parameters = {} | |||
| masked_parameter = {} | |||
| key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' | |||
| masked_parameter_map[key] = masked_parameter | |||
| parameter_map[key] = parameters | |||
| tool_map[key] = tool_runtime | |||
| # encrypt agent tool parameters if it's secret-input | |||
| agent_mode = new_app_model_config.agent_mode_dict | |||
| for tool in agent_mode.get('tools') or []: | |||
| agent_tool_entity = AgentToolEntity(**tool) | |||
| # get tool | |||
| key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' | |||
| if key in tool_map: | |||
| tool_runtime = tool_map[key] | |||
| else: | |||
| tool_runtime = ToolManager.get_agent_tool_runtime( | |||
| tenant_id=current_user.current_tenant_id, | |||
| agent_tool=agent_tool_entity, | |||
| agent_callback=None | |||
| ) | |||
| manager = ToolParameterConfigurationManager( | |||
| tenant_id=current_user.current_tenant_id, | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| ) | |||
| manager.delete_tool_parameters_cache() | |||
| # override parameters if it equals to masked parameters | |||
| if agent_tool_entity.tool_parameters: | |||
| if key not in masked_parameter_map: | |||
| continue | |||
| if agent_tool_entity.tool_parameters == masked_parameter_map[key]: | |||
| agent_tool_entity.tool_parameters = parameter_map[key] | |||
| # encrypt parameters | |||
| if agent_tool_entity.tool_parameters: | |||
| tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) | |||
| # update app model config | |||
| new_app_model_config.agent_mode = json.dumps(agent_mode) | |||
| db.session.add(new_app_model_config) | |||
| db.session.flush() | |||
| @@ -154,9 +154,9 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| """ | |||
| convert tool to prompt message tool | |||
| """ | |||
| tool_entity = ToolManager.get_tool_runtime( | |||
| provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, | |||
| tenant_id=self.application_generate_entity.tenant_id, | |||
| tool_entity = ToolManager.get_agent_tool_runtime( | |||
| tenant_id=self.tenant_id, | |||
| agent_tool=tool, | |||
| agent_callback=self.agent_callback | |||
| ) | |||
| tool_entity.load_variables(self.variables_pool) | |||
| @@ -171,33 +171,11 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| } | |||
| ) | |||
| runtime_parameters = {} | |||
| parameters = tool_entity.parameters or [] | |||
| user_parameters = tool_entity.get_runtime_parameters() or [] | |||
| # override parameters | |||
| for parameter in user_parameters: | |||
| # check if parameter in tool parameters | |||
| found = False | |||
| for tool_parameter in parameters: | |||
| if tool_parameter.name == parameter.name: | |||
| found = True | |||
| break | |||
| if found: | |||
| # override parameter | |||
| tool_parameter.type = parameter.type | |||
| tool_parameter.form = parameter.form | |||
| tool_parameter.required = parameter.required | |||
| tool_parameter.default = parameter.default | |||
| tool_parameter.options = parameter.options | |||
| tool_parameter.llm_description = parameter.llm_description | |||
| else: | |||
| # add new parameter | |||
| parameters.append(parameter) | |||
| parameters = tool_entity.get_all_runtime_parameters() | |||
| for parameter in parameters: | |||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | |||
| continue | |||
| parameter_type = 'string' | |||
| enum = [] | |||
| if parameter.type == ToolParameter.ToolParameterType.STRING: | |||
| @@ -213,59 +191,16 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| else: | |||
| raise ValueError(f"parameter type {parameter.type} is not supported") | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM: | |||
| # get tool parameter from form | |||
| tool_parameter_config = tool.tool_parameters.get(parameter.name) | |||
| if not tool_parameter_config: | |||
| # get default value | |||
| tool_parameter_config = parameter.default | |||
| if not tool_parameter_config and parameter.required: | |||
| raise ValueError(f"tool parameter {parameter.name} not found in tool config") | |||
| if parameter.type == ToolParameter.ToolParameterType.SELECT: | |||
| # check if tool_parameter_config in options | |||
| options = list(map(lambda x: x.value, parameter.options)) | |||
| if tool_parameter_config not in options: | |||
| raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") | |||
| # convert tool parameter config to correct type | |||
| try: | |||
| if parameter.type == ToolParameter.ToolParameterType.NUMBER: | |||
| # check if tool parameter is integer | |||
| if isinstance(tool_parameter_config, int): | |||
| tool_parameter_config = tool_parameter_config | |||
| elif isinstance(tool_parameter_config, float): | |||
| tool_parameter_config = tool_parameter_config | |||
| elif isinstance(tool_parameter_config, str): | |||
| if '.' in tool_parameter_config: | |||
| tool_parameter_config = float(tool_parameter_config) | |||
| else: | |||
| tool_parameter_config = int(tool_parameter_config) | |||
| elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: | |||
| tool_parameter_config = bool(tool_parameter_config) | |||
| elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: | |||
| tool_parameter_config = str(tool_parameter_config) | |||
| elif parameter.type == ToolParameter.ToolParameterType: | |||
| tool_parameter_config = str(tool_parameter_config) | |||
| except Exception as e: | |||
| raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") | |||
| # save tool parameter to tool entity memory | |||
| runtime_parameters[parameter.name] = tool_parameter_config | |||
| elif parameter.form == ToolParameter.ToolParameterForm.LLM: | |||
| message_tool.parameters['properties'][parameter.name] = { | |||
| "type": parameter_type, | |||
| "description": parameter.llm_description or '', | |||
| } | |||
| if len(enum) > 0: | |||
| message_tool.parameters['properties'][parameter.name]['enum'] = enum | |||
| message_tool.parameters['properties'][parameter.name] = { | |||
| "type": parameter_type, | |||
| "description": parameter.llm_description or '', | |||
| } | |||
| if parameter.required: | |||
| message_tool.parameters['required'].append(parameter.name) | |||
| if len(enum) > 0: | |||
| message_tool.parameters['properties'][parameter.name]['enum'] = enum | |||
| tool_entity.runtime.runtime_parameters.update(runtime_parameters) | |||
| if parameter.required: | |||
| message_tool.parameters['required'].append(parameter.name) | |||
| return message_tool, tool_entity | |||
| @@ -305,6 +240,9 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| tool_runtime_parameters = tool.get_runtime_parameters() or [] | |||
| for parameter in tool_runtime_parameters: | |||
| if parameter.form != ToolParameter.ToolParameterForm.LLM: | |||
| continue | |||
| parameter_type = 'string' | |||
| enum = [] | |||
| if parameter.type == ToolParameter.ToolParameterType.STRING: | |||
| @@ -320,18 +258,17 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| else: | |||
| raise ValueError(f"parameter type {parameter.type} is not supported") | |||
| if parameter.form == ToolParameter.ToolParameterForm.LLM: | |||
| prompt_tool.parameters['properties'][parameter.name] = { | |||
| "type": parameter_type, | |||
| "description": parameter.llm_description or '', | |||
| } | |||
| if len(enum) > 0: | |||
| prompt_tool.parameters['properties'][parameter.name]['enum'] = enum | |||
| if parameter.required: | |||
| if parameter.name not in prompt_tool.parameters['required']: | |||
| prompt_tool.parameters['required'].append(parameter.name) | |||
| prompt_tool.parameters['properties'][parameter.name] = { | |||
| "type": parameter_type, | |||
| "description": parameter.llm_description or '', | |||
| } | |||
| if len(enum) > 0: | |||
| prompt_tool.parameters['properties'][parameter.name]['enum'] = enum | |||
| if parameter.required: | |||
| if parameter.name not in prompt_tool.parameters['required']: | |||
| prompt_tool.parameters['required'].append(parameter.name) | |||
| return prompt_tool | |||
| @@ -0,0 +1,54 @@ | |||
| import json | |||
| from enum import Enum | |||
| from json import JSONDecodeError | |||
| from typing import Optional | |||
| from extensions.ext_redis import redis_client | |||
| class ToolParameterCacheType(Enum): | |||
| PARAMETER = "tool_parameter" | |||
| class ToolParameterCache: | |||
| def __init__(self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| tool_name: str, | |||
| cache_type: ToolParameterCacheType | |||
| ): | |||
| self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" | |||
| def get(self) -> Optional[dict]: | |||
| """ | |||
| Get cached model provider credentials. | |||
| :return: | |||
| """ | |||
| cached_tool_parameter = redis_client.get(self.cache_key) | |||
| if cached_tool_parameter: | |||
| try: | |||
| cached_tool_parameter = cached_tool_parameter.decode('utf-8') | |||
| cached_tool_parameter = json.loads(cached_tool_parameter) | |||
| except JSONDecodeError: | |||
| return None | |||
| return cached_tool_parameter | |||
| else: | |||
| return None | |||
| def set(self, parameters: dict) -> None: | |||
| """ | |||
| Cache model provider credentials. | |||
| :param credentials: provider credentials | |||
| :return: | |||
| """ | |||
| redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) | |||
| def delete(self) -> None: | |||
| """ | |||
| Delete cached model provider credentials. | |||
| :return: | |||
| """ | |||
| redis_client.delete(self.cache_key) | |||
| @@ -119,7 +119,7 @@ parameters: # Parameter list | |||
| - The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. | |||
| - `parameters` Parameter list | |||
| - `name` Parameter name, unique, no duplication with other parameters | |||
| - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box | |||
| - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type | |||
| - `required` Required or not | |||
| - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter | |||
| - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts | |||
| @@ -119,7 +119,7 @@ parameters: # 参数列表 | |||
| - `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 | |||
| - `parameters` 参数列表 | |||
| - `name` 参数名称,唯一,不允许和其他参数重名 | |||
| - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框 | |||
| - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型 | |||
| - `required` 是否必填 | |||
| - 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数 | |||
| - 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 | |||
| @@ -100,6 +100,7 @@ class ToolParameter(BaseModel): | |||
| NUMBER = "number" | |||
| BOOLEAN = "boolean" | |||
| SELECT = "select" | |||
| SECRET_INPUT = "secret-input" | |||
| class ToolParameterForm(Enum): | |||
| SCHEMA = "schema" # should be set while adding tool | |||
| @@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool): | |||
| _api_base_url = URL('https://co.aippt.cn/api') | |||
| _api_token_cache = {} | |||
| _api_token_cache_lock = Lock() | |||
| _style_cache = {} | |||
| _style_cache_lock = Lock() | |||
| _task = {} | |||
| _task_type_map = { | |||
| @@ -390,20 +392,31 @@ class AIPPTGenerateTool(BuiltinTool): | |||
| ).digest() | |||
| ).decode('utf-8') | |||
| def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: | |||
| @classmethod | |||
| def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: | |||
| """ | |||
| Get styles | |||
| :param credentials: the credentials | |||
| :return: Tuple[list[dict[id, color]], list[dict[id, style]] | |||
| """ | |||
| # check cache | |||
| with cls._style_cache_lock: | |||
| # clear expired styles | |||
| now = time() | |||
| for key in list(cls._style_cache.keys()): | |||
| if cls._style_cache[key]['expire'] < now: | |||
| del cls._style_cache[key] | |||
| key = f'{credentials["aippt_access_key"]}#@#{user_id}' | |||
| if key in cls._style_cache: | |||
| return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] | |||
| headers = { | |||
| 'x-channel': '', | |||
| 'x-api-key': self.runtime.credentials['aippt_access_key'], | |||
| 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id) | |||
| 'x-api-key': credentials['aippt_access_key'], | |||
| 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) | |||
| } | |||
| response = get( | |||
| str(self._api_base_url / 'template_component' / 'suit' / 'select'), | |||
| str(cls._api_base_url / 'template_component' / 'suit' / 'select'), | |||
| headers=headers | |||
| ) | |||
| @@ -425,7 +438,26 @@ class AIPPTGenerateTool(BuiltinTool): | |||
| 'name': item.get('title'), | |||
| } for item in response.get('data', {}).get('suit_style') or []] | |||
| with cls._style_cache_lock: | |||
| cls._style_cache[key] = { | |||
| 'colors': colors, | |||
| 'styles': styles, | |||
| 'expire': now + 60 * 60 | |||
| } | |||
| return colors, styles | |||
| def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: | |||
| """ | |||
| Get styles | |||
| :param credentials: the credentials | |||
| :return: Tuple[list[dict[id, color]], list[dict[id, style]] | |||
| """ | |||
| if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): | |||
| return [], [] | |||
| return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) | |||
| def _get_suit(self, style_id: int, colour_id: int) -> int: | |||
| """ | |||
| @@ -14,7 +14,7 @@ description: | |||
| llm: A tool for sending messages to a chat group on Wecom(企业微信) . | |||
| parameters: | |||
| - name: hook_key | |||
| type: string | |||
| type: secret-input | |||
| required: true | |||
| label: | |||
| en_US: Wecom Group bot webhook key | |||
| @@ -266,6 +266,40 @@ class Tool(BaseModel, ABC): | |||
| """ | |||
| return self.parameters | |||
| def get_all_runtime_parameters(self) -> list[ToolParameter]: | |||
| """ | |||
| get all runtime parameters | |||
| :return: all runtime parameters | |||
| """ | |||
| parameters = self.parameters or [] | |||
| parameters = parameters.copy() | |||
| user_parameters = self.get_runtime_parameters() or [] | |||
| user_parameters = user_parameters.copy() | |||
| # override parameters | |||
| for parameter in user_parameters: | |||
| # check if parameter in tool parameters | |||
| found = False | |||
| for tool_parameter in parameters: | |||
| if tool_parameter.name == parameter.name: | |||
| found = True | |||
| break | |||
| if found: | |||
| # override parameter | |||
| tool_parameter.type = parameter.type | |||
| tool_parameter.form = parameter.form | |||
| tool_parameter.required = parameter.required | |||
| tool_parameter.default = parameter.default | |||
| tool_parameter.options = parameter.options | |||
| tool_parameter.llm_description = parameter.llm_description | |||
| else: | |||
| # add new parameter | |||
| parameters.append(parameter) | |||
| return parameters | |||
| def is_tool_available(self) -> bool: | |||
| """ | |||
| check if the tool is available | |||
| @@ -6,11 +6,17 @@ from os import listdir, path | |||
| from typing import Any, Union | |||
| from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler | |||
| from core.entities.application_entities import AgentToolEntity | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.provider_manager import ProviderManager | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.constant import DEFAULT_PROVIDERS | |||
| from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials | |||
| from core.tools.entities.tool_entities import ( | |||
| ApiProviderAuthType, | |||
| ToolInvokeMessage, | |||
| ToolParameter, | |||
| ToolProviderCredentials, | |||
| ) | |||
| from core.tools.entities.user_entities import UserToolProvider | |||
| from core.tools.errors import ToolProviderNotFoundError | |||
| from core.tools.provider.api_tool_provider import ApiBasedToolProviderController | |||
| @@ -21,7 +27,12 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from core.tools.tool.api_tool import ApiTool | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.utils.configuration import ( | |||
| ModelToolConfigurationManager, | |||
| ToolConfigurationManager, | |||
| ToolParameterConfigurationManager, | |||
| ) | |||
| from core.tools.utils.encoder import serialize_base_model_dict | |||
| from extensions.ext_database import db | |||
| from models.tools import ApiToolProvider, BuiltinToolProvider | |||
| @@ -172,7 +183,7 @@ class ToolManager: | |||
| # decrypt the credentials | |||
| credentials = builtin_provider.credentials | |||
| controller = ToolManager.get_builtin_provider(provider_name) | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |||
| @@ -189,7 +200,7 @@ class ToolManager: | |||
| api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) | |||
| # decrypt the credentials | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |||
| return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ | |||
| @@ -214,6 +225,71 @@ class ToolManager: | |||
| else: | |||
| raise ToolProviderNotFoundError(f'provider type {provider_type} not found') | |||
| @staticmethod | |||
| def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: | |||
| """ | |||
| get the agent tool runtime | |||
| """ | |||
| tool_entity = ToolManager.get_tool_runtime( | |||
| provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, | |||
| tenant_id=tenant_id, | |||
| agent_callback=agent_callback | |||
| ) | |||
| runtime_parameters = {} | |||
| parameters = tool_entity.get_all_runtime_parameters() | |||
| for parameter in parameters: | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM: | |||
| # get tool parameter from form | |||
| tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) | |||
| if not tool_parameter_config: | |||
| # get default value | |||
| tool_parameter_config = parameter.default | |||
| if not tool_parameter_config and parameter.required: | |||
| raise ValueError(f"tool parameter {parameter.name} not found in tool config") | |||
| if parameter.type == ToolParameter.ToolParameterType.SELECT: | |||
| # check if tool_parameter_config in options | |||
| options = list(map(lambda x: x.value, parameter.options)) | |||
| if tool_parameter_config not in options: | |||
| raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") | |||
| # convert tool parameter config to correct type | |||
| try: | |||
| if parameter.type == ToolParameter.ToolParameterType.NUMBER: | |||
| # check if tool parameter is integer | |||
| if isinstance(tool_parameter_config, int): | |||
| tool_parameter_config = tool_parameter_config | |||
| elif isinstance(tool_parameter_config, float): | |||
| tool_parameter_config = tool_parameter_config | |||
| elif isinstance(tool_parameter_config, str): | |||
| if '.' in tool_parameter_config: | |||
| tool_parameter_config = float(tool_parameter_config) | |||
| else: | |||
| tool_parameter_config = int(tool_parameter_config) | |||
| elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: | |||
| tool_parameter_config = bool(tool_parameter_config) | |||
| elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: | |||
| tool_parameter_config = str(tool_parameter_config) | |||
| elif parameter.type == ToolParameter.ToolParameterType: | |||
| tool_parameter_config = str(tool_parameter_config) | |||
| except Exception as e: | |||
| raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") | |||
| # save tool parameter to tool entity memory | |||
| runtime_parameters[parameter.name] = tool_parameter_config | |||
| # decrypt runtime parameters | |||
| encryption_manager = ToolParameterConfigurationManager( | |||
| tenant_id=tenant_id, | |||
| tool_runtime=tool_entity, | |||
| provider_name=agent_tool.provider_id, | |||
| provider_type=agent_tool.provider_type, | |||
| ) | |||
| runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) | |||
| tool_entity.runtime.runtime_parameters.update(runtime_parameters) | |||
| return tool_entity | |||
| @staticmethod | |||
| def get_builtin_provider_icon(provider: str) -> tuple[str, str]: | |||
| """ | |||
| @@ -396,7 +472,7 @@ class ToolManager: | |||
| controller = ToolManager.get_builtin_provider(provider_name) | |||
| # init tool configuration | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) | |||
| # decrypt the credentials and mask the credentials | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) | |||
| masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) | |||
| @@ -463,7 +539,7 @@ class ToolManager: | |||
| ) | |||
| # init tool configuration | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) | |||
| # decrypt the credentials and mask the credentials | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) | |||
| @@ -523,7 +599,7 @@ class ToolManager: | |||
| provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE | |||
| ) | |||
| # init tool configuration | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |||
| masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) | |||
| @@ -5,16 +5,19 @@ from pydantic import BaseModel | |||
| from yaml import FullLoader, load | |||
| from core.helper import encrypter | |||
| from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType | |||
| from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType | |||
| from core.tools.entities.tool_entities import ( | |||
| ModelToolConfiguration, | |||
| ModelToolProviderConfiguration, | |||
| ToolParameter, | |||
| ToolProviderCredentials, | |||
| ) | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from core.tools.tool.tool import Tool | |||
| class ToolConfiguration(BaseModel): | |||
| class ToolConfigurationManager(BaseModel): | |||
| tenant_id: str | |||
| provider_controller: ToolProviderController | |||
| @@ -101,6 +104,128 @@ class ToolConfiguration(BaseModel): | |||
| ) | |||
| cache.delete() | |||
| class ToolParameterConfigurationManager(BaseModel): | |||
| """ | |||
| Tool parameter configuration manager | |||
| """ | |||
| tenant_id: str | |||
| tool_runtime: Tool | |||
| provider_name: str | |||
| provider_type: str | |||
| def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| deep copy parameters | |||
| """ | |||
| return {key: value for key, value in parameters.items()} | |||
| def _merge_parameters(self) -> list[ToolParameter]: | |||
| """ | |||
| merge parameters | |||
| """ | |||
| # get tool parameters | |||
| tool_parameters = self.tool_runtime.parameters or [] | |||
| # get tool runtime parameters | |||
| runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] | |||
| # override parameters | |||
| current_parameters = tool_parameters.copy() | |||
| for runtime_parameter in runtime_parameters: | |||
| found = False | |||
| for index, parameter in enumerate(current_parameters): | |||
| if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: | |||
| current_parameters[index] = runtime_parameter | |||
| found = True | |||
| break | |||
| if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: | |||
| current_parameters.append(runtime_parameter) | |||
| return current_parameters | |||
| def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| mask tool parameters | |||
| return a deep copy of parameters with masked values | |||
| """ | |||
| parameters = self._deep_copy(parameters) | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| for parameter in current_parameters: | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: | |||
| if parameter.name in parameters: | |||
| if len(parameters[parameter.name]) > 6: | |||
| parameters[parameter.name] = \ | |||
| parameters[parameter.name][:2] + \ | |||
| '*' * (len(parameters[parameter.name]) - 4) +\ | |||
| parameters[parameter.name][-2:] | |||
| else: | |||
| parameters[parameter.name] = '*' * len(parameters[parameter.name]) | |||
| return parameters | |||
| def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| encrypt tool parameters with tenant id | |||
| return a deep copy of parameters with encrypted values | |||
| """ | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| for parameter in current_parameters: | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: | |||
| if parameter.name in parameters: | |||
| encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| parameters[parameter.name] = encrypted | |||
| return parameters | |||
| def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| decrypt tool parameters with tenant id | |||
| return a deep copy of parameters with decrypted values | |||
| """ | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| provider=f'{self.provider_type}.{self.provider_name}', | |||
| tool_name=self.tool_runtime.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER | |||
| ) | |||
| cached_parameters = cache.get() | |||
| if cached_parameters: | |||
| return cached_parameters | |||
| # override parameters | |||
| current_parameters = self._merge_parameters() | |||
| has_secret_input = False | |||
| for parameter in current_parameters: | |||
| if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: | |||
| if parameter.name in parameters: | |||
| try: | |||
| has_secret_input = True | |||
| parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| except: | |||
| pass | |||
| if has_secret_input: | |||
| cache.set(parameters) | |||
| return parameters | |||
| def delete_tool_parameters_cache(self): | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| provider=f'{self.provider_type}.{self.provider_name}', | |||
| tool_name=self.tool_runtime.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER | |||
| ) | |||
| cache.delete() | |||
| class ModelToolConfigurationManager: | |||
| """ | |||
| Model as tool configuration | |||
| @@ -17,7 +17,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio | |||
| from core.tools.provider.api_tool_provider import ApiBasedToolProviderController | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from core.tools.tool_manager import ToolManager | |||
| from core.tools.utils.configuration import ToolConfiguration | |||
| from core.tools.utils.configuration import ToolConfigurationManager | |||
| from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict | |||
| from core.tools.utils.parser import ApiBasedToolSchemaParser | |||
| from extensions.ext_database import db | |||
| @@ -77,7 +77,7 @@ class ToolManageService: | |||
| provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) | |||
| tools = provider_controller.get_tools() | |||
| tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| # check if user has added the provider | |||
| builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| @@ -279,7 +279,7 @@ class ToolManageService: | |||
| provider_controller.load_bundled_tools(tool_bundles) | |||
| # encrypt credentials | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) | |||
| db_provider.credentials_str = json.dumps(encrypted_credentials) | |||
| @@ -366,7 +366,7 @@ class ToolManageService: | |||
| provider_controller = ToolManager.get_builtin_provider(provider_name) | |||
| if not provider_controller.need_credentials: | |||
| raise ValueError(f'provider {provider_name} does not need credentials') | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| # get original credentials if exists | |||
| if provider is not None: | |||
| original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |||
| @@ -450,7 +450,7 @@ class ToolManageService: | |||
| provider_controller.load_bundled_tools(tool_bundles) | |||
| # get original credentials if exists | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |||
| masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) | |||
| @@ -490,7 +490,7 @@ class ToolManageService: | |||
| # delete cache | |||
| provider_controller = ToolManager.get_builtin_provider(provider_name) | |||
| tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration.delete_tool_credentials_cache() | |||
| return { 'result': 'success' } | |||
| @@ -632,7 +632,7 @@ class ToolManageService: | |||
| # decrypt credentials | |||
| if db_provider.id: | |||
| tool_configuration = ToolConfiguration( | |||
| tool_configuration = ToolConfigurationManager( | |||
| tenant_id=tenant_id, | |||
| provider_controller=provider_controller | |||
| ) | |||