您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Optional
  8. from pydantic import BaseModel
  9. from core.helper.position_helper import sort_to_dict_by_position_map
  10. class ExtensionModule(enum.Enum):
  11. MODERATION = "moderation"
  12. EXTERNAL_DATA_TOOL = "external_data_tool"
  13. class ModuleExtension(BaseModel):
  14. extension_class: Any = None
  15. name: str
  16. label: Optional[dict] = None
  17. form_schema: Optional[list] = None
  18. builtin: bool = True
  19. position: Optional[int] = None
  20. class Extensible:
  21. module: ExtensionModule
  22. name: str
  23. tenant_id: str
  24. config: Optional[dict] = None
  25. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  26. self.tenant_id = tenant_id
  27. self.config = config
  28. @classmethod
  29. def scan_extensions(cls):
  30. extensions = []
  31. position_map: dict[str, int] = {}
  32. # Get the package name from the module path
  33. package_name = ".".join(cls.__module__.split(".")[:-1])
  34. try:
  35. # Get package directory path
  36. package_spec = importlib.util.find_spec(package_name)
  37. if not package_spec or not package_spec.origin:
  38. raise ImportError(f"Could not find package {package_name}")
  39. package_dir = os.path.dirname(package_spec.origin)
  40. # Traverse subdirectories
  41. for subdir_name in os.listdir(package_dir):
  42. if subdir_name.startswith("__"):
  43. continue
  44. subdir_path = os.path.join(package_dir, subdir_name)
  45. if not os.path.isdir(subdir_path):
  46. continue
  47. extension_name = subdir_name
  48. file_names = os.listdir(subdir_path)
  49. # Check for extension module file
  50. if (extension_name + ".py") not in file_names:
  51. logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
  52. continue
  53. # Check for builtin flag and position
  54. builtin = False
  55. position = 0
  56. if "__builtin__" in file_names:
  57. builtin = True
  58. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  59. if os.path.exists(builtin_file_path):
  60. position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
  61. position_map[extension_name] = position
  62. # Import the extension module
  63. module_name = f"{package_name}.{extension_name}.{extension_name}"
  64. spec = importlib.util.find_spec(module_name)
  65. if not spec or not spec.loader:
  66. raise ImportError(f"Failed to load module {module_name}")
  67. mod = importlib.util.module_from_spec(spec)
  68. spec.loader.exec_module(mod)
  69. # Find extension class
  70. extension_class = None
  71. for name, obj in vars(mod).items():
  72. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  73. extension_class = obj
  74. break
  75. if not extension_class:
  76. logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
  77. continue
  78. # Load schema if not builtin
  79. json_data: dict[str, Any] = {}
  80. if not builtin:
  81. json_path = os.path.join(subdir_path, "schema.json")
  82. if not os.path.exists(json_path):
  83. logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
  84. continue
  85. with open(json_path, encoding="utf-8") as f:
  86. json_data = json.load(f)
  87. # Create extension
  88. extensions.append(
  89. ModuleExtension(
  90. extension_class=extension_class,
  91. name=extension_name,
  92. label=json_data.get("label"),
  93. form_schema=json_data.get("form_schema"),
  94. builtin=builtin,
  95. position=position,
  96. )
  97. )
  98. except Exception as e:
  99. logging.exception("Error scanning extensions")
  100. raise
  101. # Sort extensions by position
  102. sorted_extensions = sort_to_dict_by_position_map(
  103. position_map=position_map, data=extensions, name_func=lambda x: x.name
  104. )
  105. return sorted_extensions