您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import logging
  2. from collections import defaultdict
  3. from collections.abc import Mapping, Sequence
  4. from typing import Protocol, cast, final
  5. from core.workflow.enums import NodeExecutionType, NodeState, NodeType
  6. from core.workflow.nodes.base.node import Node
  7. from libs.typing import is_str, is_str_dict
  8. from .edge import Edge
  9. logger = logging.getLogger(__name__)
  10. class NodeFactory(Protocol):
  11. """
  12. Protocol for creating Node instances from node data dictionaries.
  13. This protocol decouples the Graph class from specific node mapping implementations,
  14. allowing for different node creation strategies while maintaining type safety.
  15. """
  16. def create_node(self, node_config: dict[str, object]) -> Node:
  17. """
  18. Create a Node instance from node configuration data.
  19. :param node_config: node configuration dictionary containing type and other data
  20. :return: initialized Node instance
  21. :raises ValueError: if node type is unknown or configuration is invalid
  22. """
  23. ...
  24. @final
  25. class Graph:
  26. """Graph representation with nodes and edges for workflow execution."""
  27. def __init__(
  28. self,
  29. *,
  30. nodes: dict[str, Node] | None = None,
  31. edges: dict[str, Edge] | None = None,
  32. in_edges: dict[str, list[str]] | None = None,
  33. out_edges: dict[str, list[str]] | None = None,
  34. root_node: Node,
  35. ):
  36. """
  37. Initialize Graph instance.
  38. :param nodes: graph nodes mapping (node id: node object)
  39. :param edges: graph edges mapping (edge id: edge object)
  40. :param in_edges: incoming edges mapping (node id: list of edge ids)
  41. :param out_edges: outgoing edges mapping (node id: list of edge ids)
  42. :param root_node: root node object
  43. """
  44. self.nodes = nodes or {}
  45. self.edges = edges or {}
  46. self.in_edges = in_edges or {}
  47. self.out_edges = out_edges or {}
  48. self.root_node = root_node
  49. @classmethod
  50. def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
  51. """
  52. Parse node configurations and build a mapping of node IDs to configs.
  53. :param node_configs: list of node configuration dictionaries
  54. :return: mapping of node ID to node config
  55. """
  56. node_configs_map: dict[str, dict[str, object]] = {}
  57. for node_config in node_configs:
  58. node_id = node_config.get("id")
  59. if not node_id or not isinstance(node_id, str):
  60. continue
  61. node_configs_map[node_id] = node_config
  62. return node_configs_map
  63. @classmethod
  64. def _find_root_node_id(
  65. cls,
  66. node_configs_map: Mapping[str, Mapping[str, object]],
  67. edge_configs: Sequence[Mapping[str, object]],
  68. root_node_id: str | None = None,
  69. ) -> str:
  70. """
  71. Find the root node ID if not specified.
  72. :param node_configs_map: mapping of node ID to node config
  73. :param edge_configs: list of edge configurations
  74. :param root_node_id: explicitly specified root node ID
  75. :return: determined root node ID
  76. """
  77. if root_node_id:
  78. if root_node_id not in node_configs_map:
  79. raise ValueError(f"Root node id {root_node_id} not found in the graph")
  80. return root_node_id
  81. # Find nodes with no incoming edges
  82. nodes_with_incoming: set[str] = set()
  83. for edge_config in edge_configs:
  84. target = edge_config.get("target")
  85. if isinstance(target, str):
  86. nodes_with_incoming.add(target)
  87. root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
  88. # Prefer START node if available
  89. start_node_id = None
  90. for nid in root_candidates:
  91. node_data = node_configs_map[nid].get("data")
  92. if not is_str_dict(node_data):
  93. continue
  94. node_type = node_data.get("type")
  95. if not isinstance(node_type, str):
  96. continue
  97. if node_type in [NodeType.START, NodeType.DATASOURCE]:
  98. start_node_id = nid
  99. break
  100. root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
  101. if not root_node_id:
  102. raise ValueError("Unable to determine root node ID")
  103. return root_node_id
  104. @classmethod
  105. def _build_edges(
  106. cls, edge_configs: list[dict[str, object]]
  107. ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
  108. """
  109. Build edge objects and mappings from edge configurations.
  110. :param edge_configs: list of edge configurations
  111. :return: tuple of (edges dict, in_edges dict, out_edges dict)
  112. """
  113. edges: dict[str, Edge] = {}
  114. in_edges: dict[str, list[str]] = defaultdict(list)
  115. out_edges: dict[str, list[str]] = defaultdict(list)
  116. edge_counter = 0
  117. for edge_config in edge_configs:
  118. source = edge_config.get("source")
  119. target = edge_config.get("target")
  120. if not is_str(source) or not is_str(target):
  121. continue
  122. # Create edge
  123. edge_id = f"edge_{edge_counter}"
  124. edge_counter += 1
  125. source_handle = edge_config.get("sourceHandle", "source")
  126. if not is_str(source_handle):
  127. continue
  128. edge = Edge(
  129. id=edge_id,
  130. tail=source,
  131. head=target,
  132. source_handle=source_handle,
  133. )
  134. edges[edge_id] = edge
  135. out_edges[source].append(edge_id)
  136. in_edges[target].append(edge_id)
  137. return edges, dict(in_edges), dict(out_edges)
  138. @classmethod
  139. def _create_node_instances(
  140. cls,
  141. node_configs_map: dict[str, dict[str, object]],
  142. node_factory: "NodeFactory",
  143. ) -> dict[str, Node]:
  144. """
  145. Create node instances from configurations using the node factory.
  146. :param node_configs_map: mapping of node ID to node config
  147. :param node_factory: factory for creating node instances
  148. :return: mapping of node ID to node instance
  149. """
  150. nodes: dict[str, Node] = {}
  151. for node_id, node_config in node_configs_map.items():
  152. try:
  153. node_instance = node_factory.create_node(node_config)
  154. except ValueError as e:
  155. logger.warning("Failed to create node instance: %s", str(e))
  156. continue
  157. nodes[node_id] = node_instance
  158. return nodes
  159. @classmethod
  160. def _mark_inactive_root_branches(
  161. cls,
  162. nodes: dict[str, Node],
  163. edges: dict[str, Edge],
  164. in_edges: dict[str, list[str]],
  165. out_edges: dict[str, list[str]],
  166. active_root_id: str,
  167. ) -> None:
  168. """
  169. Mark nodes and edges from inactive root branches as skipped.
  170. Algorithm:
  171. 1. Mark inactive root nodes as skipped
  172. 2. For skipped nodes, mark all their outgoing edges as skipped
  173. 3. For each edge marked as skipped, check its target node:
  174. - If ALL incoming edges are skipped, mark the node as skipped
  175. - Otherwise, leave the node state unchanged
  176. :param nodes: mapping of node ID to node instance
  177. :param edges: mapping of edge ID to edge instance
  178. :param in_edges: mapping of node ID to incoming edge IDs
  179. :param out_edges: mapping of node ID to outgoing edge IDs
  180. :param active_root_id: ID of the active root node
  181. """
  182. # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
  183. top_level_roots: list[str] = [
  184. node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
  185. ]
  186. # If there's only one root or the active root is not a top-level root, no marking needed
  187. if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
  188. return
  189. # Mark inactive root nodes as skipped
  190. inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
  191. for root_id in inactive_roots:
  192. if root_id in nodes:
  193. nodes[root_id].state = NodeState.SKIPPED
  194. # Recursively mark downstream nodes and edges
  195. def mark_downstream(node_id: str) -> None:
  196. """Recursively mark downstream nodes and edges as skipped."""
  197. if nodes[node_id].state != NodeState.SKIPPED:
  198. return
  199. # If this node is skipped, mark all its outgoing edges as skipped
  200. out_edge_ids = out_edges.get(node_id, [])
  201. for edge_id in out_edge_ids:
  202. edge = edges[edge_id]
  203. edge.state = NodeState.SKIPPED
  204. # Check the target node of this edge
  205. target_node = nodes[edge.head]
  206. in_edge_ids = in_edges.get(target_node.id, [])
  207. in_edge_states = [edges[eid].state for eid in in_edge_ids]
  208. # If all incoming edges are skipped, mark the node as skipped
  209. if all(state == NodeState.SKIPPED for state in in_edge_states):
  210. target_node.state = NodeState.SKIPPED
  211. # Recursively process downstream nodes
  212. mark_downstream(target_node.id)
  213. # Process each inactive root and its downstream nodes
  214. for root_id in inactive_roots:
  215. mark_downstream(root_id)
  216. @classmethod
  217. def init(
  218. cls,
  219. *,
  220. graph_config: Mapping[str, object],
  221. node_factory: "NodeFactory",
  222. root_node_id: str | None = None,
  223. ) -> "Graph":
  224. """
  225. Initialize graph
  226. :param graph_config: graph config containing nodes and edges
  227. :param node_factory: factory for creating node instances from config data
  228. :param root_node_id: root node id
  229. :return: graph instance
  230. """
  231. # Parse configs
  232. edge_configs = graph_config.get("edges", [])
  233. node_configs = graph_config.get("nodes", [])
  234. edge_configs = cast(list[dict[str, object]], edge_configs)
  235. node_configs = cast(list[dict[str, object]], node_configs)
  236. if not node_configs:
  237. raise ValueError("Graph must have at least one node")
  238. node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
  239. # Parse node configurations
  240. node_configs_map = cls._parse_node_configs(node_configs)
  241. # Find root node
  242. root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
  243. # Build edges
  244. edges, in_edges, out_edges = cls._build_edges(edge_configs)
  245. # Create node instances
  246. nodes = cls._create_node_instances(node_configs_map, node_factory)
  247. # Get root node instance
  248. root_node = nodes[root_node_id]
  249. # Mark inactive root branches as skipped
  250. cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
  251. # Create and return the graph
  252. return cls(
  253. nodes=nodes,
  254. edges=edges,
  255. in_edges=in_edges,
  256. out_edges=out_edges,
  257. root_node=root_node,
  258. )
  259. @property
  260. def node_ids(self) -> list[str]:
  261. """
  262. Get list of node IDs (compatibility property for existing code)
  263. :return: list of node IDs
  264. """
  265. return list(self.nodes.keys())
  266. def get_outgoing_edges(self, node_id: str) -> list[Edge]:
  267. """
  268. Get all outgoing edges from a node (V2 method)
  269. :param node_id: node id
  270. :return: list of outgoing edges
  271. """
  272. edge_ids = self.out_edges.get(node_id, [])
  273. return [self.edges[eid] for eid in edge_ids if eid in self.edges]
  274. def get_incoming_edges(self, node_id: str) -> list[Edge]:
  275. """
  276. Get all incoming edges to a node (V2 method)
  277. :param node_id: node id
  278. :return: list of incoming edges
  279. """
  280. edge_ids = self.in_edges.get(node_id, [])
  281. return [self.edges[eid] for eid in edge_ids if eid in self.edges]