You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

extensible.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import enum
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. from pathlib import Path
  7. from typing import Any, Optional
  8. from pydantic import BaseModel
  9. from core.helper.position_helper import sort_to_dict_by_position_map
  10. class ExtensionModule(enum.Enum):
  11. MODERATION = "moderation"
  12. EXTERNAL_DATA_TOOL = "external_data_tool"
  13. class ModuleExtension(BaseModel):
  14. extension_class: Any = None
  15. name: str
  16. label: Optional[dict] = None
  17. form_schema: Optional[list] = None
  18. builtin: bool = True
  19. position: Optional[int] = None
  20. class Extensible:
  21. module: ExtensionModule
  22. name: str
  23. tenant_id: str
  24. config: Optional[dict] = None
  25. def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
  26. self.tenant_id = tenant_id
  27. self.config = config
  28. @classmethod
  29. def scan_extensions(cls):
  30. extensions = []
  31. position_map: dict[str, int] = {}
  32. # Get the package name from the module path
  33. package_name = ".".join(cls.__module__.split(".")[:-1])
  34. try:
  35. # Get package directory path
  36. package_spec = importlib.util.find_spec(package_name)
  37. if not package_spec or not package_spec.origin:
  38. raise ImportError(f"Could not find package {package_name}")
  39. package_dir = os.path.dirname(package_spec.origin)
  40. # Traverse subdirectories
  41. for subdir_name in os.listdir(package_dir):
  42. if subdir_name.startswith("__"):
  43. continue
  44. subdir_path = os.path.join(package_dir, subdir_name)
  45. if not os.path.isdir(subdir_path):
  46. continue
  47. extension_name = subdir_name
  48. file_names = os.listdir(subdir_path)
  49. # Check for extension module file
  50. if (extension_name + ".py") not in file_names:
  51. logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path)
  52. continue
  53. # Check for builtin flag and position
  54. builtin = False
  55. position = 0
  56. if "__builtin__" in file_names:
  57. builtin = True
  58. builtin_file_path = os.path.join(subdir_path, "__builtin__")
  59. if os.path.exists(builtin_file_path):
  60. position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
  61. position_map[extension_name] = position
  62. # Import the extension module
  63. module_name = f"{package_name}.{extension_name}.{extension_name}"
  64. spec = importlib.util.find_spec(module_name)
  65. if not spec or not spec.loader:
  66. raise ImportError(f"Failed to load module {module_name}")
  67. mod = importlib.util.module_from_spec(spec)
  68. spec.loader.exec_module(mod)
  69. # Find extension class
  70. extension_class = None
  71. for name, obj in vars(mod).items():
  72. if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
  73. extension_class = obj
  74. break
  75. if not extension_class:
  76. logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name)
  77. continue
  78. # Load schema if not builtin
  79. json_data: dict[str, Any] = {}
  80. if not builtin:
  81. json_path = os.path.join(subdir_path, "schema.json")
  82. if not os.path.exists(json_path):
  83. logging.warning("Missing schema.json file in %s, Skip.", subdir_path)
  84. continue
  85. with open(json_path, encoding="utf-8") as f:
  86. json_data = json.load(f)
  87. # Create extension
  88. extensions.append(
  89. ModuleExtension(
  90. extension_class=extension_class,
  91. name=extension_name,
  92. label=json_data.get("label"),
  93. form_schema=json_data.get("form_schema"),
  94. builtin=builtin,
  95. position=position,
  96. )
  97. )
  98. except Exception as e:
  99. logging.exception("Error scanning extensions")
  100. raise
  101. # Sort extensions by position
  102. sorted_extensions = sort_to_dict_by_position_map(
  103. position_map=position_map, data=extensions, name_func=lambda x: x.name
  104. )
  105. return sorted_extensions