### What problem does this PR solve? Feat: Delete the agent and tool nodes downstream of the agent node #3221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.20.0
| @@ -54,13 +54,13 @@ function InnerAgentNode({ | |||
| type="target" | |||
| position={Position.Top} | |||
| isConnectable={false} | |||
| id="f" | |||
| id={NodeHandleId.AgentTop} | |||
| ></Handle> | |||
| <Handle | |||
| type="source" | |||
| position={Position.Bottom} | |||
| isConnectable={false} | |||
| id="e" | |||
| id={NodeHandleId.AgentBottom} | |||
| style={{ left: 180 }} | |||
| ></Handle> | |||
| <Handle | |||
| @@ -3054,4 +3054,6 @@ export enum NodeHandleId { | |||
| Start = 'start', | |||
| End = 'end', | |||
| Tool = 'tool', | |||
| AgentTop = 'agentTop', | |||
| AgentBottom = 'agentBottom', | |||
| } | |||
| @@ -5,9 +5,9 @@ import { | |||
| } from '@/components/ui/popover'; | |||
| import { Operator } from '@/pages/agent/constant'; | |||
| import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context'; | |||
| import useGraphStore from '@/pages/agent/store'; | |||
| import { Position } from '@xyflow/react'; | |||
| import { PropsWithChildren, useCallback, useContext } from 'react'; | |||
| import { useDeleteToolNode } from '../use-delete-tool-node'; | |||
| import { useGetAgentToolNames } from '../use-get-tools'; | |||
| import { ToolCommand } from './tool-command'; | |||
| import { useUpdateAgentNodeTools } from './use-update-tools'; | |||
| @@ -17,7 +17,9 @@ export function ToolPopover({ children }: PropsWithChildren) { | |||
| const node = useContext(AgentFormContext); | |||
| const { updateNodeTools } = useUpdateAgentNodeTools(); | |||
| const { toolNames } = useGetAgentToolNames(); | |||
| const { deleteToolNode } = useDeleteToolNode(); | |||
| const deleteAgentToolNodeById = useGraphStore( | |||
| (state) => state.deleteAgentToolNodeById, | |||
| ); | |||
| const handleChange = useCallback( | |||
| (value: string[]) => { | |||
| @@ -29,11 +31,11 @@ export function ToolPopover({ children }: PropsWithChildren) { | |||
| nodeId: node?.id, | |||
| })(); | |||
| } else { | |||
| deleteToolNode(node.id); // TODO: The tool node should be derived from the agent tools data | |||
| deleteAgentToolNodeById(node.id); // TODO: The tool node should be derived from the agent tools data | |||
| } | |||
| } | |||
| }, | |||
| [addCanvasNode, deleteToolNode, node?.id, updateNodeTools], | |||
| [addCanvasNode, deleteAgentToolNodeById, node?.id, updateNodeTools], | |||
| ); | |||
| return ( | |||
| @@ -3,7 +3,6 @@ import { AgentFormContext } from '@/pages/agent/context'; | |||
| import useGraphStore from '@/pages/agent/store'; | |||
| import { get } from 'lodash'; | |||
| import { useCallback, useContext, useMemo } from 'react'; | |||
| import { useDeleteToolNode } from '../use-delete-tool-node'; | |||
| export function useGetNodeTools() { | |||
| const node = useContext(AgentFormContext); | |||
| @@ -48,7 +47,9 @@ export function useDeleteAgentNodeTools() { | |||
| const { updateNodeForm } = useGraphStore((state) => state); | |||
| const tools = useGetNodeTools(); | |||
| const node = useContext(AgentFormContext); | |||
| const { deleteToolNode } = useDeleteToolNode(); | |||
| const deleteAgentToolNodeById = useGraphStore( | |||
| (state) => state.deleteAgentToolNodeById, | |||
| ); | |||
| const deleteNodeTool = useCallback( | |||
| (value: string) => () => { | |||
| @@ -56,11 +57,11 @@ export function useDeleteAgentNodeTools() { | |||
| if (node?.id) { | |||
| updateNodeForm(node?.id, nextTools, ['tools']); | |||
| if (nextTools.length === 0) { | |||
| deleteToolNode(node?.id); | |||
| deleteAgentToolNodeById(node?.id); | |||
| } | |||
| } | |||
| }, | |||
| [deleteToolNode, node?.id, tools, updateNodeForm], | |||
| [deleteAgentToolNodeById, node?.id, tools, updateNodeForm], | |||
| ); | |||
| return { deleteNodeTool }; | |||
| @@ -1,24 +0,0 @@ | |||
| import { useCallback } from 'react'; | |||
| import { NodeHandleId } from '../../constant'; | |||
| import useGraphStore from '../../store'; | |||
| export function useDeleteToolNode() { | |||
| const { edges, deleteEdgeById, deleteNodeById } = useGraphStore( | |||
| (state) => state, | |||
| ); | |||
| const deleteToolNode = useCallback( | |||
| (agentNodeId: string) => { | |||
| const edge = edges.find( | |||
| (x) => x.source === agentNodeId && x.sourceHandle === NodeHandleId.Tool, | |||
| ); | |||
| if (edge) { | |||
| deleteEdgeById(edge.id); | |||
| deleteNodeById(edge.target); | |||
| } | |||
| }, | |||
| [deleteEdgeById, deleteNodeById, edges], | |||
| ); | |||
| return { deleteToolNode }; | |||
| } | |||
| @@ -315,7 +315,11 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) { | |||
| if (agentNode) { | |||
| // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes | |||
| const allChildAgentNodeIds = edges | |||
| .filter((x) => x.source === nodeId && x.sourceHandle === 'e') | |||
| .filter( | |||
| (x) => | |||
| x.source === nodeId && | |||
| x.sourceHandle === NodeHandleId.AgentBottom, | |||
| ) | |||
| .map((x) => x.target); | |||
| const xAxises = nodes | |||
| @@ -334,8 +338,8 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) { | |||
| addEdge({ | |||
| source: nodeId, | |||
| target: newNode.id, | |||
| sourceHandle: 'e', | |||
| targetHandle: 'f', | |||
| sourceHandle: NodeHandleId.AgentBottom, | |||
| targetHandle: NodeHandleId.AgentTop, | |||
| }); | |||
| } | |||
| } else if (type === Operator.Tool) { | |||
| @@ -1,14 +1,18 @@ | |||
| import { RAGFlowNodeType } from '@/interfaces/database/flow'; | |||
| import { OnBeforeDelete } from '@xyflow/react'; | |||
| import { Node, OnBeforeDelete } from '@xyflow/react'; | |||
| import { Operator } from '../constant'; | |||
| import useGraphStore from '../store'; | |||
| import { deleteAllDownstreamAgentsAndTool } from '../utils/delete-node'; | |||
| const UndeletableNodes = [Operator.Begin, Operator.IterationStart]; | |||
| export function useBeforeDelete() { | |||
| const getOperatorTypeFromId = useGraphStore( | |||
| (state) => state.getOperatorTypeFromId, | |||
| ); | |||
| const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state); | |||
| const agentPredicate = (node: Node) => { | |||
| return getOperatorTypeFromId(node.id) === Operator.Agent; | |||
| }; | |||
| const handleBeforeDelete: OnBeforeDelete<RAGFlowNodeType> = async ({ | |||
| nodes, // Nodes to be deleted | |||
| edges, // Edges to be deleted | |||
| @@ -47,6 +51,27 @@ export function useBeforeDelete() { | |||
| return true; | |||
| }); | |||
| // Delete the agent and tool nodes downstream of the agent node | |||
| if (nodes.some(agentPredicate)) { | |||
| nodes.filter(agentPredicate).forEach((node) => { | |||
| const { downstreamAgentAndToolEdges, downstreamAgentAndToolNodeIds } = | |||
| deleteAllDownstreamAgentsAndTool(node.id, edges); | |||
| downstreamAgentAndToolNodeIds.forEach((nodeId) => { | |||
| const currentNode = getNode(nodeId); | |||
| if (toBeDeletedNodes.every((x) => x.id !== nodeId) && currentNode) { | |||
| toBeDeletedNodes.push(currentNode); | |||
| } | |||
| }); | |||
| downstreamAgentAndToolEdges.forEach((edge) => { | |||
| if (toBeDeletedEdges.every((x) => x.id !== edge.id)) { | |||
| toBeDeletedEdges.push(edge); | |||
| } | |||
| }); | |||
| }, []); | |||
| } | |||
| return { | |||
| nodes: toBeDeletedNodes, | |||
| edges: toBeDeletedEdges, | |||
| @@ -21,7 +21,7 @@ import lodashSet from 'lodash/set'; | |||
| import { create } from 'zustand'; | |||
| import { devtools } from 'zustand/middleware'; | |||
| import { immer } from 'zustand/middleware/immer'; | |||
| import { Operator, SwitchElseTo } from './constant'; | |||
| import { NodeHandleId, Operator, SwitchElseTo } from './constant'; | |||
| import { | |||
| duplicateNodeForm, | |||
| generateDuplicateNode, | |||
| @@ -30,6 +30,7 @@ import { | |||
| isEdgeEqual, | |||
| mapEdgeMouseEvent, | |||
| } from './utils'; | |||
| import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node'; | |||
| export type RFState = { | |||
| nodes: RAGFlowNodeType[]; | |||
| @@ -70,6 +71,8 @@ export type RFState = { | |||
| deleteEdge: () => void; | |||
| deleteEdgeById: (id: string) => void; | |||
| deleteNodeById: (id: string) => void; | |||
| deleteAgentDownstreamNodesById: (id: string) => void; | |||
| deleteAgentToolNodeById: (id: string) => void; | |||
| deleteIterationNodeById: (id: string) => void; | |||
| deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void; | |||
| findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; | |||
| @@ -370,7 +373,16 @@ const useGraphStore = create<RFState>()( | |||
| }); | |||
| }, | |||
| deleteNodeById: (id: string) => { | |||
| const { nodes, edges } = get(); | |||
| const { | |||
| nodes, | |||
| edges, | |||
| getOperatorTypeFromId, | |||
| deleteAgentDownstreamNodesById, | |||
| } = get(); | |||
| if (getOperatorTypeFromId(id) === Operator.Agent) { | |||
| deleteAgentDownstreamNodesById(id); | |||
| return; | |||
| } | |||
| set({ | |||
| nodes: nodes.filter((node) => node.id !== id), | |||
| edges: edges | |||
| @@ -378,6 +390,38 @@ const useGraphStore = create<RFState>()( | |||
| .filter((edge) => edge.target !== id), | |||
| }); | |||
| }, | |||
| deleteAgentDownstreamNodesById: (id) => { | |||
| const { edges, nodes } = get(); | |||
| const { downstreamAgentAndToolNodeIds, downstreamAgentAndToolEdges } = | |||
| deleteAllDownstreamAgentsAndTool(id, edges); | |||
| set({ | |||
| nodes: nodes.filter( | |||
| (node) => | |||
| !downstreamAgentAndToolNodeIds.some((x) => x === node.id) && | |||
| node.id !== id, | |||
| ), | |||
| edges: edges.filter( | |||
| (edge) => | |||
| edge.source !== id && | |||
| edge.target !== id && | |||
| !downstreamAgentAndToolEdges.some((x) => x.id === edge.id), | |||
| ), | |||
| }); | |||
| }, | |||
| deleteAgentToolNodeById: (id) => { | |||
| const { edges, deleteEdgeById, deleteNodeById } = get(); | |||
| const edge = edges.find( | |||
| (x) => x.source === id && x.sourceHandle === NodeHandleId.Tool, | |||
| ); | |||
| if (edge) { | |||
| deleteEdgeById(edge.id); | |||
| deleteNodeById(edge.target); | |||
| } | |||
| }, | |||
| deleteIterationNodeById: (id: string) => { | |||
| const { nodes, edges } = get(); | |||
| const children = nodes.filter((node) => node.parentId === id); | |||
| @@ -0,0 +1,34 @@ | |||
| import { Edge } from '@xyflow/react'; | |||
| import { filterAllDownstreamAgentAndToolNodeIds } from './filter-downstream-nodes'; | |||
| // Delete all downstream agent and tool operators of the current agent operator | |||
| export function deleteAllDownstreamAgentsAndTool( | |||
| nodeId: string, | |||
| edges: Edge[], | |||
| ) { | |||
| const downstreamAgentAndToolNodeIds = filterAllDownstreamAgentAndToolNodeIds( | |||
| edges, | |||
| [nodeId], | |||
| ); | |||
| const downstreamAgentAndToolEdges = downstreamAgentAndToolNodeIds.reduce< | |||
| Edge[] | |||
| >((pre, cur) => { | |||
| const relatedEdges = edges.filter( | |||
| (x) => x.source === cur || x.target === cur, | |||
| ); | |||
| relatedEdges.forEach((x) => { | |||
| if (!pre.some((y) => y.id !== x.id)) { | |||
| pre.push(x); | |||
| } | |||
| }); | |||
| return pre; | |||
| }, []); | |||
| return { | |||
| downstreamAgentAndToolNodeIds, | |||
| downstreamAgentAndToolEdges, | |||
| }; | |||
| } | |||
| @@ -0,0 +1,31 @@ | |||
| import { Edge } from '@xyflow/react'; | |||
| import { NodeHandleId } from '../constant'; | |||
| // Get all downstream agent operators of the current agent operator | |||
| export function filterAllDownstreamAgentAndToolNodeIds( | |||
| edges: Edge[], | |||
| nodeIds: string[], | |||
| ) { | |||
| return nodeIds.reduce<string[]>((pre, nodeId) => { | |||
| const currentEdges = edges.filter( | |||
| (x) => | |||
| x.source === nodeId && | |||
| (x.sourceHandle === NodeHandleId.AgentBottom || | |||
| x.sourceHandle === NodeHandleId.Tool), | |||
| ); | |||
| const downstreamNodeIds: string[] = currentEdges.map((x) => x.target); | |||
| const ids = downstreamNodeIds.concat( | |||
| filterAllDownstreamAgentAndToolNodeIds(edges, downstreamNodeIds), | |||
| ); | |||
| ids.forEach((x) => { | |||
| if (pre.every((y) => y !== x)) { | |||
| pre.push(x); | |||
| } | |||
| }); | |||
| return pre; | |||
| }, []); | |||
| } | |||