您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import logging
  2. from collections import defaultdict
  3. from collections.abc import Mapping
  4. from typing import Any, Optional, Protocol, cast
  5. from core.workflow.enums import NodeType
  6. from core.workflow.nodes.base.node import Node
  7. from .edge import Edge
  8. logger = logging.getLogger(__name__)
  9. class NodeFactory(Protocol):
  10. """
  11. Protocol for creating Node instances from node data dictionaries.
  12. This protocol decouples the Graph class from specific node mapping implementations,
  13. allowing for different node creation strategies while maintaining type safety.
  14. """
  15. def create_node(self, node_config: dict[str, Any]) -> Node:
  16. """
  17. Create a Node instance from node configuration data.
  18. :param node_config: node configuration dictionary containing type and other data
  19. :return: initialized Node instance
  20. :raises ValueError: if node type is unknown or configuration is invalid
  21. """
  22. ...
  23. class Graph:
  24. """Graph representation with nodes and edges for workflow execution."""
  25. def __init__(
  26. self,
  27. *,
  28. nodes: Optional[dict[str, Node]] = None,
  29. edges: Optional[dict[str, Edge]] = None,
  30. in_edges: Optional[dict[str, list[str]]] = None,
  31. out_edges: Optional[dict[str, list[str]]] = None,
  32. root_node: Node,
  33. ):
  34. """
  35. Initialize Graph instance.
  36. :param nodes: graph nodes mapping (node id: node object)
  37. :param edges: graph edges mapping (edge id: edge object)
  38. :param in_edges: incoming edges mapping (node id: list of edge ids)
  39. :param out_edges: outgoing edges mapping (node id: list of edge ids)
  40. :param root_node: root node object
  41. """
  42. self.nodes = nodes or {}
  43. self.edges = edges or {}
  44. self.in_edges = in_edges or {}
  45. self.out_edges = out_edges or {}
  46. self.root_node = root_node
  47. @classmethod
  48. def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
  49. """
  50. Parse node configurations and build a mapping of node IDs to configs.
  51. :param node_configs: list of node configuration dictionaries
  52. :return: mapping of node ID to node config
  53. """
  54. node_configs_map: dict[str, dict[str, Any]] = {}
  55. for node_config in node_configs:
  56. node_id = node_config.get("id")
  57. if not node_id:
  58. continue
  59. node_configs_map[node_id] = node_config
  60. return node_configs_map
  61. @classmethod
  62. def _find_root_node_id(
  63. cls,
  64. node_configs_map: dict[str, dict[str, Any]],
  65. edge_configs: list[dict[str, Any]],
  66. root_node_id: Optional[str] = None,
  67. ) -> str:
  68. """
  69. Find the root node ID if not specified.
  70. :param node_configs_map: mapping of node ID to node config
  71. :param edge_configs: list of edge configurations
  72. :param root_node_id: explicitly specified root node ID
  73. :return: determined root node ID
  74. """
  75. if root_node_id:
  76. if root_node_id not in node_configs_map:
  77. raise ValueError(f"Root node id {root_node_id} not found in the graph")
  78. return root_node_id
  79. # Find nodes with no incoming edges
  80. nodes_with_incoming = set()
  81. for edge_config in edge_configs:
  82. target = edge_config.get("target")
  83. if target:
  84. nodes_with_incoming.add(target)
  85. root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
  86. # Prefer START node if available
  87. start_node_id = None
  88. for nid in root_candidates:
  89. node_data = node_configs_map[nid].get("data", {})
  90. if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]:
  91. start_node_id = nid
  92. break
  93. root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
  94. if not root_node_id:
  95. raise ValueError("Unable to determine root node ID")
  96. return root_node_id
  97. @classmethod
  98. def _build_edges(
  99. cls, edge_configs: list[dict[str, Any]]
  100. ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
  101. """
  102. Build edge objects and mappings from edge configurations.
  103. :param edge_configs: list of edge configurations
  104. :return: tuple of (edges dict, in_edges dict, out_edges dict)
  105. """
  106. edges: dict[str, Edge] = {}
  107. in_edges: dict[str, list[str]] = defaultdict(list)
  108. out_edges: dict[str, list[str]] = defaultdict(list)
  109. edge_counter = 0
  110. for edge_config in edge_configs:
  111. source = edge_config.get("source")
  112. target = edge_config.get("target")
  113. if not source or not target:
  114. continue
  115. # Create edge
  116. edge_id = f"edge_{edge_counter}"
  117. edge_counter += 1
  118. source_handle = edge_config.get("sourceHandle", "source")
  119. edge = Edge(
  120. id=edge_id,
  121. tail=source,
  122. head=target,
  123. source_handle=source_handle,
  124. )
  125. edges[edge_id] = edge
  126. out_edges[source].append(edge_id)
  127. in_edges[target].append(edge_id)
  128. return edges, dict(in_edges), dict(out_edges)
  129. @classmethod
  130. def _create_node_instances(
  131. cls,
  132. node_configs_map: dict[str, dict[str, Any]],
  133. node_factory: "NodeFactory",
  134. ) -> dict[str, Node]:
  135. """
  136. Create node instances from configurations using the node factory.
  137. :param node_configs_map: mapping of node ID to node config
  138. :param node_factory: factory for creating node instances
  139. :return: mapping of node ID to node instance
  140. """
  141. nodes: dict[str, Node] = {}
  142. for node_id, node_config in node_configs_map.items():
  143. try:
  144. node_instance = node_factory.create_node(node_config)
  145. except ValueError as e:
  146. logger.warning("Failed to create node instance: %s", str(e))
  147. continue
  148. nodes[node_id] = node_instance
  149. return nodes
  150. @classmethod
  151. def init(
  152. cls,
  153. *,
  154. graph_config: Mapping[str, Any],
  155. node_factory: "NodeFactory",
  156. root_node_id: Optional[str] = None,
  157. ) -> "Graph":
  158. """
  159. Initialize graph
  160. :param graph_config: graph config containing nodes and edges
  161. :param node_factory: factory for creating node instances from config data
  162. :param root_node_id: root node id
  163. :return: graph instance
  164. """
  165. # Parse configs
  166. edge_configs = graph_config.get("edges", [])
  167. node_configs = graph_config.get("nodes", [])
  168. if not node_configs:
  169. raise ValueError("Graph must have at least one node")
  170. edge_configs = cast(list, edge_configs)
  171. node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
  172. # Parse node configurations
  173. node_configs_map = cls._parse_node_configs(node_configs)
  174. # Find root node
  175. root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
  176. # Build edges
  177. edges, in_edges, out_edges = cls._build_edges(edge_configs)
  178. # Create node instances
  179. nodes = cls._create_node_instances(node_configs_map, node_factory)
  180. # Get root node instance
  181. root_node = nodes[root_node_id]
  182. # Create and return the graph
  183. return cls(
  184. nodes=nodes,
  185. edges=edges,
  186. in_edges=in_edges,
  187. out_edges=out_edges,
  188. root_node=root_node,
  189. )
  190. @property
  191. def node_ids(self) -> list[str]:
  192. """
  193. Get list of node IDs (compatibility property for existing code)
  194. :return: list of node IDs
  195. """
  196. return list(self.nodes.keys())
  197. def get_outgoing_edges(self, node_id: str) -> list[Edge]:
  198. """
  199. Get all outgoing edges from a node (V2 method)
  200. :param node_id: node id
  201. :return: list of outgoing edges
  202. """
  203. edge_ids = self.out_edges.get(node_id, [])
  204. return [self.edges[eid] for eid in edge_ids if eid in self.edges]
  205. def get_incoming_edges(self, node_id: str) -> list[Edge]:
  206. """
  207. Get all incoming edges to a node (V2 method)
  208. :param node_id: node id
  209. :return: list of incoming edges
  210. """
  211. edge_ids = self.in_edges.get(node_id, [])
  212. return [self.edges[eid] for eid in edge_ids if eid in self.edges]