您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

registry.py 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import json
  2. import threading
  3. from collections.abc import Mapping, MutableMapping
  4. from pathlib import Path
  5. from typing import Any, ClassVar, Optional
  6. class SchemaRegistry:
  7. """Schema registry manages JSON schemas with version support"""
  8. _default_instance: ClassVar[Optional["SchemaRegistry"]] = None
  9. _lock: ClassVar[threading.Lock] = threading.Lock()
  10. def __init__(self, base_dir: str):
  11. self.base_dir = Path(base_dir)
  12. self.versions: MutableMapping[str, MutableMapping[str, Any]] = {}
  13. self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
  14. @classmethod
  15. def default_registry(cls) -> "SchemaRegistry":
  16. """Returns the default schema registry for builtin schemas (thread-safe singleton)"""
  17. if cls._default_instance is None:
  18. with cls._lock:
  19. # Double-checked locking pattern
  20. if cls._default_instance is None:
  21. current_dir = Path(__file__).parent
  22. schema_dir = current_dir / "builtin" / "schemas"
  23. registry = cls(str(schema_dir))
  24. registry.load_all_versions()
  25. cls._default_instance = registry
  26. return cls._default_instance
  27. def load_all_versions(self) -> None:
  28. """Scans the schema directory and loads all versions"""
  29. if not self.base_dir.exists():
  30. return
  31. for entry in self.base_dir.iterdir():
  32. if not entry.is_dir():
  33. continue
  34. version = entry.name
  35. if not version.startswith("v"):
  36. continue
  37. self._load_version_dir(version, entry)
  38. def _load_version_dir(self, version: str, version_dir: Path) -> None:
  39. """Loads all schemas in a version directory"""
  40. if not version_dir.exists():
  41. return
  42. if version not in self.versions:
  43. self.versions[version] = {}
  44. for entry in version_dir.iterdir():
  45. if entry.suffix != ".json":
  46. continue
  47. schema_name = entry.stem
  48. self._load_schema(version, schema_name, entry)
  49. def _load_schema(self, version: str, schema_name: str, schema_path: Path) -> None:
  50. """Loads a single schema file"""
  51. try:
  52. with open(schema_path, encoding="utf-8") as f:
  53. schema = json.load(f)
  54. # Store the schema
  55. self.versions[version][schema_name] = schema
  56. # Extract and store metadata
  57. uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
  58. metadata = {
  59. "version": version,
  60. "title": schema.get("title", ""),
  61. "description": schema.get("description", ""),
  62. "deprecated": schema.get("deprecated", False),
  63. }
  64. self.metadata[uri] = metadata
  65. except (OSError, json.JSONDecodeError) as e:
  66. print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
  67. def get_schema(self, uri: str) -> Optional[Any]:
  68. """Retrieves a schema by URI with version support"""
  69. version, schema_name = self._parse_uri(uri)
  70. if not version or not schema_name:
  71. return None
  72. version_schemas = self.versions.get(version)
  73. if not version_schemas:
  74. return None
  75. return version_schemas.get(schema_name)
  76. def _parse_uri(self, uri: str) -> tuple[str, str]:
  77. """Parses a schema URI to extract version and schema name"""
  78. import re
  79. pattern = r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$"
  80. match = re.match(pattern, uri)
  81. if not match:
  82. return "", ""
  83. version = match.group(1)
  84. schema_name = match.group(2)
  85. return version, schema_name
  86. def list_versions(self) -> list[str]:
  87. """Returns all available versions"""
  88. return sorted(self.versions.keys())
  89. def list_schemas(self, version: str) -> list[str]:
  90. """Returns all schemas in a specific version"""
  91. version_schemas = self.versions.get(version)
  92. if not version_schemas:
  93. return []
  94. return sorted(version_schemas.keys())
  95. def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
  96. """Returns all schemas for a version in the API format"""
  97. version_schemas = self.versions.get(version, {})
  98. result = []
  99. for schema_name, schema in version_schemas.items():
  100. result.append({
  101. "name": schema_name,
  102. "schema": schema
  103. })
  104. return result