You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

resolver.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import logging
  2. import re
  3. import threading
  4. from collections import deque
  5. from dataclasses import dataclass
  6. from typing import Any, Optional, Union
  7. from core.schemas.registry import SchemaRegistry
  8. logger = logging.getLogger(__name__)
  9. # Type aliases for better clarity
  10. SchemaType = Union[dict[str, Any], list[Any], str, int, float, bool, None]
  11. SchemaDict = dict[str, Any]
  12. # Pre-compiled pattern for better performance
  13. _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$")
  14. class SchemaResolutionError(Exception):
  15. """Base exception for schema resolution errors"""
  16. pass
  17. class CircularReferenceError(SchemaResolutionError):
  18. """Raised when a circular reference is detected"""
  19. def __init__(self, ref_uri: str, ref_path: list[str]):
  20. self.ref_uri = ref_uri
  21. self.ref_path = ref_path
  22. super().__init__(f"Circular reference detected: {ref_uri} in path {' -> '.join(ref_path)}")
  23. class MaxDepthExceededError(SchemaResolutionError):
  24. """Raised when maximum resolution depth is exceeded"""
  25. def __init__(self, max_depth: int):
  26. self.max_depth = max_depth
  27. super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
  28. class SchemaNotFoundError(SchemaResolutionError):
  29. """Raised when a referenced schema cannot be found"""
  30. def __init__(self, ref_uri: str):
  31. self.ref_uri = ref_uri
  32. super().__init__(f"Schema not found: {ref_uri}")
  33. @dataclass
  34. class QueueItem:
  35. """Represents an item in the BFS queue"""
  36. current: Any
  37. parent: Optional[Any]
  38. key: Optional[Union[str, int]]
  39. depth: int
  40. ref_path: set[str]
  41. class SchemaResolver:
  42. """Resolver for Dify schema references with caching and optimizations"""
  43. _cache: dict[str, SchemaDict] = {}
  44. _cache_lock = threading.Lock()
  45. def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
  46. """
  47. Initialize the schema resolver
  48. Args:
  49. registry: Schema registry to use (defaults to default registry)
  50. max_depth: Maximum depth for reference resolution
  51. """
  52. self.registry = registry or SchemaRegistry.default_registry()
  53. self.max_depth = max_depth
  54. @classmethod
  55. def clear_cache(cls) -> None:
  56. """Clear the global schema cache"""
  57. with cls._cache_lock:
  58. cls._cache.clear()
  59. def resolve(self, schema: SchemaType) -> SchemaType:
  60. """
  61. Resolve all $ref references in the schema
  62. Performance optimization: quickly checks for $ref presence before processing.
  63. Args:
  64. schema: Schema to resolve
  65. Returns:
  66. Resolved schema with all references expanded
  67. Raises:
  68. CircularReferenceError: If circular reference detected
  69. MaxDepthExceededError: If max depth exceeded
  70. SchemaNotFoundError: If referenced schema not found
  71. """
  72. if not isinstance(schema, (dict, list)):
  73. return schema
  74. # Fast path: if no Dify refs found, return original schema unchanged
  75. # This avoids expensive deepcopy and BFS traversal for schemas without refs
  76. if not _has_dify_refs(schema):
  77. return schema
  78. # Slow path: schema contains refs, perform full resolution
  79. import copy
  80. result = copy.deepcopy(schema)
  81. # Initialize BFS queue
  82. queue = deque([QueueItem(
  83. current=result,
  84. parent=None,
  85. key=None,
  86. depth=0,
  87. ref_path=set()
  88. )])
  89. while queue:
  90. item = queue.popleft()
  91. # Process the current item
  92. self._process_queue_item(queue, item)
  93. return result
  94. def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
  95. """Process a single queue item"""
  96. if isinstance(item.current, dict):
  97. self._process_dict(queue, item)
  98. elif isinstance(item.current, list):
  99. self._process_list(queue, item)
  100. def _process_dict(self, queue: deque, item: QueueItem) -> None:
  101. """Process a dictionary item"""
  102. ref_uri = item.current.get("$ref")
  103. if ref_uri and _is_dify_schema_ref(ref_uri):
  104. # Handle $ref resolution
  105. self._resolve_ref(queue, item, ref_uri)
  106. else:
  107. # Process nested items
  108. for key, value in item.current.items():
  109. if isinstance(value, (dict, list)):
  110. next_depth = item.depth + 1
  111. if next_depth >= self.max_depth:
  112. raise MaxDepthExceededError(self.max_depth)
  113. queue.append(QueueItem(
  114. current=value,
  115. parent=item.current,
  116. key=key,
  117. depth=next_depth,
  118. ref_path=item.ref_path
  119. ))
  120. def _process_list(self, queue: deque, item: QueueItem) -> None:
  121. """Process a list item"""
  122. for idx, value in enumerate(item.current):
  123. if isinstance(value, (dict, list)):
  124. next_depth = item.depth + 1
  125. if next_depth >= self.max_depth:
  126. raise MaxDepthExceededError(self.max_depth)
  127. queue.append(QueueItem(
  128. current=value,
  129. parent=item.current,
  130. key=idx,
  131. depth=next_depth,
  132. ref_path=item.ref_path
  133. ))
  134. def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
  135. """Resolve a $ref reference"""
  136. # Check for circular reference
  137. if ref_uri in item.ref_path:
  138. # Mark as circular and skip
  139. item.current["$circular_ref"] = True
  140. logger.warning("Circular reference detected: %s", ref_uri)
  141. return
  142. # Get resolved schema (from cache or registry)
  143. resolved_schema = self._get_resolved_schema(ref_uri)
  144. if not resolved_schema:
  145. logger.warning("Schema not found: %s", ref_uri)
  146. return
  147. # Update ref path
  148. new_ref_path = item.ref_path | {ref_uri}
  149. # Replace the reference with resolved schema
  150. next_depth = item.depth + 1
  151. if next_depth >= self.max_depth:
  152. raise MaxDepthExceededError(self.max_depth)
  153. if item.parent is None:
  154. # Root level replacement
  155. item.current.clear()
  156. item.current.update(resolved_schema)
  157. queue.append(QueueItem(
  158. current=item.current,
  159. parent=None,
  160. key=None,
  161. depth=next_depth,
  162. ref_path=new_ref_path
  163. ))
  164. else:
  165. # Update parent container
  166. item.parent[item.key] = resolved_schema.copy()
  167. queue.append(QueueItem(
  168. current=item.parent[item.key],
  169. parent=item.parent,
  170. key=item.key,
  171. depth=next_depth,
  172. ref_path=new_ref_path
  173. ))
  174. def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
  175. """Get resolved schema from cache or registry"""
  176. # Check cache first
  177. with self._cache_lock:
  178. if ref_uri in self._cache:
  179. return self._cache[ref_uri].copy()
  180. # Fetch from registry
  181. schema = self.registry.get_schema(ref_uri)
  182. if not schema:
  183. return None
  184. # Clean and cache
  185. cleaned = _remove_metadata_fields(schema)
  186. with self._cache_lock:
  187. self._cache[ref_uri] = cleaned
  188. return cleaned.copy()
  189. def resolve_dify_schema_refs(
  190. schema: SchemaType,
  191. registry: Optional[SchemaRegistry] = None,
  192. max_depth: int = 30
  193. ) -> SchemaType:
  194. """
  195. Resolve $ref references in Dify schema to actual schema content
  196. This is a convenience function that creates a resolver and resolves the schema.
  197. Performance optimization: quickly checks for $ref presence before processing.
  198. Args:
  199. schema: Schema object that may contain $ref references
  200. registry: Optional schema registry, defaults to default registry
  201. max_depth: Maximum depth to prevent infinite loops (default: 30)
  202. Returns:
  203. Schema with all $ref references resolved to actual content
  204. Raises:
  205. CircularReferenceError: If circular reference detected
  206. MaxDepthExceededError: If maximum depth exceeded
  207. SchemaNotFoundError: If referenced schema not found
  208. """
  209. # Fast path: if no Dify refs found, return original schema unchanged
  210. # This avoids expensive deepcopy and BFS traversal for schemas without refs
  211. if not _has_dify_refs(schema):
  212. return schema
  213. # Slow path: schema contains refs, perform full resolution
  214. resolver = SchemaResolver(registry, max_depth)
  215. return resolver.resolve(schema)
  216. def _remove_metadata_fields(schema: dict) -> dict:
  217. """
  218. Remove metadata fields from schema that shouldn't be included in resolved output
  219. Args:
  220. schema: Schema dictionary
  221. Returns:
  222. Cleaned schema without metadata fields
  223. """
  224. # Create a copy and remove metadata fields
  225. cleaned = schema.copy()
  226. metadata_fields = ["$id", "$schema", "version"]
  227. for field in metadata_fields:
  228. cleaned.pop(field, None)
  229. return cleaned
  230. def _is_dify_schema_ref(ref_uri: Any) -> bool:
  231. """
  232. Check if the reference URI is a Dify schema reference
  233. Args:
  234. ref_uri: URI to check
  235. Returns:
  236. True if it's a Dify schema reference
  237. """
  238. if not isinstance(ref_uri, str):
  239. return False
  240. # Use pre-compiled pattern for better performance
  241. return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
  242. def _has_dify_refs_recursive(schema: SchemaType) -> bool:
  243. """
  244. Recursively check if a schema contains any Dify $ref references
  245. This is the fallback method when string-based detection is not possible.
  246. Args:
  247. schema: Schema to check for references
  248. Returns:
  249. True if any Dify $ref is found, False otherwise
  250. """
  251. if isinstance(schema, dict):
  252. # Check if this dict has a $ref field
  253. ref_uri = schema.get("$ref")
  254. if ref_uri and _is_dify_schema_ref(ref_uri):
  255. return True
  256. # Check nested values
  257. for value in schema.values():
  258. if _has_dify_refs_recursive(value):
  259. return True
  260. elif isinstance(schema, list):
  261. # Check each item in the list
  262. for item in schema:
  263. if _has_dify_refs_recursive(item):
  264. return True
  265. # Primitive types don't contain refs
  266. return False
  267. def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
  268. """
  269. Hybrid detection: fast string scan followed by precise recursive check
  270. Performance optimization using two-phase detection:
  271. 1. Fast string scan to quickly eliminate schemas without $ref
  272. 2. Precise recursive validation only for potential candidates
  273. Args:
  274. schema: Schema to check for references
  275. Returns:
  276. True if any Dify $ref is found, False otherwise
  277. """
  278. # Phase 1: Fast string-based pre-filtering
  279. try:
  280. import json
  281. schema_str = json.dumps(schema, separators=(',', ':'))
  282. # Quick elimination: no $ref at all
  283. if '"$ref"' not in schema_str:
  284. return False
  285. # Quick elimination: no Dify schema URLs
  286. if 'https://dify.ai/schemas/' not in schema_str:
  287. return False
  288. except (TypeError, ValueError, OverflowError):
  289. # JSON serialization failed (e.g., circular references, non-serializable objects)
  290. # Fall back to recursive detection
  291. logger.debug("JSON serialization failed for schema, using recursive detection")
  292. return _has_dify_refs_recursive(schema)
  293. # Phase 2: Precise recursive validation
  294. # Only executed for schemas that passed string pre-filtering
  295. return _has_dify_refs_recursive(schema)
  296. def _has_dify_refs(schema: SchemaType) -> bool:
  297. """
  298. Check if a schema contains any Dify $ref references
  299. Uses hybrid detection for optimal performance:
  300. - Fast string scan for quick elimination
  301. - Precise recursive check for validation
  302. Args:
  303. schema: Schema to check for references
  304. Returns:
  305. True if any Dify $ref is found, False otherwise
  306. """
  307. return _has_dify_refs_hybrid(schema)
  308. def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
  309. """
  310. Parse a Dify schema URI to extract version and schema name
  311. Args:
  312. uri: Schema URI to parse
  313. Returns:
  314. Tuple of (version, schema_name) or ("", "") if invalid
  315. """
  316. match = _DIFY_SCHEMA_PATTERN.match(uri)
  317. if not match:
  318. return "", ""
  319. return match.group(1), match.group(2)