|
|
|
@@ -204,47 +204,6 @@ class Graph(BaseModel): |
|
|
|
|
|
|
|
return graph |
|
|
|
|
|
|
|
def add_extra_edge( |
|
|
|
self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Add extra edge to the graph |
|
|
|
|
|
|
|
:param source_node_id: source node id |
|
|
|
:param target_node_id: target node id |
|
|
|
:param run_condition: run condition |
|
|
|
""" |
|
|
|
if source_node_id not in self.node_ids or target_node_id not in self.node_ids: |
|
|
|
return |
|
|
|
|
|
|
|
if source_node_id not in self.edge_mapping: |
|
|
|
self.edge_mapping[source_node_id] = [] |
|
|
|
|
|
|
|
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: |
|
|
|
return |
|
|
|
|
|
|
|
graph_edge = GraphEdge( |
|
|
|
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition |
|
|
|
) |
|
|
|
|
|
|
|
self.edge_mapping[source_node_id].append(graph_edge) |
|
|
|
|
|
|
|
def get_leaf_node_ids(self) -> list[str]: |
|
|
|
""" |
|
|
|
Get leaf node ids of the graph |
|
|
|
|
|
|
|
:return: leaf node ids |
|
|
|
""" |
|
|
|
leaf_node_ids = [] |
|
|
|
for node_id in self.node_ids: |
|
|
|
if node_id not in self.edge_mapping or ( |
|
|
|
len(self.edge_mapping[node_id]) == 1 |
|
|
|
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id |
|
|
|
): |
|
|
|
leaf_node_ids.append(node_id) |
|
|
|
|
|
|
|
return leaf_node_ids |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _recursively_add_node_ids( |
|
|
|
cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str |
|
|
|
@@ -681,11 +640,8 @@ class Graph(BaseModel): |
|
|
|
if start_node_id not in reverse_edge_mapping: |
|
|
|
return False |
|
|
|
|
|
|
|
all_routes_node_ids = set() |
|
|
|
parallel_start_node_ids: dict[str, list[str]] = {} |
|
|
|
for branch_node_id, node_ids in routes_node_ids.items(): |
|
|
|
all_routes_node_ids.update(node_ids) |
|
|
|
|
|
|
|
for branch_node_id in routes_node_ids: |
|
|
|
if branch_node_id in reverse_edge_mapping: |
|
|
|
for graph_edge in reverse_edge_mapping[branch_node_id]: |
|
|
|
if graph_edge.source_node_id not in parallel_start_node_ids: |
|
|
|
@@ -693,8 +649,9 @@ class Graph(BaseModel): |
|
|
|
|
|
|
|
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) |
|
|
|
|
|
|
|
expected_branch_set = set(routes_node_ids.keys()) |
|
|
|
for _, branch_node_ids in parallel_start_node_ids.items(): |
|
|
|
if set(branch_node_ids) == set(routes_node_ids.keys()): |
|
|
|
if set(branch_node_ids) == expected_branch_set: |
|
|
|
return True |
|
|
|
|
|
|
|
return False |