You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extensible.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Optional
  8. from pydantic import BaseModel
  9. from core.helper.position_helper import sort_to_dict_by_position_map
  10. logger = logging.getLogger(__name__)
  11. class ExtensionModule(enum.Enum):
  12. MODERATION = "moderation"
  13. EXTERNAL_DATA_TOOL = "external_data_tool"
  14. class ModuleExtension(BaseModel):
  15. extension_class: Optional[Any] = None
  16. name: str
  17. label: Optional[dict] = None
  18. form_schema: Optional[list] = None
  19. builtin: bool = True
  20. position: Optional[int] = None
  21. class Extensible:
  22. module: ExtensionModule
  23. name: str
  24. tenant_id: str
  25. config: Optional[dict] = None
  26. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  27. self.tenant_id = tenant_id
  28. self.config = config
  29. @classmethod
  30. def scan_extensions(cls):
  31. extensions = []
  32. position_map: dict[str, int] = {}
  33. # Get the package name from the module path
  34. package_name = ".".join(cls.__module__.split(".")[:-1])
  35. try:
  36. # Get package directory path
  37. package_spec = importlib.util.find_spec(package_name)
  38. if not package_spec or not package_spec.origin:
  39. raise ImportError(f"Could not find package {package_name}")
  40. package_dir = os.path.dirname(package_spec.origin)
  41. # Traverse subdirectories
  42. for subdir_name in os.listdir(package_dir):
  43. if subdir_name.startswith("__"):
  44. continue
  45. subdir_path = os.path.join(package_dir, subdir_name)
  46. if not os.path.isdir(subdir_path):
  47. continue
  48. extension_name = subdir_name
  49. file_names = os.listdir(subdir_path)
  50. # Check for extension module file
  51. if (extension_name + ".py") not in file_names:
  52. logger.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
  53. continue
  54. # Check for builtin flag and position
  55. builtin = False
  56. position = 0
  57. if "__builtin__" in file_names:
  58. builtin = True
  59. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  60. if os.path.exists(builtin_file_path):
  61. position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
  62. position_map[extension_name] = position
  63. # Import the extension module
  64. module_name = f"{package_name}.{extension_name}.{extension_name}"
  65. spec = importlib.util.find_spec(module_name)
  66. if not spec or not spec.loader:
  67. raise ImportError(f"Failed to load module {module_name}")
  68. mod = importlib.util.module_from_spec(spec)
  69. spec.loader.exec_module(mod)
  70. # Find extension class
  71. extension_class = None
  72. for obj in vars(mod).values():
  73. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  74. extension_class = obj
  75. break
  76. if not extension_class:
  77. logger.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
  78. continue
  79. # Load schema if not builtin
  80. json_data: dict[str, Any] = {}
  81. if not builtin:
  82. json_path = os.path.join(subdir_path, "schema.json")
  83. if not os.path.exists(json_path):
  84. logger.warning("Missing schema.json file in %s, Skip.", subdir_path)
  85. continue
  86. with open(json_path, encoding="utf-8") as f:
  87. json_data = json.load(f)
  88. # Create extension
  89. extensions.append(
  90. ModuleExtension(
  91. extension_class=extension_class,
  92. name=extension_name,
  93. label=json_data.get("label"),
  94. form_schema=json_data.get("form_schema"),
  95. builtin=builtin,
  96. position=position,
  97. )
  98. )
  99. except Exception:
  100. logger.exception("Error scanning extensions")
  101. raise
  102. # Sort extensions by position
  103. sorted_extensions = sort_to_dict_by_position_map(
  104. position_map=position_map, data=extensions, name_func=lambda x: x.name
  105. )
  106. return sorted_extensions