| @@ -13,18 +13,10 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> | |||
| :param file_name: the YAML file name, default to '_position.yaml' | |||
| :return: a dict with name as key and index as value | |||
| """ | |||
| position_file_name = os.path.join(folder_path, file_name) | |||
| if not position_file_name or not os.path.exists(position_file_name): | |||
| return {} | |||
| positions = load_yaml_file(position_file_name, ignore_error=True) | |||
| position_map = {} | |||
| index = 0 | |||
| for _, name in enumerate(positions): | |||
| if name and isinstance(name, str): | |||
| position_map[name.strip()] = index | |||
| index += 1 | |||
| return position_map | |||
| position_file_path = os.path.join(folder_path, file_name) | |||
| yaml_content = load_yaml_file(file_path=position_file_path, default_value=[]) | |||
| positions = [item.strip() for item in yaml_content if item and isinstance(item, str) and item.strip()] | |||
| return {name: index for index, name in enumerate(positions)} | |||
| def sort_by_position_map( | |||
| @@ -162,7 +162,7 @@ class AIModel(ABC): | |||
| # traverse all model_schema_yaml_paths | |||
| for model_schema_yaml_path in model_schema_yaml_paths: | |||
| # read yaml data from yaml file | |||
| yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True) | |||
| yaml_data = load_yaml_file(model_schema_yaml_path) | |||
| new_parameter_rules = [] | |||
| for parameter_rule in yaml_data.get('parameter_rules', []): | |||
| @@ -44,7 +44,7 @@ class ModelProvider(ABC): | |||
| # read provider schema from yaml file | |||
| yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | |||
| yaml_data = load_yaml_file(yaml_path, ignore_error=True) | |||
| yaml_data = load_yaml_file(yaml_path) | |||
| try: | |||
| # yaml_data to entity | |||
| @@ -27,7 +27,7 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| provider = self.__class__.__module__.split('.')[-1] | |||
| yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') | |||
| try: | |||
| provider_yaml = load_yaml_file(yaml_path) | |||
| provider_yaml = load_yaml_file(yaml_path, ignore_error=False) | |||
| except Exception as e: | |||
| raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') | |||
| @@ -58,7 +58,7 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| for tool_file in tool_files: | |||
| # get tool name | |||
| tool_name = tool_file.split(".")[0] | |||
| tool = load_yaml_file(path.join(tool_path, tool_file)) | |||
| tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) | |||
| # get tool class, import the module | |||
| assistant_tool_class = load_single_subclass_from_source( | |||
| @@ -1,35 +1,31 @@ | |||
| import logging | |||
| import os | |||
| from typing import Any | |||
| import yaml | |||
| from yaml import YAMLError | |||
| logger = logging.getLogger(__name__) | |||
| def load_yaml_file(file_path: str, ignore_error: bool = False) -> dict: | |||
| def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any: | |||
| """ | |||
| Safe loading a YAML file to a dict | |||
| Safe loading a YAML file | |||
| :param file_path: the path of the YAML file | |||
| :param ignore_error: | |||
| if True, return empty dict if error occurs and the error will be logged in warning level | |||
| if True, return default_value if error occurs and the error will be logged in warning level | |||
| if False, raise error if error occurs | |||
| :return: a dict of the YAML content | |||
| :param default_value: the value returned when errors ignored | |||
| :return: an object of the YAML content | |||
| """ | |||
| try: | |||
| if not file_path or not os.path.exists(file_path): | |||
| raise FileNotFoundError(f'Failed to load YAML file {file_path}: file not found') | |||
| with open(file_path, encoding='utf-8') as file: | |||
| with open(file_path, encoding='utf-8') as yaml_file: | |||
| try: | |||
| return yaml.safe_load(file) | |||
| return yaml.safe_load(yaml_file) | |||
| except Exception as e: | |||
| raise YAMLError(f'Failed to load YAML file {file_path}: {e}') | |||
| except FileNotFoundError as e: | |||
| logger.debug(f'Failed to load YAML file {file_path}: {e}') | |||
| return {} | |||
| except Exception as e: | |||
| if ignore_error: | |||
| logger.warning(f'Failed to load YAML file {file_path}: {e}') | |||
| return {} | |||
| return default_value | |||
| else: | |||
| raise e | |||
| @@ -21,6 +21,20 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: | |||
| return str(tmp_path) | |||
| @pytest.fixture | |||
| def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: | |||
| monkeypatch.chdir(tmp_path) | |||
| tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent( | |||
| """\ | |||
| # - commented1 | |||
| # - commented2 | |||
| - | |||
| - | |||
| """)) | |||
| return str(tmp_path) | |||
| def test_position_helper(prepare_example_positions_yaml): | |||
| position_map = get_position_map( | |||
| folder_path=prepare_example_positions_yaml, | |||
| @@ -32,3 +46,10 @@ def test_position_helper(prepare_example_positions_yaml): | |||
| 'third': 2, | |||
| 'forth': 3, | |||
| } | |||
| def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml): | |||
| position_map = get_position_map( | |||
| folder_path=prepare_empty_commented_positions_yaml, | |||
| file_name='example_positions_all_commented.yaml') | |||
| assert position_map == {} | |||
| @@ -53,6 +53,9 @@ def test_load_yaml_non_existing_file(): | |||
| assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {} | |||
| assert load_yaml_file(file_path='') == {} | |||
| with pytest.raises(FileNotFoundError): | |||
| load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False) | |||
| def test_load_valid_yaml_file(prepare_example_yaml_file): | |||
| yaml_data = load_yaml_file(file_path=prepare_example_yaml_file) | |||
| @@ -68,7 +71,7 @@ def test_load_valid_yaml_file(prepare_example_yaml_file): | |||
| def test_load_invalid_yaml_file(prepare_invalid_yaml_file): | |||
| # yaml syntax error | |||
| with pytest.raises(YAMLError): | |||
| load_yaml_file(file_path=prepare_invalid_yaml_file) | |||
| load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=False) | |||
| # ignore error | |||
| assert load_yaml_file(file_path=prepare_invalid_yaml_file, ignore_error=True) == {} | |||
| assert load_yaml_file(file_path=prepare_invalid_yaml_file) == {} | |||