You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

registry.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import json
  2. import logging
  3. import threading
  4. from collections.abc import Mapping, MutableMapping
  5. from pathlib import Path
  6. from typing import Any, ClassVar, Optional
  7. class SchemaRegistry:
  8. """Schema registry manages JSON schemas with version support"""
  9. logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
  10. _default_instance: ClassVar[Optional["SchemaRegistry"]] = None
  11. _lock: ClassVar[threading.Lock] = threading.Lock()
  12. def __init__(self, base_dir: str):
  13. self.base_dir = Path(base_dir)
  14. self.versions: MutableMapping[str, MutableMapping[str, Any]] = {}
  15. self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
  16. @classmethod
  17. def default_registry(cls) -> "SchemaRegistry":
  18. """Returns the default schema registry for builtin schemas (thread-safe singleton)"""
  19. if cls._default_instance is None:
  20. with cls._lock:
  21. # Double-checked locking pattern
  22. if cls._default_instance is None:
  23. current_dir = Path(__file__).parent
  24. schema_dir = current_dir / "builtin" / "schemas"
  25. registry = cls(str(schema_dir))
  26. registry.load_all_versions()
  27. cls._default_instance = registry
  28. return cls._default_instance
  29. def load_all_versions(self) -> None:
  30. """Scans the schema directory and loads all versions"""
  31. if not self.base_dir.exists():
  32. return
  33. for entry in self.base_dir.iterdir():
  34. if not entry.is_dir():
  35. continue
  36. version = entry.name
  37. if not version.startswith("v"):
  38. continue
  39. self._load_version_dir(version, entry)
  40. def _load_version_dir(self, version: str, version_dir: Path) -> None:
  41. """Loads all schemas in a version directory"""
  42. if not version_dir.exists():
  43. return
  44. if version not in self.versions:
  45. self.versions[version] = {}
  46. for entry in version_dir.iterdir():
  47. if entry.suffix != ".json":
  48. continue
  49. schema_name = entry.stem
  50. self._load_schema(version, schema_name, entry)
  51. def _load_schema(self, version: str, schema_name: str, schema_path: Path) -> None:
  52. """Loads a single schema file"""
  53. try:
  54. with open(schema_path, encoding="utf-8") as f:
  55. schema = json.load(f)
  56. # Store the schema
  57. self.versions[version][schema_name] = schema
  58. # Extract and store metadata
  59. uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
  60. metadata = {
  61. "version": version,
  62. "title": schema.get("title", ""),
  63. "description": schema.get("description", ""),
  64. "deprecated": schema.get("deprecated", False),
  65. }
  66. self.metadata[uri] = metadata
  67. except (OSError, json.JSONDecodeError) as e:
  68. self.logger.warning("Failed to load schema %s/%s: %s", version, schema_name, e)
  69. def get_schema(self, uri: str) -> Any | None:
  70. """Retrieves a schema by URI with version support"""
  71. version, schema_name = self._parse_uri(uri)
  72. if not version or not schema_name:
  73. return None
  74. version_schemas = self.versions.get(version)
  75. if not version_schemas:
  76. return None
  77. return version_schemas.get(schema_name)
  78. def _parse_uri(self, uri: str) -> tuple[str, str]:
  79. """Parses a schema URI to extract version and schema name"""
  80. from core.schemas.resolver import parse_dify_schema_uri
  81. return parse_dify_schema_uri(uri)
  82. def list_versions(self) -> list[str]:
  83. """Returns all available versions"""
  84. return sorted(self.versions.keys())
  85. def list_schemas(self, version: str) -> list[str]:
  86. """Returns all schemas in a specific version"""
  87. version_schemas = self.versions.get(version)
  88. if not version_schemas:
  89. return []
  90. return sorted(version_schemas.keys())
  91. def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
  92. """Returns all schemas for a version in the API format"""
  93. version_schemas = self.versions.get(version, {})
  94. result: list[Mapping[str, Any]] = []
  95. for schema_name, schema in version_schemas.items():
  96. result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema})
  97. return result