Browse Source

generalize helper for loading module from source (#2862)

tags/0.5.11
Bowen Liang 1 year ago
parent
commit
08b727833e
No account linked to committer's email address

+ 4
- 12
api/core/extension/extensible.py View File

import enum import enum
import importlib.util
import json import json
import logging import logging
import os import os


from pydantic import BaseModel from pydantic import BaseModel


from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import sort_to_dict_by_position_map from core.utils.position_helper import sort_to_dict_by_position_map






# Dynamic loading {subdir_name}.py file and find the subclass of Extensible # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py') py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break

if not extension_class:
try:
extension_class = load_single_subclass_from_source(extension_name, py_path, cls)
except Exception:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue continue



+ 5
- 12
api/core/model_runtime/model_providers/__base/model_provider.py View File

import importlib
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod


from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source




class ModelProvider(ABC): class ModelProvider(ABC):


# Dynamic loading {model_type_name}.py file and find the subclass of AIModel # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
spec = importlib.util.spec_from_file_location(f"{parent_module}.{model_type_name}.{model_type_name}", model_type_py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_class = None
for name, obj in vars(mod).items():
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
and obj != AIModel and obj.__module__ == mod.__name__):
model_class = obj
break

mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
if not model_class: if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')



+ 5
- 10
api/core/model_runtime/model_providers/model_provider_factory.py View File

import importlib
import logging import logging
import os import os
from typing import Optional from typing import Optional
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.module_import_helper import load_single_subclass_from_source
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
spec = importlib.util.spec_from_file_location(f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

model_provider_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, ModelProvider) and obj != ModelProvider:
model_provider_class = obj
break
model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
script_path=py_path,
parent_type=ModelProvider)


if not model_provider_class: if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")

+ 6
- 11
api/core/tools/provider/builtin_tool_provider.py View File

import importlib
from abc import abstractmethod from abc import abstractmethod
from os import listdir, path from os import listdir, path
from typing import Any from typing import Any
from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.utils.module_import_helper import load_single_subclass_from_source




class BuiltinToolProviderController(ToolProviderController): class BuiltinToolProviderController(ToolProviderController):
tool_name = tool_file.split(".")[0] tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader) tool = load(f.read(), FullLoader)
# get tool class, import the module # get tool class, import the module
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
]
assistant_tool_class = classes[0]
assistant_tool_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'builtin', provider, 'tools', f'{tool_name}.py'),
parent_type=BuiltinTool)
tools.append(assistant_tool_class(**tool)) tools.append(assistant_tool_class(**tool))


self.tools = tools self.tools = tools

+ 11
- 32
api/core/tools/tool_manager.py View File

import importlib
import json import json
import logging import logging
import mimetypes import mimetypes
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
) )
from core.tools.utils.encoder import serialize_base_model_dict from core.tools.utils.encoder import serialize_base_model_dict
from core.utils.module_import_helper import load_single_subclass_from_source
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider




if provider_entity is None: if provider_entity is None:
# fetch the provider from .provider.builtin # fetch the provider from .provider.builtin
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# get all the classes in the module
classes = [ x for _, x in vars(mod).items()
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
provider_entity = classes[0]()
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
parent_type=ToolProviderController)
provider_entity = provider_class()


return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages) return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
if provider.startswith('__'): if provider.startswith('__'):
continue continue


py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)

# load all classes
classes = [
obj for name, obj in vars(mod).items()
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
# init provider # init provider
provider_class = classes[0]
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class()) builtin_providers.append(provider_class())


# cache the builtin providers # cache the builtin providers

+ 62
- 0
api/core/utils/module_import_helper.py View File

import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr


def import_module_from_source(
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
"""
Importing a module from the source file directly
"""
try:
existed_spec = importlib.util.find_spec(module_name)
if existed_spec:
spec = existed_spec
else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader)
module = importlib.util.module_from_spec(spec)
if not existed_spec:
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
raise e


def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type]:
"""
Get all the subclasses of the parent type from the module
"""
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)]
return classes


def load_single_subclass_from_source(
module_name: str,
script_path: AnyStr,
parent_type: type,
use_lazy_loader: bool = False,
) -> type:
"""
Load a single subclass from the source
"""
module = import_module_from_source(module_name, script_path, use_lazy_loader)
subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses):
case 1:
return subclasses[0]
case 0:
raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}')
case _:
raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}')

+ 7
- 0
api/tests/integration_tests/utils/child_class.py View File

from tests.integration_tests.utils.parent_class import ParentClass


class ChildClass(ParentClass):
def __init__(self, name: str):
super().__init__(name)
self.name = name

+ 7
- 0
api/tests/integration_tests/utils/lazy_load_class.py View File

from tests.integration_tests.utils.parent_class import ParentClass


class LazyLoadChildClass(ParentClass):
def __init__(self, name: str):
super().__init__(name)
self.name = name

+ 6
- 0
api/tests/integration_tests/utils/parent_class.py View File

class ParentClass:
def __init__(self, name):
self.name = name

def get_name(self):
return self.name

+ 32
- 0
api/tests/integration_tests/utils/test_module_import_helper.py View File

import os

from core.utils.module_import_helper import load_single_subclass_from_source, import_module_from_source
from tests.integration_tests.utils.parent_class import ParentClass


def test_loading_subclass_from_source():
current_path = os.getcwd()
module = load_single_subclass_from_source(
module_name='ChildClass',
script_path=os.path.join(current_path, 'child_class.py'),
parent_type=ParentClass)
assert module and module.__name__ == 'ChildClass'


def test_load_import_module_from_source():
current_path = os.getcwd()
module = import_module_from_source(
module_name='ChildClass',
py_file_path=os.path.join(current_path, 'child_class.py'))
assert module and module.__name__ == 'ChildClass'


def test_lazy_loading_subclass_from_source():
current_path = os.getcwd()
clz = load_single_subclass_from_source(
module_name='LazyLoadChildClass',
script_path=os.path.join(current_path, 'lazy_load_class.py'),
parent_type=ParentClass,
use_lazy_loader=True)
instance = clz('dify')
assert instance.get_name() == 'dify'

Loading…
Cancel
Save