| @@ -3,8 +3,6 @@ import os | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| import yaml | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | |||
| from core.model_runtime.entities.model_entities import ( | |||
| @@ -18,6 +16,7 @@ from core.model_runtime.entities.model_entities import ( | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| from core.utils.position_helper import get_position_map, sort_by_position_map | |||
| @@ -154,8 +153,7 @@ class AIModel(ABC): | |||
| # traverse all model_schema_yaml_paths | |||
| for model_schema_yaml_path in model_schema_yaml_paths: | |||
| # read yaml data from yaml file | |||
| with open(model_schema_yaml_path, encoding='utf-8') as f: | |||
| yaml_data = yaml.safe_load(f) | |||
| yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) | |||
| new_parameter_rules = [] | |||
| for parameter_rule in yaml_data.get('parameter_rules', []): | |||
| @@ -1,11 +1,10 @@ | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| import yaml | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source | |||
| @@ -44,10 +43,7 @@ class ModelProvider(ABC): | |||
| # read provider schema from yaml file | |||
| yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | |||
| yaml_data = {} | |||
| if os.path.exists(yaml_path): | |||
| with open(yaml_path, encoding='utf-8') as f: | |||
| yaml_data = yaml.safe_load(f) | |||
| yaml_data = load_yaml_file(yaml_path, ignore_error=True) | |||
| try: | |||
| # yaml_data to entity | |||
| @@ -2,8 +2,6 @@ from abc import abstractmethod | |||
| from os import listdir, path | |||
| from typing import Any | |||
| from yaml import FullLoader, load | |||
| from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType | |||
| from core.tools.entities.user_entities import UserToolProviderCredentials | |||
| from core.tools.errors import ( | |||
| @@ -15,6 +13,7 @@ from core.tools.errors import ( | |||
| from core.tools.provider.tool_provider import ToolProviderController | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| from core.tools.tool.tool import Tool | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| from core.utils.module_import_helper import load_single_subclass_from_source | |||
| @@ -28,10 +27,9 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| provider = self.__class__.__module__.split('.')[-1] | |||
| yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') | |||
| try: | |||
| with open(yaml_path, 'rb') as f: | |||
| provider_yaml = load(f.read(), FullLoader) | |||
| except: | |||
| raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}') | |||
| provider_yaml = load_yaml_file(yaml_path) | |||
| except Exception as e: | |||
| raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') | |||
| if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: | |||
| # set credentials name | |||
| @@ -58,18 +56,18 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path))) | |||
| tools = [] | |||
| for tool_file in tool_files: | |||
| with open(path.join(tool_path, tool_file), encoding='utf-8') as f: | |||
| # get tool name | |||
| tool_name = tool_file.split(".")[0] | |||
| tool = load(f.read(), FullLoader) | |||
| # get tool class, import the module | |||
| assistant_tool_class = load_single_subclass_from_source( | |||
| module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', | |||
| script_path=path.join(path.dirname(path.realpath(__file__)), | |||
| 'builtin', provider, 'tools', f'{tool_name}.py'), | |||
| parent_type=BuiltinTool) | |||
| tool["identity"]["provider"] = provider | |||
| tools.append(assistant_tool_class(**tool)) | |||
| # get tool name | |||
| tool_name = tool_file.split(".")[0] | |||
| tool = load_yaml_file(path.join(tool_path, tool_file)) | |||
| # get tool class, import the module | |||
| assistant_tool_class = load_single_subclass_from_source( | |||
| module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', | |||
| script_path=path.join(path.dirname(path.realpath(__file__)), | |||
| 'builtin', provider, 'tools', f'{tool_name}.py'), | |||
| parent_type=BuiltinTool) | |||
| tool["identity"]["provider"] = provider | |||
| tools.append(assistant_tool_class(**tool)) | |||
| self.tools = tools | |||
| return tools | |||
| @@ -23,7 +23,7 @@ class ToolConfigurationManager(BaseModel): | |||
| deep copy credentials | |||
| """ | |||
| return deepcopy(credentials) | |||
| def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]: | |||
| """ | |||
| encrypt tool credentials with tenant id | |||
| @@ -39,9 +39,9 @@ class ToolConfigurationManager(BaseModel): | |||
| if field_name in credentials: | |||
| encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name]) | |||
| credentials[field_name] = encrypted | |||
| return credentials | |||
| def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| mask tool credentials | |||
| @@ -58,7 +58,7 @@ class ToolConfigurationManager(BaseModel): | |||
| if len(credentials[field_name]) > 6: | |||
| credentials[field_name] = \ | |||
| credentials[field_name][:2] + \ | |||
| '*' * (len(credentials[field_name]) - 4) +\ | |||
| '*' * (len(credentials[field_name]) - 4) + \ | |||
| credentials[field_name][-2:] | |||
| else: | |||
| credentials[field_name] = '*' * len(credentials[field_name]) | |||
| @@ -72,7 +72,7 @@ class ToolConfigurationManager(BaseModel): | |||
| return a deep copy of credentials with decrypted values | |||
| """ | |||
| cache = ToolProviderCredentialsCache( | |||
| tenant_id=self.tenant_id, | |||
| tenant_id=self.tenant_id, | |||
| identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', | |||
| cache_type=ToolProviderCredentialsCacheType.PROVIDER | |||
| ) | |||
| @@ -92,10 +92,10 @@ class ToolConfigurationManager(BaseModel): | |||
| cache.set(credentials) | |||
| return credentials | |||
| def delete_tool_credentials_cache(self): | |||
| cache = ToolProviderCredentialsCache( | |||
| tenant_id=self.tenant_id, | |||
| tenant_id=self.tenant_id, | |||
| identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', | |||
| cache_type=ToolProviderCredentialsCacheType.PROVIDER | |||
| ) | |||
| @@ -116,7 +116,7 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| deep copy parameters | |||
| """ | |||
| return deepcopy(parameters) | |||
| def _merge_parameters(self) -> list[ToolParameter]: | |||
| """ | |||
| merge parameters | |||
| @@ -139,7 +139,7 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| current_parameters.append(runtime_parameter) | |||
| return current_parameters | |||
| def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| mask tool parameters | |||
| @@ -157,13 +157,13 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| if len(parameters[parameter.name]) > 6: | |||
| parameters[parameter.name] = \ | |||
| parameters[parameter.name][:2] + \ | |||
| '*' * (len(parameters[parameter.name]) - 4) +\ | |||
| '*' * (len(parameters[parameter.name]) - 4) + \ | |||
| parameters[parameter.name][-2:] | |||
| else: | |||
| parameters[parameter.name] = '*' * len(parameters[parameter.name]) | |||
| return parameters | |||
| def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| encrypt tool parameters with tenant id | |||
| @@ -180,9 +180,9 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| if parameter.name in parameters: | |||
| encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| parameters[parameter.name] = encrypted | |||
| return parameters | |||
| def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | |||
| """ | |||
| decrypt tool parameters with tenant id | |||
| @@ -190,7 +190,7 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| return a deep copy of parameters with decrypted values | |||
| """ | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| tenant_id=self.tenant_id, | |||
| provider=f'{self.provider_type}.{self.provider_name}', | |||
| tool_name=self.tool_runtime.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER, | |||
| @@ -212,15 +212,15 @@ class ToolParameterConfigurationManager(BaseModel): | |||
| parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) | |||
| except: | |||
| pass | |||
| if has_secret_input: | |||
| cache.set(parameters) | |||
| return parameters | |||
| def delete_tool_parameters_cache(self): | |||
| cache = ToolParameterCache( | |||
| tenant_id=self.tenant_id, | |||
| tenant_id=self.tenant_id, | |||
| provider=f'{self.provider_type}.{self.provider_name}', | |||
| tool_name=self.tool_runtime.identity.name, | |||
| cache_type=ToolParameterCacheType.PARAMETER, | |||
| @@ -0,0 +1,34 @@ | |||
| import logging | |||
| import os | |||
| import yaml | |||
| from yaml import YAMLError | |||
| def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: | |||
| """ | |||
| Safe loading a YAML file to a dict | |||
| :param file_path: the path of the YAML file | |||
| :param ignore_error: | |||
| if True, return empty dict if error occurs and the error will be logged in warning level | |||
| if False, raise error if error occurs | |||
| :return: a dict of the YAML content | |||
| """ | |||
| try: | |||
| if not file_path or not os.path.exists(file_path): | |||
| raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') | |||
| with open(file_path, encoding='utf-8') as file: | |||
| try: | |||
| return yaml.safe_load(file) | |||
| except Exception as e: | |||
| raise YAMLError(f'Failed to load YAML file {file_path}: {e}') | |||
| except FileNotFoundError as e: | |||
| logging.debug(f'Failed to load YAML file {file_path}: {e}') | |||
| return {} | |||
| except Exception as e: | |||
| if ignore_error: | |||
| logging.warning(f'Failed to load YAML file {file_path}: {e}') | |||
| return {} | |||
| else: | |||
| raise e | |||
| @@ -1,10 +1,9 @@ | |||
| import logging | |||
| import os | |||
| from collections import OrderedDict | |||
| from collections.abc import Callable | |||
| from typing import Any, AnyStr | |||
| import yaml | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| def get_position_map( | |||
| @@ -17,21 +16,15 @@ def get_position_map( | |||
| :param file_name: the YAML file name, default to '_position.yaml' | |||
| :return: a dict with name as key and index as value | |||
| """ | |||
| try: | |||
| position_file_name = os.path.join(folder_path, file_name) | |||
| if not os.path.exists(position_file_name): | |||
| return {} | |||
| with open(position_file_name, encoding='utf-8') as f: | |||
| positions = yaml.safe_load(f) | |||
| position_map = {} | |||
| for index, name in enumerate(positions): | |||
| if name and isinstance(name, str): | |||
| position_map[name.strip()] = index | |||
| return position_map | |||
| except: | |||
| logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') | |||
| return {} | |||
| position_file_name = os.path.join(folder_path, file_name) | |||
| positions = load_yaml_file(position_file_name, ignore_error=True) | |||
| position_map = {} | |||
| index = 0 | |||
| for _, name in enumerate(positions): | |||
| if name and isinstance(name, str): | |||
| position_map[name.strip()] = index | |||
| index += 1 | |||
| return position_map | |||
| def sort_by_position_map( | |||
| @@ -14,6 +14,7 @@ select = [ | |||
| "I", # isort rules | |||
| "UP", # pyupgrade rules | |||
| "RUF019", # unnecessary-key-check | |||
| "S506", # unsafe-yaml-load | |||
| ] | |||
| ignore = [ | |||
| "F403", # undefined-local-with-import-star | |||
| @@ -0,0 +1,34 @@ | |||
| from textwrap import dedent | |||
| import pytest | |||
| from core.utils.position_helper import get_position_map | |||
| @pytest.fixture | |||
| def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: | |||
| monkeypatch.chdir(tmp_path) | |||
| tmp_path.joinpath("example_positions.yaml").write_text(dedent( | |||
| """\ | |||
| - first | |||
| - second | |||
| # - commented | |||
| - third | |||
| - 9999999999999 | |||
| - forth | |||
| """)) | |||
| return str(tmp_path) | |||
| def test_position_helper(prepare_example_positions_yaml): | |||
| position_map = get_position_map( | |||
| folder_path=prepare_example_positions_yaml, | |||
| file_name='example_positions.yaml') | |||
| assert len(position_map) == 4 | |||
| assert position_map == { | |||
| 'first': 0, | |||
| 'second': 1, | |||
| 'third': 2, | |||
| 'forth': 3, | |||
| } | |||
| @@ -0,0 +1,74 @@ | |||
| from textwrap import dedent | |||
| import pytest | |||
| from yaml import YAMLError | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| EXAMPLE_YAML_FILE = 'example_yaml.yaml' | |||
| INVALID_YAML_FILE = 'invalid_yaml.yaml' | |||
| NON_EXISTING_YAML_FILE = 'non_existing_file.yaml' | |||
| @pytest.fixture | |||
| def prepare_example_yaml_file(tmp_path, monkeypatch) -> str: | |||
| monkeypatch.chdir(tmp_path) | |||
| file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE) | |||
| file_path.write_text(dedent( | |||
| """\ | |||
| address: | |||
| city: Example City | |||
| country: Example Country | |||
| age: 30 | |||
| gender: male | |||
| languages: | |||
| - Python | |||
| - Java | |||
| - C++ | |||
| empty_key: | |||
| """)) | |||
| return str(file_path) | |||
| @pytest.fixture | |||
| def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str: | |||
| monkeypatch.chdir(tmp_path) | |||
| file_path = tmp_path.joinpath(INVALID_YAML_FILE) | |||
| file_path.write_text(dedent( | |||
| """\ | |||
| address: | |||
| city: Example City | |||
| country: Example Country | |||
| age: 30 | |||
| gender: male | |||
| languages: | |||
| - Python | |||
| - Java | |||
| - C++ | |||
| """)) | |||
| return str(file_path) | |||
| def test_load_yaml_non_existing_file(): | |||
| assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} | |||
| assert load_yaml_file(file_path='') == {} | |||
| def test_load_valid_yaml_file(prepare_example_yaml_file): | |||
| yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) | |||
| assert len(yaml_data) > 0 | |||
| assert yaml_data['age'] == 30 | |||
| assert yaml_data['gender'] == 'male' | |||
| assert yaml_data['address']['city'] == 'Example City' | |||
| assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'} | |||
| assert yaml_data.get('empty_key') is None | |||
| assert yaml_data.get('non_existed_key') is None | |||
| def test_load_invalid_yaml_file(prepare_invalid_yaml_file): | |||
| # yaml syntax error | |||
| with pytest.raises(YAMLError): | |||
| load_yaml_file(file_path=prepare_invalid_yaml_file) | |||
| # ignore error | |||
| assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {} | |||