You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

store.ts 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. import { RAGFlowNodeType } from '@/interfaces/database/flow';
  2. import type {} from '@redux-devtools/extension';
  3. import {
  4. Connection,
  5. Edge,
  6. EdgeChange,
  7. OnConnect,
  8. OnEdgesChange,
  9. OnNodesChange,
  10. OnSelectionChangeFunc,
  11. OnSelectionChangeParams,
  12. addEdge,
  13. applyEdgeChanges,
  14. applyNodeChanges,
  15. } from '@xyflow/react';
  16. import { omit } from 'lodash';
  17. import differenceWith from 'lodash/differenceWith';
  18. import intersectionWith from 'lodash/intersectionWith';
  19. import lodashSet from 'lodash/set';
  20. import { create } from 'zustand';
  21. import { devtools } from 'zustand/middleware';
  22. import { immer } from 'zustand/middleware/immer';
  23. import { Operator, SwitchElseTo } from './constant';
  24. import {
  25. duplicateNodeForm,
  26. generateDuplicateNode,
  27. generateNodeNamesWithIncreasingIndex,
  28. getOperatorIndex,
  29. isEdgeEqual,
  30. } from './utils';
  31. export type RFState = {
  32. nodes: RAGFlowNodeType[];
  33. edges: Edge[];
  34. selectedNodeIds: string[];
  35. selectedEdgeIds: string[];
  36. clickedNodeId: string; // currently selected node
  37. onNodesChange: OnNodesChange<RAGFlowNodeType>;
  38. onEdgesChange: OnEdgesChange;
  39. onConnect: OnConnect;
  40. setNodes: (nodes: RAGFlowNodeType[]) => void;
  41. setEdges: (edges: Edge[]) => void;
  42. setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void;
  43. updateNodeForm: (
  44. nodeId: string,
  45. values: any,
  46. path?: (string | number)[],
  47. ) => RAGFlowNodeType[];
  48. onSelectionChange: OnSelectionChangeFunc;
  49. addNode: (nodes: RAGFlowNodeType) => void;
  50. getNode: (id?: string | null) => RAGFlowNodeType | undefined;
  51. addEdge: (connection: Connection) => void;
  52. getEdge: (id: string) => Edge | undefined;
  53. updateFormDataOnConnect: (connection: Connection) => void;
  54. updateSwitchFormData: (
  55. source: string,
  56. sourceHandle?: string | null,
  57. target?: string | null,
  58. ) => void;
  59. deletePreviousEdgeOfClassificationNode: (connection: Connection) => void;
  60. duplicateNode: (id: string, name: string) => void;
  61. duplicateIterationNode: (id: string, name: string) => void;
  62. deleteEdge: () => void;
  63. deleteEdgeById: (id: string) => void;
  64. deleteNodeById: (id: string) => void;
  65. deleteIterationNodeById: (id: string) => void;
  66. deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void;
  67. findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
  68. updateMutableNodeFormItem: (id: string, field: string, value: any) => void;
  69. getOperatorTypeFromId: (id?: string | null) => string | undefined;
  70. getParentIdById: (id?: string | null) => string | undefined;
  71. updateNodeName: (id: string, name: string) => void;
  72. generateNodeName: (name: string) => string;
  73. setClickedNodeId: (id?: string) => void;
  74. };
  75. // this is our useStore hook that we can use in our components to get parts of the store and call actions
  76. const useGraphStore = create<RFState>()(
  77. devtools(
  78. immer((set, get) => ({
  79. nodes: [] as RAGFlowNodeType[],
  80. edges: [] as Edge[],
  81. selectedNodeIds: [] as string[],
  82. selectedEdgeIds: [] as string[],
  83. clickedNodeId: '',
  84. onNodesChange: (changes) => {
  85. set({
  86. nodes: applyNodeChanges(changes, get().nodes),
  87. });
  88. },
  89. onEdgesChange: (changes: EdgeChange[]) => {
  90. set({
  91. edges: applyEdgeChanges(changes, get().edges),
  92. });
  93. },
  94. onConnect: (connection: Connection) => {
  95. const {
  96. deletePreviousEdgeOfClassificationNode,
  97. updateFormDataOnConnect,
  98. } = get();
  99. set({
  100. edges: addEdge(connection, get().edges),
  101. });
  102. deletePreviousEdgeOfClassificationNode(connection);
  103. updateFormDataOnConnect(connection);
  104. },
  105. onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => {
  106. set({
  107. selectedEdgeIds: edges.map((x) => x.id),
  108. selectedNodeIds: nodes.map((x) => x.id),
  109. });
  110. },
  111. setNodes: (nodes: RAGFlowNodeType[]) => {
  112. set({ nodes });
  113. },
  114. setEdges: (edges: Edge[]) => {
  115. set({ edges });
  116. },
  117. setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => {
  118. const { edges, setEdges } = get();
  119. // the previous downstream edge of this node
  120. const previousDownstreamEdges = edges.filter(
  121. (x) => x.source === nodeId,
  122. );
  123. const isDifferent =
  124. previousDownstreamEdges.length !== currentDownstreamEdges.length ||
  125. !previousDownstreamEdges.every((x) =>
  126. currentDownstreamEdges.some(
  127. (y) =>
  128. y.source === x.source &&
  129. y.target === x.target &&
  130. y.sourceHandle === x.sourceHandle,
  131. ),
  132. ) ||
  133. !currentDownstreamEdges.every((x) =>
  134. previousDownstreamEdges.some(
  135. (y) =>
  136. y.source === x.source &&
  137. y.target === x.target &&
  138. y.sourceHandle === x.sourceHandle,
  139. ),
  140. );
  141. const intersectionDownstreamEdges = intersectionWith(
  142. previousDownstreamEdges,
  143. currentDownstreamEdges,
  144. isEdgeEqual,
  145. );
  146. if (isDifferent) {
  147. // other operator's edges
  148. const irrelevantEdges = edges.filter((x) => x.source !== nodeId);
  149. // the added downstream edges
  150. const selfAddedDownstreamEdges = differenceWith(
  151. currentDownstreamEdges,
  152. intersectionDownstreamEdges,
  153. isEdgeEqual,
  154. );
  155. setEdges([
  156. ...irrelevantEdges,
  157. ...intersectionDownstreamEdges,
  158. ...selfAddedDownstreamEdges,
  159. ]);
  160. }
  161. },
  162. addNode: (node: RAGFlowNodeType) => {
  163. set({ nodes: get().nodes.concat(node) });
  164. },
  165. getNode: (id?: string | null) => {
  166. return get().nodes.find((x) => x.id === id);
  167. },
  168. getOperatorTypeFromId: (id?: string | null) => {
  169. return get().getNode(id)?.data?.label;
  170. },
  171. getParentIdById: (id?: string | null) => {
  172. return get().getNode(id)?.parentId;
  173. },
  174. addEdge: (connection: Connection) => {
  175. set({
  176. edges: addEdge(connection, get().edges),
  177. });
  178. get().deletePreviousEdgeOfClassificationNode(connection);
  179. // TODO: This may not be reasonable. You need to choose between listening to changes in the form.
  180. get().updateFormDataOnConnect(connection);
  181. },
  182. getEdge: (id: string) => {
  183. return get().edges.find((x) => x.id === id);
  184. },
  185. updateFormDataOnConnect: (connection: Connection) => {
  186. const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } =
  187. get();
  188. const { source, target, sourceHandle } = connection;
  189. const operatorType = getOperatorTypeFromId(source);
  190. if (source) {
  191. switch (operatorType) {
  192. case Operator.Relevant:
  193. updateNodeForm(source, { [sourceHandle as string]: target });
  194. break;
  195. case Operator.Categorize:
  196. if (sourceHandle)
  197. updateNodeForm(source, target, [
  198. 'category_description',
  199. sourceHandle,
  200. 'to',
  201. ]);
  202. break;
  203. case Operator.Switch: {
  204. updateSwitchFormData(source, sourceHandle, target);
  205. break;
  206. }
  207. default:
  208. break;
  209. }
  210. }
  211. },
  212. deletePreviousEdgeOfClassificationNode: (connection: Connection) => {
  213. // Delete the edge on the classification node or relevant node anchor when the anchor is connected to other nodes
  214. const { edges, getOperatorTypeFromId, deleteEdgeById } = get();
  215. // the node containing the anchor
  216. const anchoredNodes = [
  217. Operator.Categorize,
  218. Operator.Relevant,
  219. Operator.Switch,
  220. ];
  221. if (
  222. anchoredNodes.some(
  223. (x) => x === getOperatorTypeFromId(connection.source),
  224. )
  225. ) {
  226. const previousEdge = edges.find(
  227. (x) =>
  228. x.source === connection.source &&
  229. x.sourceHandle === connection.sourceHandle &&
  230. x.target !== connection.target,
  231. );
  232. if (previousEdge) {
  233. deleteEdgeById(previousEdge.id);
  234. }
  235. }
  236. },
  237. duplicateNode: (id: string, name: string) => {
  238. const { getNode, addNode, generateNodeName, duplicateIterationNode } =
  239. get();
  240. const node = getNode(id);
  241. if (node?.data.label === Operator.Iteration) {
  242. duplicateIterationNode(id, name);
  243. return;
  244. }
  245. addNode({
  246. ...(node || {}),
  247. data: {
  248. ...duplicateNodeForm(node?.data),
  249. name: generateNodeName(name),
  250. },
  251. ...generateDuplicateNode(node?.position, node?.data?.label),
  252. });
  253. },
  254. duplicateIterationNode: (id: string, name: string) => {
  255. const { getNode, generateNodeName, nodes } = get();
  256. const node = getNode(id);
  257. const iterationNode: RAGFlowNodeType = {
  258. ...(node || {}),
  259. data: {
  260. ...(node?.data || { label: Operator.Iteration, form: {} }),
  261. name: generateNodeName(name),
  262. },
  263. ...generateDuplicateNode(node?.position, node?.data?.label),
  264. };
  265. const children = nodes
  266. .filter((x) => x.parentId === node?.id)
  267. .map((x) => ({
  268. ...(x || {}),
  269. data: {
  270. ...duplicateNodeForm(x?.data),
  271. name: generateNodeName(x.data.name),
  272. },
  273. ...omit(generateDuplicateNode(x?.position, x?.data?.label), [
  274. 'position',
  275. ]),
  276. parentId: iterationNode.id,
  277. }));
  278. set({ nodes: nodes.concat(iterationNode, ...children) });
  279. },
  280. deleteEdge: () => {
  281. const { edges, selectedEdgeIds } = get();
  282. set({
  283. edges: edges.filter((edge) =>
  284. selectedEdgeIds.every((x) => x !== edge.id),
  285. ),
  286. });
  287. },
  288. deleteEdgeById: (id: string) => {
  289. const {
  290. edges,
  291. updateNodeForm,
  292. getOperatorTypeFromId,
  293. updateSwitchFormData,
  294. } = get();
  295. const currentEdge = edges.find((x) => x.id === id);
  296. if (currentEdge) {
  297. const { source, sourceHandle } = currentEdge;
  298. const operatorType = getOperatorTypeFromId(source);
  299. // After deleting the edge, set the corresponding field in the node's form field to undefined
  300. switch (operatorType) {
  301. case Operator.Relevant:
  302. updateNodeForm(source, {
  303. [sourceHandle as string]: undefined,
  304. });
  305. break;
  306. case Operator.Categorize:
  307. if (sourceHandle)
  308. updateNodeForm(source, undefined, [
  309. 'category_description',
  310. sourceHandle,
  311. 'to',
  312. ]);
  313. break;
  314. case Operator.Switch: {
  315. updateSwitchFormData(source, sourceHandle, undefined);
  316. break;
  317. }
  318. default:
  319. break;
  320. }
  321. }
  322. set({
  323. edges: edges.filter((edge) => edge.id !== id),
  324. });
  325. },
  326. deleteEdgeBySourceAndSourceHandle: ({
  327. source,
  328. sourceHandle,
  329. }: Partial<Connection>) => {
  330. const { edges } = get();
  331. const nextEdges = edges.filter(
  332. (edge) =>
  333. edge.source !== source || edge.sourceHandle !== sourceHandle,
  334. );
  335. set({
  336. edges: nextEdges,
  337. });
  338. },
  339. deleteNodeById: (id: string) => {
  340. const { nodes, edges } = get();
  341. set({
  342. nodes: nodes.filter((node) => node.id !== id),
  343. edges: edges
  344. .filter((edge) => edge.source !== id)
  345. .filter((edge) => edge.target !== id),
  346. });
  347. },
  348. deleteIterationNodeById: (id: string) => {
  349. const { nodes, edges } = get();
  350. const children = nodes.filter((node) => node.parentId === id);
  351. set({
  352. nodes: nodes.filter((node) => node.id !== id && node.parentId !== id),
  353. edges: edges.filter(
  354. (edge) =>
  355. edge.source !== id &&
  356. edge.target !== id &&
  357. !children.some(
  358. (child) => edge.source === child.id && edge.target === child.id,
  359. ),
  360. ),
  361. });
  362. },
  363. findNodeByName: (name: Operator) => {
  364. return get().nodes.find((x) => x.data.label === name);
  365. },
  366. updateNodeForm: (
  367. nodeId: string,
  368. values: any,
  369. path: (string | number)[] = [],
  370. ) => {
  371. const nextNodes = get().nodes.map((node) => {
  372. if (node.id === nodeId) {
  373. let nextForm: Record<string, unknown> = { ...node.data.form };
  374. if (path.length === 0) {
  375. nextForm = Object.assign(nextForm, values);
  376. } else {
  377. lodashSet(nextForm, path, values);
  378. }
  379. return {
  380. ...node,
  381. data: {
  382. ...node.data,
  383. form: nextForm,
  384. },
  385. } as any;
  386. }
  387. return node;
  388. });
  389. set({
  390. nodes: nextNodes,
  391. });
  392. return nextNodes;
  393. },
  394. updateSwitchFormData: (source, sourceHandle, target) => {
  395. const { updateNodeForm } = get();
  396. if (sourceHandle) {
  397. if (sourceHandle === SwitchElseTo) {
  398. updateNodeForm(source, target, [SwitchElseTo]);
  399. } else {
  400. const operatorIndex = getOperatorIndex(sourceHandle);
  401. if (operatorIndex) {
  402. updateNodeForm(source, target, [
  403. 'conditions',
  404. Number(operatorIndex) - 1, // The index is the conditions form index
  405. 'to',
  406. ]);
  407. }
  408. }
  409. }
  410. },
  411. updateMutableNodeFormItem: (id: string, field: string, value: any) => {
  412. const { nodes } = get();
  413. const idx = nodes.findIndex((x) => x.id === id);
  414. if (idx) {
  415. lodashSet(nodes, [idx, 'data', 'form', field], value);
  416. }
  417. },
  418. updateNodeName: (id, name) => {
  419. if (id) {
  420. set({
  421. nodes: get().nodes.map((node) => {
  422. if (node.id === id) {
  423. node.data.name = name;
  424. }
  425. return node;
  426. }),
  427. });
  428. }
  429. },
  430. setClickedNodeId: (id?: string) => {
  431. set({ clickedNodeId: id });
  432. },
  433. generateNodeName: (name: string) => {
  434. const { nodes } = get();
  435. return generateNodeNamesWithIncreasingIndex(name, nodes);
  436. },
  437. })),
  438. { name: 'graph' },
  439. ),
  440. );
  441. export default useGraphStore;