| import json | import json | ||||
| import logging | import logging | ||||
| import os | import os | ||||
| from collections import OrderedDict | |||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.utils.position_helper import sort_to_dict_by_position_map | |||||
| class ExtensionModule(enum.Enum): | class ExtensionModule(enum.Enum): | ||||
| MODERATION = 'moderation' | MODERATION = 'moderation' | ||||
| @classmethod | @classmethod | ||||
| def scan_extensions(cls): | def scan_extensions(cls): | ||||
| extensions = {} | |||||
| extensions: list[ModuleExtension] = [] | |||||
| position_map = {} | |||||
| # get the path of the current class | # get the path of the current class | ||||
| current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') | current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') | ||||
| if os.path.exists(builtin_file_path): | if os.path.exists(builtin_file_path): | ||||
| with open(builtin_file_path, encoding='utf-8') as f: | with open(builtin_file_path, encoding='utf-8') as f: | ||||
| position = int(f.read().strip()) | position = int(f.read().strip()) | ||||
| position_map[extension_name] = position | |||||
| if (extension_name + '.py') not in file_names: | if (extension_name + '.py') not in file_names: | ||||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | ||||
| with open(json_path, encoding='utf-8') as f: | with open(json_path, encoding='utf-8') as f: | ||||
| json_data = json.load(f) | json_data = json.load(f) | ||||
| extensions[extension_name] = ModuleExtension( | |||||
| extensions.append(ModuleExtension( | |||||
| extension_class=extension_class, | extension_class=extension_class, | ||||
| name=extension_name, | name=extension_name, | ||||
| label=json_data.get('label'), | label=json_data.get('label'), | ||||
| form_schema=json_data.get('form_schema'), | form_schema=json_data.get('form_schema'), | ||||
| builtin=builtin, | builtin=builtin, | ||||
| position=position | position=position | ||||
| ) | |||||
| )) | |||||
| sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) | |||||
| sorted_extensions = OrderedDict(sorted_items) | |||||
| sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) | |||||
| return sorted_extensions | return sorted_extensions |
| ) | ) | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | ||||
| from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | ||||
| from core.utils.position_helper import get_position_map, sort_by_position_map | |||||
| class AIModel(ABC): | class AIModel(ABC): | ||||
| ] | ] | ||||
| # get _position.yaml file path | # get _position.yaml file path | ||||
| position_file_path = os.path.join(provider_model_type_path, '_position.yaml') | |||||
| # read _position.yaml file | |||||
| position_map = {} | |||||
| if os.path.exists(position_file_path): | |||||
| with open(position_file_path, encoding='utf-8') as f: | |||||
| positions = yaml.safe_load(f) | |||||
| # convert list to dict with key as model provider name, value as index | |||||
| position_map = {position: index for index, position in enumerate(positions)} | |||||
| position_map = get_position_map(provider_model_type_path) | |||||
| # traverse all model_schema_yaml_paths | # traverse all model_schema_yaml_paths | ||||
| for model_schema_yaml_path in model_schema_yaml_paths: | for model_schema_yaml_path in model_schema_yaml_paths: | ||||
| model_schemas.append(model_schema) | model_schemas.append(model_schema) | ||||
| # resort model schemas by position | # resort model schemas by position | ||||
| if position_map: | |||||
| model_schemas.sort(key=lambda x: position_map.get(x.model, 999)) | |||||
| model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model) | |||||
| # cache model schemas | # cache model schemas | ||||
| self.model_schemas = model_schemas | self.model_schemas = model_schemas |
| import importlib | import importlib | ||||
| import logging | import logging | ||||
| import os | import os | ||||
| from collections import OrderedDict | |||||
| from typing import Optional | from typing import Optional | ||||
| import yaml | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||||
| from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator | from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator | ||||
| from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator | from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator | ||||
| from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| if self.model_provider_extensions: | if self.model_provider_extensions: | ||||
| return self.model_provider_extensions | return self.model_provider_extensions | ||||
| model_providers = {} | |||||
| # get the path of current classes | # get the path of current classes | ||||
| current_path = os.path.abspath(__file__) | current_path = os.path.abspath(__file__) | ||||
| ] | ] | ||||
| # get _position.yaml file path | # get _position.yaml file path | ||||
| position_file_path = os.path.join(model_providers_path, '_position.yaml') | |||||
| # read _position.yaml file | |||||
| position_map = {} | |||||
| if os.path.exists(position_file_path): | |||||
| with open(position_file_path, encoding='utf-8') as f: | |||||
| positions = yaml.safe_load(f) | |||||
| # convert list to dict with key as model provider name, value as index | |||||
| position_map = {position: index for index, position in enumerate(positions)} | |||||
| position_map = get_position_map(model_providers_path) | |||||
| # traverse all model_provider_dir_paths | # traverse all model_provider_dir_paths | ||||
| model_providers: list[ModelProviderExtension] = [] | |||||
| for model_provider_dir_path in model_provider_dir_paths: | for model_provider_dir_path in model_provider_dir_paths: | ||||
| # get model_provider dir name | # get model_provider dir name | ||||
| model_provider_name = os.path.basename(model_provider_dir_path) | model_provider_name = os.path.basename(model_provider_dir_path) | ||||
| logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") | logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") | ||||
| continue | continue | ||||
| model_providers[model_provider_name] = ModelProviderExtension( | |||||
| model_providers.append(ModelProviderExtension( | |||||
| name=model_provider_name, | name=model_provider_name, | ||||
| provider_instance=model_provider_class(), | provider_instance=model_provider_class(), | ||||
| position=position_map.get(model_provider_name) | position=position_map.get(model_provider_name) | ||||
| ) | |||||
| )) | |||||
| sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position)) | |||||
| sorted_extensions = OrderedDict(sorted_items) | |||||
| sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) | |||||
| self.model_provider_extensions = sorted_extensions | self.model_provider_extensions = sorted_extensions | ||||
| import os.path | import os.path | ||||
| from yaml import FullLoader, load | |||||
| from core.tools.entities.user_entities import UserToolProvider | from core.tools.entities.user_entities import UserToolProvider | ||||
| from core.utils.position_helper import get_position_map, sort_by_position_map | |||||
| class BuiltinToolProviderSort: | class BuiltinToolProviderSort: | ||||
| @classmethod | @classmethod | ||||
| def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: | def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: | ||||
| if not cls._position: | if not cls._position: | ||||
| tmp_position = {} | |||||
| file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') | |||||
| with open(file_path) as f: | |||||
| for pos, val in enumerate(load(f, Loader=FullLoader)): | |||||
| tmp_position[val] = pos | |||||
| cls._position = tmp_position | |||||
| cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) | |||||
| def sort_compare(provider: UserToolProvider) -> int: | |||||
| def name_func(provider: UserToolProvider) -> str: | |||||
| if provider.type == UserToolProvider.ProviderType.MODEL: | if provider.type == UserToolProvider.ProviderType.MODEL: | ||||
| return cls._position.get(f'model.{provider.name}', 10000) | |||||
| return cls._position.get(provider.name, 10000) | |||||
| sorted_providers = sorted(providers, key=sort_compare) | |||||
| return f'model.{provider.name}' | |||||
| else: | |||||
| return provider.name | |||||
| sorted_providers = sort_by_position_map(cls._position, providers, name_func) | |||||
| return sorted_providers | return sorted_providers |
| import logging | |||||
| import os | |||||
| from collections import OrderedDict | |||||
| from collections.abc import Callable | |||||
| from typing import Any, AnyStr | |||||
| import yaml | |||||
| def get_position_map( | |||||
| folder_path: AnyStr, | |||||
| file_name: str = '_position.yaml', | |||||
| ) -> dict[str, int]: | |||||
| """ | |||||
| Get the mapping from name to index from a YAML file | |||||
| :param folder_path: | |||||
| :param file_name: the YAML file name, default to '_position.yaml' | |||||
| :return: a dict with name as key and index as value | |||||
| """ | |||||
| try: | |||||
| position_file_name = os.path.join(folder_path, file_name) | |||||
| if not os.path.exists(position_file_name): | |||||
| return {} | |||||
| with open(position_file_name, encoding='utf-8') as f: | |||||
| positions = yaml.safe_load(f) | |||||
| position_map = {} | |||||
| for index, name in enumerate(positions): | |||||
| if name and isinstance(name, str): | |||||
| position_map[name.strip()] = index | |||||
| return position_map | |||||
| except: | |||||
| logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') | |||||
| return {} | |||||
| def sort_by_position_map( | |||||
| position_map: dict[str, int], | |||||
| data: list[Any], | |||||
| name_func: Callable[[Any], str], | |||||
| ) -> list[Any]: | |||||
| """ | |||||
| Sort the objects by the position map. | |||||
| If the name of the object is not in the position map, it will be put at the end. | |||||
| :param position_map: the map holding positions in the form of {name: index} | |||||
| :param name_func: the function to get the name of the object | |||||
| :param data: the data to be sorted | |||||
| :return: the sorted objects | |||||
| """ | |||||
| if not position_map or not data: | |||||
| return data | |||||
| return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) | |||||
| def sort_to_dict_by_position_map( | |||||
| position_map: dict[str, int], | |||||
| data: list[Any], | |||||
| name_func: Callable[[Any], str], | |||||
| ) -> OrderedDict[str, Any]: | |||||
| """ | |||||
| Sort the objects into a ordered dict by the position map. | |||||
| If the name of the object is not in the position map, it will be put at the end. | |||||
| :param position_map: the map holding positions in the form of {name: index} | |||||
| :param name_func: the function to get the name of the object | |||||
| :param data: the data to be sorted | |||||
| :return: an OrderedDict with the sorted pairs of name and object | |||||
| """ | |||||
| sorted_items = sort_by_position_map(position_map, data, name_func) | |||||
| return OrderedDict([(name_func(item), item) for item in sorted_items]) |