Browse Source

generalize position helper for parsing _position.yaml and sorting objects by name (#2803)

tags/0.5.10
Bowen Liang 1 year ago
parent
commit
8b15b742ad
No account linked to committer's email address

+ 8
- 6
api/core/extension/extensible.py View File

import json import json
import logging import logging
import os import os
from collections import OrderedDict
from typing import Any, Optional from typing import Any, Optional


from pydantic import BaseModel from pydantic import BaseModel


from core.utils.position_helper import sort_to_dict_by_position_map



class ExtensionModule(enum.Enum): class ExtensionModule(enum.Enum):
MODERATION = 'moderation' MODERATION = 'moderation'


@classmethod @classmethod
def scan_extensions(cls): def scan_extensions(cls):
extensions = {}
extensions: list[ModuleExtension] = []
position_map = {}


# get the path of the current class # get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f: with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip()) position = int(f.read().strip())
position_map[extension_name] = position


if (extension_name + '.py') not in file_names: if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
with open(json_path, encoding='utf-8') as f: with open(json_path, encoding='utf-8') as f:
json_data = json.load(f) json_data = json.load(f)


extensions[extension_name] = ModuleExtension(
extensions.append(ModuleExtension(
extension_class=extension_class, extension_class=extension_class,
name=extension_name, name=extension_name,
label=json_data.get('label'), label=json_data.get('label'),
form_schema=json_data.get('form_schema'), form_schema=json_data.get('form_schema'),
builtin=builtin, builtin=builtin,
position=position position=position
)
))


sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)


return sorted_extensions return sorted_extensions

+ 3
- 11
api/core/model_runtime/model_providers/__base/ai_model.py View File

) )
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map




class AIModel(ABC): class AIModel(ABC):
] ]


# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')

# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
position_map = get_position_map(provider_model_type_path)


# traverse all model_schema_yaml_paths # traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths: for model_schema_yaml_path in model_schema_yaml_paths:
model_schemas.append(model_schema) model_schemas.append(model_schema)


# resort model schemas by position # resort model schemas by position
if position_map:
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)


# cache model schemas # cache model schemas
self.model_schemas = model_schemas self.model_schemas = model_schemas

+ 6
- 16
api/core/model_runtime/model_providers/model_provider_factory.py View File

import importlib import importlib
import logging import logging
import os import os
from collections import OrderedDict
from typing import Optional from typing import Optional


import yaml
from pydantic import BaseModel from pydantic import BaseModel


from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


if self.model_provider_extensions: if self.model_provider_extensions:
return self.model_provider_extensions return self.model_provider_extensions


model_providers = {}


# get the path of current classes # get the path of current classes
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
] ]


# get _position.yaml file path # get _position.yaml file path
position_file_path = os.path.join(model_providers_path, '_position.yaml')

# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
position_map = get_position_map(model_providers_path)


# traverse all model_provider_dir_paths # traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths: for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name # get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path) model_provider_name = os.path.basename(model_provider_dir_path)
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue continue


model_providers[model_provider_name] = ModelProviderExtension(
model_providers.append(ModelProviderExtension(
name=model_provider_name, name=model_provider_name,
provider_instance=model_provider_class(), provider_instance=model_provider_class(),
position=position_map.get(model_provider_name) position=position_map.get(model_provider_name)
)
))


sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)


self.model_provider_extensions = sorted_extensions self.model_provider_extensions = sorted_extensions



+ 8
- 13
api/core/tools/provider/builtin/_positions.py View File

import os.path import os.path


from yaml import FullLoader, load

from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map




class BuiltinToolProviderSort: class BuiltinToolProviderSort:
@classmethod @classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position: if not cls._position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
cls._position = tmp_position
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))


def sort_compare(provider: UserToolProvider) -> int:
def name_func(provider: UserToolProvider) -> str:
if provider.type == UserToolProvider.ProviderType.MODEL: if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000)
return cls._position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare)
return f'model.{provider.name}'
else:
return provider.name

sorted_providers = sort_by_position_map(cls._position, providers, name_func)


return sorted_providers return sorted_providers

+ 70
- 0
api/core/utils/position_helper.py View File

import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr

import yaml


def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}

with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}


def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: the sorted objects
"""
if not position_map or not data:
return data

return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))


def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])

Loading…
Cancel
Save