瀏覽代碼

Feat: Delete the agent and tool nodes downstream of the agent node #3221 (#8450)

### What problem does this PR solve?

Feat: Delete the agent and tool nodes downstream of the agent node #3221

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
tags/v0.20.0
balibabu 4 月之前
父節點
當前提交
07545fbfd3
沒有連結到貢獻者的電子郵件帳戶。

+ 2
- 2
web/src/pages/agent/canvas/node/agent-node.tsx 查看文件

type="target" type="target"
position={Position.Top} position={Position.Top}
isConnectable={false} isConnectable={false}
id="f"
id={NodeHandleId.AgentTop}
></Handle> ></Handle>
<Handle <Handle
type="source" type="source"
position={Position.Bottom} position={Position.Bottom}
isConnectable={false} isConnectable={false}
id="e"
id={NodeHandleId.AgentBottom}
style={{ left: 180 }} style={{ left: 180 }}
></Handle> ></Handle>
<Handle <Handle

+ 2
- 0
web/src/pages/agent/constant.tsx 查看文件

Start = 'start', Start = 'start',
End = 'end', End = 'end',
Tool = 'tool', Tool = 'tool',
AgentTop = 'agentTop',
AgentBottom = 'agentBottom',
} }

+ 6
- 4
web/src/pages/agent/form/agent-form/tool-popover/index.tsx 查看文件

} from '@/components/ui/popover'; } from '@/components/ui/popover';
import { Operator } from '@/pages/agent/constant'; import { Operator } from '@/pages/agent/constant';
import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context'; import { AgentFormContext, AgentInstanceContext } from '@/pages/agent/context';
import useGraphStore from '@/pages/agent/store';
import { Position } from '@xyflow/react'; import { Position } from '@xyflow/react';
import { PropsWithChildren, useCallback, useContext } from 'react'; import { PropsWithChildren, useCallback, useContext } from 'react';
import { useDeleteToolNode } from '../use-delete-tool-node';
import { useGetAgentToolNames } from '../use-get-tools'; import { useGetAgentToolNames } from '../use-get-tools';
import { ToolCommand } from './tool-command'; import { ToolCommand } from './tool-command';
import { useUpdateAgentNodeTools } from './use-update-tools'; import { useUpdateAgentNodeTools } from './use-update-tools';
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
const { updateNodeTools } = useUpdateAgentNodeTools(); const { updateNodeTools } = useUpdateAgentNodeTools();
const { toolNames } = useGetAgentToolNames(); const { toolNames } = useGetAgentToolNames();
const { deleteToolNode } = useDeleteToolNode();
const deleteAgentToolNodeById = useGraphStore(
(state) => state.deleteAgentToolNodeById,
);


const handleChange = useCallback( const handleChange = useCallback(
(value: string[]) => { (value: string[]) => {
nodeId: node?.id, nodeId: node?.id,
})(); })();
} else { } else {
deleteToolNode(node.id); // TODO: The tool node should be derived from the agent tools data
deleteAgentToolNodeById(node.id); // TODO: The tool node should be derived from the agent tools data
} }
} }
}, },
[addCanvasNode, deleteToolNode, node?.id, updateNodeTools],
[addCanvasNode, deleteAgentToolNodeById, node?.id, updateNodeTools],
); );


