| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- import { RAGFlowNodeType } from '@/interfaces/database/flow';
- import type {} from '@redux-devtools/extension';
- import {
- Connection,
- Edge,
- EdgeChange,
- OnConnect,
- OnEdgesChange,
- OnNodesChange,
- OnSelectionChangeFunc,
- OnSelectionChangeParams,
- addEdge,
- applyEdgeChanges,
- applyNodeChanges,
- } from '@xyflow/react';
- import { omit } from 'lodash';
- import differenceWith from 'lodash/differenceWith';
- import intersectionWith from 'lodash/intersectionWith';
- import lodashSet from 'lodash/set';
- import { create } from 'zustand';
- import { devtools } from 'zustand/middleware';
- import { immer } from 'zustand/middleware/immer';
- import { Operator, SwitchElseTo } from './constant';
- import {
- duplicateNodeForm,
- generateDuplicateNode,
- generateNodeNamesWithIncreasingIndex,
- getOperatorIndex,
- isEdgeEqual,
- } from './utils';
-
- export type RFState = {
- nodes: RAGFlowNodeType[];
- edges: Edge[];
- selectedNodeIds: string[];
- selectedEdgeIds: string[];
- clickedNodeId: string; // currently selected node
- onNodesChange: OnNodesChange<RAGFlowNodeType>;
- onEdgesChange: OnEdgesChange;
- onConnect: OnConnect;
- setNodes: (nodes: RAGFlowNodeType[]) => void;
- setEdges: (edges: Edge[]) => void;
- setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void;
- updateNodeForm: (
- nodeId: string,
- values: any,
- path?: (string | number)[],
- ) => RAGFlowNodeType[];
- onSelectionChange: OnSelectionChangeFunc;
- addNode: (nodes: RAGFlowNodeType) => void;
- getNode: (id?: string | null) => RAGFlowNodeType | undefined;
- addEdge: (connection: Connection) => void;
- getEdge: (id: string) => Edge | undefined;
- updateFormDataOnConnect: (connection: Connection) => void;
- updateSwitchFormData: (
- source: string,
- sourceHandle?: string | null,
- target?: string | null,
- ) => void;
- deletePreviousEdgeOfClassificationNode: (connection: Connection) => void;
- duplicateNode: (id: string, name: string) => void;
- duplicateIterationNode: (id: string, name: string) => void;
- deleteEdge: () => void;
- deleteEdgeById: (id: string) => void;
- deleteNodeById: (id: string) => void;
- deleteIterationNodeById: (id: string) => void;
- deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void;
- findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
- updateMutableNodeFormItem: (id: string, field: string, value: any) => void;
- getOperatorTypeFromId: (id?: string | null) => string | undefined;
- getParentIdById: (id?: string | null) => string | undefined;
- updateNodeName: (id: string, name: string) => void;
- generateNodeName: (name: string) => string;
- setClickedNodeId: (id?: string) => void;
- };
-
- // this is our useStore hook that we can use in our components to get parts of the store and call actions
- const useGraphStore = create<RFState>()(
- devtools(
- immer((set, get) => ({
- nodes: [] as RAGFlowNodeType[],
- edges: [] as Edge[],
- selectedNodeIds: [] as string[],
- selectedEdgeIds: [] as string[],
- clickedNodeId: '',
- onNodesChange: (changes) => {
- set({
- nodes: applyNodeChanges(changes, get().nodes),
- });
- },
- onEdgesChange: (changes: EdgeChange[]) => {
- set({
- edges: applyEdgeChanges(changes, get().edges),
- });
- },
- onConnect: (connection: Connection) => {
- const {
- deletePreviousEdgeOfClassificationNode,
- updateFormDataOnConnect,
- } = get();
- set({
- edges: addEdge(connection, get().edges),
- });
- deletePreviousEdgeOfClassificationNode(connection);
- updateFormDataOnConnect(connection);
- },
- onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => {
- set({
- selectedEdgeIds: edges.map((x) => x.id),
- selectedNodeIds: nodes.map((x) => x.id),
- });
- },
- setNodes: (nodes: RAGFlowNodeType[]) => {
- set({ nodes });
- },
- setEdges: (edges: Edge[]) => {
- set({ edges });
- },
- setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => {
- const { edges, setEdges } = get();
- // the previous downstream edge of this node
- const previousDownstreamEdges = edges.filter(
- (x) => x.source === nodeId,
- );
- const isDifferent =
- previousDownstreamEdges.length !== currentDownstreamEdges.length ||
- !previousDownstreamEdges.every((x) =>
- currentDownstreamEdges.some(
- (y) =>
- y.source === x.source &&
- y.target === x.target &&
- y.sourceHandle === x.sourceHandle,
- ),
- ) ||
- !currentDownstreamEdges.every((x) =>
- previousDownstreamEdges.some(
- (y) =>
- y.source === x.source &&
- y.target === x.target &&
- y.sourceHandle === x.sourceHandle,
- ),
- );
-
- const intersectionDownstreamEdges = intersectionWith(
- previousDownstreamEdges,
- currentDownstreamEdges,
- isEdgeEqual,
- );
- if (isDifferent) {
- // other operator's edges
- const irrelevantEdges = edges.filter((x) => x.source !== nodeId);
- // the added downstream edges
- const selfAddedDownstreamEdges = differenceWith(
- currentDownstreamEdges,
- intersectionDownstreamEdges,
- isEdgeEqual,
- );
- setEdges([
- ...irrelevantEdges,
- ...intersectionDownstreamEdges,
- ...selfAddedDownstreamEdges,
- ]);
- }
- },
- addNode: (node: RAGFlowNodeType) => {
- set({ nodes: get().nodes.concat(node) });
- },
- getNode: (id?: string | null) => {
- return get().nodes.find((x) => x.id === id);
- },
- getOperatorTypeFromId: (id?: string | null) => {
- return get().getNode(id)?.data?.label;
- },
- getParentIdById: (id?: string | null) => {
- return get().getNode(id)?.parentId;
- },
- addEdge: (connection: Connection) => {
- set({
- edges: addEdge(connection, get().edges),
- });
- get().deletePreviousEdgeOfClassificationNode(connection);
- // TODO: This may not be reasonable. You need to choose between listening to changes in the form.
- get().updateFormDataOnConnect(connection);
- },
- getEdge: (id: string) => {
- return get().edges.find((x) => x.id === id);
- },
- updateFormDataOnConnect: (connection: Connection) => {
- const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } =
- get();
- const { source, target, sourceHandle } = connection;
- const operatorType = getOperatorTypeFromId(source);
- if (source) {
- switch (operatorType) {
- case Operator.Relevant:
- updateNodeForm(source, { [sourceHandle as string]: target });
- break;
- case Operator.Categorize:
- if (sourceHandle)
- updateNodeForm(source, target, [
- 'category_description',
- sourceHandle,
- 'to',
- ]);
- break;
- case Operator.Switch: {
- updateSwitchFormData(source, sourceHandle, target);
- break;
- }
- default:
- break;
- }
- }
- },
- deletePreviousEdgeOfClassificationNode: (connection: Connection) => {
- // Delete the edge on the classification node or relevant node anchor when the anchor is connected to other nodes
- const { edges, getOperatorTypeFromId, deleteEdgeById } = get();
- // the node containing the anchor
- const anchoredNodes = [
- Operator.Categorize,
- Operator.Relevant,
- Operator.Switch,
- ];
- if (
- anchoredNodes.some(
- (x) => x === getOperatorTypeFromId(connection.source),
- )
- ) {
- const previousEdge = edges.find(
- (x) =>
- x.source === connection.source &&
- x.sourceHandle === connection.sourceHandle &&
- x.target !== connection.target,
- );
- if (previousEdge) {
- deleteEdgeById(previousEdge.id);
- }
- }
- },
- duplicateNode: (id: string, name: string) => {
- const { getNode, addNode, generateNodeName, duplicateIterationNode } =
- get();
- const node = getNode(id);
-
- if (node?.data.label === Operator.Iteration) {
- duplicateIterationNode(id, name);
- return;
- }
-
- addNode({
- ...(node || {}),
- data: {
- ...duplicateNodeForm(node?.data),
- name: generateNodeName(name),
- },
- ...generateDuplicateNode(node?.position, node?.data?.label),
- });
- },
- duplicateIterationNode: (id: string, name: string) => {
- const { getNode, generateNodeName, nodes } = get();
- const node = getNode(id);
-
- const iterationNode: RAGFlowNodeType = {
- ...(node || {}),
- data: {
- ...(node?.data || { label: Operator.Iteration, form: {} }),
- name: generateNodeName(name),
- },
- ...generateDuplicateNode(node?.position, node?.data?.label),
- };
-
- const children = nodes
- .filter((x) => x.parentId === node?.id)
- .map((x) => ({
- ...(x || {}),
- data: {
- ...duplicateNodeForm(x?.data),
- name: generateNodeName(x.data.name),
- },
- ...omit(generateDuplicateNode(x?.position, x?.data?.label), [
- 'position',
- ]),
- parentId: iterationNode.id,
- }));
-
- set({ nodes: nodes.concat(iterationNode, ...children) });
- },
- deleteEdge: () => {
- const { edges, selectedEdgeIds } = get();
- set({
- edges: edges.filter((edge) =>
- selectedEdgeIds.every((x) => x !== edge.id),
- ),
- });
- },
- deleteEdgeById: (id: string) => {
- const {
- edges,
- updateNodeForm,
- getOperatorTypeFromId,
- updateSwitchFormData,
- } = get();
- const currentEdge = edges.find((x) => x.id === id);
-
- if (currentEdge) {
- const { source, sourceHandle } = currentEdge;
- const operatorType = getOperatorTypeFromId(source);
- // After deleting the edge, set the corresponding field in the node's form field to undefined
- switch (operatorType) {
- case Operator.Relevant:
- updateNodeForm(source, {
- [sourceHandle as string]: undefined,
- });
- break;
- case Operator.Categorize:
- if (sourceHandle)
- updateNodeForm(source, undefined, [
- 'category_description',
- sourceHandle,
- 'to',
- ]);
- break;
- case Operator.Switch: {
- updateSwitchFormData(source, sourceHandle, undefined);
- break;
- }
- default:
- break;
- }
- }
- set({
- edges: edges.filter((edge) => edge.id !== id),
- });
- },
- deleteEdgeBySourceAndSourceHandle: ({
- source,
- sourceHandle,
- }: Partial<Connection>) => {
- const { edges } = get();
- const nextEdges = edges.filter(
- (edge) =>
- edge.source !== source || edge.sourceHandle !== sourceHandle,
- );
- set({
- edges: nextEdges,
- });
- },
- deleteNodeById: (id: string) => {
- const { nodes, edges } = get();
- set({
- nodes: nodes.filter((node) => node.id !== id),
- edges: edges
- .filter((edge) => edge.source !== id)
- .filter((edge) => edge.target !== id),
- });
- },
- deleteIterationNodeById: (id: string) => {
- const { nodes, edges } = get();
- const children = nodes.filter((node) => node.parentId === id);
- set({
- nodes: nodes.filter((node) => node.id !== id && node.parentId !== id),
- edges: edges.filter(
- (edge) =>
- edge.source !== id &&
- edge.target !== id &&
- !children.some(
- (child) => edge.source === child.id && edge.target === child.id,
- ),
- ),
- });
- },
- findNodeByName: (name: Operator) => {
- return get().nodes.find((x) => x.data.label === name);
- },
- updateNodeForm: (
- nodeId: string,
- values: any,
- path: (string | number)[] = [],
- ) => {
- const nextNodes = get().nodes.map((node) => {
- if (node.id === nodeId) {
- let nextForm: Record<string, unknown> = { ...node.data.form };
- if (path.length === 0) {
- nextForm = Object.assign(nextForm, values);
- } else {
- lodashSet(nextForm, path, values);
- }
- return {
- ...node,
- data: {
- ...node.data,
- form: nextForm,
- },
- } as any;
- }
-
- return node;
- });
- set({
- nodes: nextNodes,
- });
-
- return nextNodes;
- },
- updateSwitchFormData: (source, sourceHandle, target) => {
- const { updateNodeForm } = get();
- if (sourceHandle) {
- if (sourceHandle === SwitchElseTo) {
- updateNodeForm(source, target, [SwitchElseTo]);
- } else {
- const operatorIndex = getOperatorIndex(sourceHandle);
- if (operatorIndex) {
- updateNodeForm(source, target, [
- 'conditions',
- Number(operatorIndex) - 1, // The index is the conditions form index
- 'to',
- ]);
- }
- }
- }
- },
- updateMutableNodeFormItem: (id: string, field: string, value: any) => {
- const { nodes } = get();
- const idx = nodes.findIndex((x) => x.id === id);
- if (idx) {
- lodashSet(nodes, [idx, 'data', 'form', field], value);
- }
- },
- updateNodeName: (id, name) => {
- if (id) {
- set({
- nodes: get().nodes.map((node) => {
- if (node.id === id) {
- node.data.name = name;
- }
-
- return node;
- }),
- });
- }
- },
- setClickedNodeId: (id?: string) => {
- set({ clickedNodeId: id });
- },
- generateNodeName: (name: string) => {
- const { nodes } = get();
-
- return generateNodeNamesWithIncreasingIndex(name, nodes);
- },
- })),
- { name: 'graph' },
- ),
- );
-
- export default useGraphStore;
|