return ( return (

+ 5
- 4
web/src/pages/agent/form/agent-form/tool-popover/use-update-tools.ts 查看文件

import useGraphStore from '@/pages/agent/store'; import useGraphStore from '@/pages/agent/store';
import { get } from 'lodash'; import { get } from 'lodash';
import { useCallback, useContext, useMemo } from 'react'; import { useCallback, useContext, useMemo } from 'react';
import { useDeleteToolNode } from '../use-delete-tool-node';


export function useGetNodeTools() { export function useGetNodeTools() {
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
const { updateNodeForm } = useGraphStore((state) => state); const { updateNodeForm } = useGraphStore((state) => state);
const tools = useGetNodeTools(); const tools = useGetNodeTools();
const node = useContext(AgentFormContext); const node = useContext(AgentFormContext);
const { deleteToolNode } = useDeleteToolNode();
const deleteAgentToolNodeById = useGraphStore(
(state) => state.deleteAgentToolNodeById,
);


const deleteNodeTool = useCallback( const deleteNodeTool = useCallback(
(value: string) => () => { (value: string) => () => {
if (node?.id) { if (node?.id) {
updateNodeForm(node?.id, nextTools, ['tools']); updateNodeForm(node?.id, nextTools, ['tools']);
if (nextTools.length === 0) { if (nextTools.length === 0) {
deleteToolNode(node?.id);
deleteAgentToolNodeById(node?.id);
} }
} }
}, },
[deleteToolNode, node?.id, tools, updateNodeForm],
[deleteAgentToolNodeById, node?.id, tools, updateNodeForm],
); );


return { deleteNodeTool }; return { deleteNodeTool };

+ 0
- 24
web/src/pages/agent/form/agent-form/use-delete-tool-node.ts 查看文件

import { useCallback } from 'react';
import { NodeHandleId } from '../../constant';
import useGraphStore from '../../store';

export function useDeleteToolNode() {
const { edges, deleteEdgeById, deleteNodeById } = useGraphStore(
(state) => state,
);
const deleteToolNode = useCallback(
(agentNodeId: string) => {
const edge = edges.find(
(x) => x.source === agentNodeId && x.sourceHandle === NodeHandleId.Tool,
);

if (edge) {
deleteEdgeById(edge.id);
deleteNodeById(edge.target);
}
},
[deleteEdgeById, deleteNodeById, edges],
);

return { deleteToolNode };
}

+ 7
- 3
web/src/pages/agent/hooks/use-add-node.ts 查看文件

if (agentNode) { if (agentNode) {
// Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
const allChildAgentNodeIds = edges const allChildAgentNodeIds = edges
.filter((x) => x.source === nodeId && x.sourceHandle === 'e')
.filter(
(x) =>
x.source === nodeId &&
x.sourceHandle === NodeHandleId.AgentBottom,
)
.map((x) => x.target); .map((x) => x.target);


const xAxises = nodes const xAxises = nodes
addEdge({ addEdge({
source: nodeId, source: nodeId,
target: newNode.id, target: newNode.id,
sourceHandle: 'e',
targetHandle: 'f',
sourceHandle: NodeHandleId.AgentBottom,
targetHandle: NodeHandleId.AgentTop,
}); });
} }
} else if (type === Operator.Tool) { } else if (type === Operator.Tool) {

+ 29
- 4
web/src/pages/agent/hooks/use-before-delete.tsx 查看文件

import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { RAGFlowNodeType } from '@/interfaces/database/flow';
import { OnBeforeDelete } from '@xyflow/react';
import { Node, OnBeforeDelete } from '@xyflow/react';
import { Operator } from '../constant'; import { Operator } from '../constant';
import useGraphStore from '../store'; import useGraphStore from '../store';
import { deleteAllDownstreamAgentsAndTool } from '../utils/delete-node';


const UndeletableNodes = [Operator.Begin, Operator.IterationStart]; const UndeletableNodes = [Operator.Begin, Operator.IterationStart];


export function useBeforeDelete() { export function useBeforeDelete() {
const getOperatorTypeFromId = useGraphStore(
(state) => state.getOperatorTypeFromId,
);
const { getOperatorTypeFromId, getNode } = useGraphStore((state) => state);

const agentPredicate = (node: Node) => {
return getOperatorTypeFromId(node.id) === Operator.Agent;
};

const handleBeforeDelete: OnBeforeDelete<RAGFlowNodeType> = async ({ const handleBeforeDelete: OnBeforeDelete<RAGFlowNodeType> = async ({
nodes, // Nodes to be deleted nodes, // Nodes to be deleted
edges, // Edges to be deleted edges, // Edges to be deleted
return true; return true;
}); });


// Delete the agent and tool nodes downstream of the agent node
if (nodes.some(agentPredicate)) {
nodes.filter(agentPredicate).forEach((node) => {
const { downstreamAgentAndToolEdges, downstreamAgentAndToolNodeIds } =
deleteAllDownstreamAgentsAndTool(node.id, edges);

downstreamAgentAndToolNodeIds.forEach((nodeId) => {
const currentNode = getNode(nodeId);
if (toBeDeletedNodes.every((x) => x.id !== nodeId) && currentNode) {
toBeDeletedNodes.push(currentNode);
}
});

downstreamAgentAndToolEdges.forEach((edge) => {
if (toBeDeletedEdges.every((x) => x.id !== edge.id)) {
toBeDeletedEdges.push(edge);
}
});
}, []);
}

return { return {
nodes: toBeDeletedNodes, nodes: toBeDeletedNodes,
edges: toBeDeletedEdges, edges: toBeDeletedEdges,

+ 46
- 2
web/src/pages/agent/store.ts 查看文件

import { create } from 'zustand'; import { create } from 'zustand';
import { devtools } from 'zustand/middleware'; import { devtools } from 'zustand/middleware';
import { immer } from 'zustand/middleware/immer'; import { immer } from 'zustand/middleware/immer';
import { Operator, SwitchElseTo } from './constant';
import { NodeHandleId, Operator, SwitchElseTo } from './constant';
import { import {
duplicateNodeForm, duplicateNodeForm,
generateDuplicateNode, generateDuplicateNode,
isEdgeEqual, isEdgeEqual,
mapEdgeMouseEvent, mapEdgeMouseEvent,
} from './utils'; } from './utils';
import { deleteAllDownstreamAgentsAndTool } from './utils/delete-node';


export type RFState = { export type RFState = {
nodes: RAGFlowNodeType[]; nodes: RAGFlowNodeType[];
deleteEdge: () => void; deleteEdge: () => void;
deleteEdgeById: (id: string) => void; deleteEdgeById: (id: string) => void;
deleteNodeById: (id: string) => void; deleteNodeById: (id: string) => void;
deleteAgentDownstreamNodesById: (id: string) => void;
deleteAgentToolNodeById: (id: string) => void;
deleteIterationNodeById: (id: string) => void; deleteIterationNodeById: (id: string) => void;
deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void; deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void;
findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
}); });
}, },
deleteNodeById: (id: string) => { deleteNodeById: (id: string) => {
const { nodes, edges } = get();
const {
nodes,
edges,
getOperatorTypeFromId,
deleteAgentDownstreamNodesById,
} = get();
if (getOperatorTypeFromId(id) === Operator.Agent) {
deleteAgentDownstreamNodesById(id);
return;
}
set({ set({
nodes: nodes.filter((node) => node.id !== id), nodes: nodes.filter((node) => node.id !== id),
edges: edges edges: edges
.filter((edge) => edge.target !== id), .filter((edge) => edge.target !== id),
}); });
}, },
deleteAgentDownstreamNodesById: (id) => {
const { edges, nodes } = get();

const { downstreamAgentAndToolNodeIds, downstreamAgentAndToolEdges } =
deleteAllDownstreamAgentsAndTool(id, edges);

set({
nodes: nodes.filter(
(node) =>
!downstreamAgentAndToolNodeIds.some((x) => x === node.id) &&
node.id !== id,
),
edges: edges.filter(
(edge) =>
edge.source !== id &&
edge.target !== id &&
!downstreamAgentAndToolEdges.some((x) => x.id === edge.id),
),
});
},
deleteAgentToolNodeById: (id) => {
const { edges, deleteEdgeById, deleteNodeById } = get();

const edge = edges.find(
(x) => x.source === id && x.sourceHandle === NodeHandleId.Tool,
);

if (edge) {
deleteEdgeById(edge.id);
deleteNodeById(edge.target);
}
},
deleteIterationNodeById: (id: string) => { deleteIterationNodeById: (id: string) => {
const { nodes, edges } = get(); const { nodes, edges } = get();
const children = nodes.filter((node) => node.parentId === id); const children = nodes.filter((node) => node.parentId === id);

+ 34
- 0
web/src/pages/agent/utils/delete-node.ts 查看文件

import { Edge } from '@xyflow/react';
import { filterAllDownstreamAgentAndToolNodeIds } from './filter-downstream-nodes';

// Delete all downstream agent and tool operators of the current agent operator
export function deleteAllDownstreamAgentsAndTool(
nodeId: string,
edges: Edge[],
) {
const downstreamAgentAndToolNodeIds = filterAllDownstreamAgentAndToolNodeIds(
edges,
[nodeId],
);

const downstreamAgentAndToolEdges = downstreamAgentAndToolNodeIds.reduce<
Edge[]
>((pre, cur) => {
const relatedEdges = edges.filter(
(x) => x.source === cur || x.target === cur,
);

relatedEdges.forEach((x) => {
if (!pre.some((y) => y.id !== x.id)) {
pre.push(x);
}
});

return pre;
}, []);

return {
downstreamAgentAndToolNodeIds,
downstreamAgentAndToolEdges,
};
}

+ 31
- 0
web/src/pages/agent/utils/filter-downstream-nodes.ts 查看文件

import { Edge } from '@xyflow/react';
import { NodeHandleId } from '../constant';

// Get all downstream agent operators of the current agent operator
export function filterAllDownstreamAgentAndToolNodeIds(
edges: Edge[],
nodeIds: string[],
) {
return nodeIds.reduce<string[]>((pre, nodeId) => {
const currentEdges = edges.filter(
(x) =>
x.source === nodeId &&
(x.sourceHandle === NodeHandleId.AgentBottom ||
x.sourceHandle === NodeHandleId.Tool),
);

const downstreamNodeIds: string[] = currentEdges.map((x) => x.target);

const ids = downstreamNodeIds.concat(
filterAllDownstreamAgentAndToolNodeIds(edges, downstreamNodeIds),
);

ids.forEach((x) => {
if (pre.every((y) => y !== x)) {
pre.push(x);
}
});

return pre;
}, []);
}

Loading…
取消
儲存