You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import { RAGFlowNodeType } from '@/interfaces/database/flow';
  2. import type {} from '@redux-devtools/extension';
  3. import {
  4. Connection,
  5. Edge,
  6. EdgeChange,
  7. OnConnect,
  8. OnEdgesChange,
  9. OnNodesChange,
  10. OnSelectionChangeFunc,
  11. OnSelectionChangeParams,
  12. addEdge,
  13. applyEdgeChanges,
  14. applyNodeChanges,
  15. } from '@xyflow/react';
  16. import { omit } from 'lodash';
  17. import differenceWith from 'lodash/differenceWith';
  18. import intersectionWith from 'lodash/intersectionWith';
  19. import lodashSet from 'lodash/set';
  20. import { create } from 'zustand';
  21. import { devtools } from 'zustand/middleware';
  22. import { immer } from 'zustand/middleware/immer';
  23. import { Operator, SwitchElseTo } from './constant';
  24. import {
  25. duplicateNodeForm,
  26. generateDuplicateNode,
  27. generateNodeNamesWithIncreasingIndex,
  28. getOperatorIndex,
  29. isEdgeEqual,
  30. } from './utils';
  31. export type RFState = {
  32. nodes: RAGFlowNodeType[];
  33. edges: Edge[];
  34. selectedNodeIds: string[];
  35. selectedEdgeIds: string[];
  36. clickedNodeId: string; // currently selected node
  37. clickedToolId: string; // currently selected tool id
  38. onNodesChange: OnNodesChange<RAGFlowNodeType>;
  39. onEdgesChange: OnEdgesChange;
  40. onConnect: OnConnect;
  41. setNodes: (nodes: RAGFlowNodeType[]) => void;
  42. setEdges: (edges: Edge[]) => void;
  43. setEdgesByNodeId: (nodeId: string, edges: Edge[]) => void;
  44. updateNodeForm: (
  45. nodeId: string,
  46. values: any,
  47. path?: (string | number)[],
  48. ) => RAGFlowNodeType[];
  49. onSelectionChange: OnSelectionChangeFunc;
  50. addNode: (nodes: RAGFlowNodeType) => void;
  51. getNode: (id?: string | null) => RAGFlowNodeType | undefined;
  52. addEdge: (connection: Connection) => void;
  53. getEdge: (id: string) => Edge | undefined;
  54. updateFormDataOnConnect: (connection: Connection) => void;
  55. updateSwitchFormData: (
  56. source: string,
  57. sourceHandle?: string | null,
  58. target?: string | null,
  59. isConnecting?: boolean,
  60. ) => void;
  61. deletePreviousEdgeOfClassificationNode: (connection: Connection) => void;
  62. duplicateNode: (id: string, name: string) => void;
  63. duplicateIterationNode: (id: string, name: string) => void;
  64. deleteEdge: () => void;
  65. deleteEdgeById: (id: string) => void;
  66. deleteNodeById: (id: string) => void;
  67. deleteIterationNodeById: (id: string) => void;
  68. deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void;
  69. findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined;
  70. updateMutableNodeFormItem: (id: string, field: string, value: any) => void;
  71. getOperatorTypeFromId: (id?: string | null) => string | undefined;
  72. getParentIdById: (id?: string | null) => string | undefined;
  73. updateNodeName: (id: string, name: string) => void;
  74. generateNodeName: (name: string) => string;
  75. setClickedNodeId: (id?: string) => void;
  76. setClickedToolId: (id?: string) => void;
  77. };
  78. // this is our useStore hook that we can use in our components to get parts of the store and call actions
  79. const useGraphStore = create<RFState>()(
  80. devtools(
  81. immer((set, get) => ({
  82. nodes: [] as RAGFlowNodeType[],
  83. edges: [] as Edge[],
  84. selectedNodeIds: [] as string[],
  85. selectedEdgeIds: [] as string[],
  86. clickedNodeId: '',
  87. clickedToolId: '',
  88. onNodesChange: (changes) => {
  89. set({
  90. nodes: applyNodeChanges(changes, get().nodes),
  91. });
  92. },
  93. onEdgesChange: (changes: EdgeChange[]) => {
  94. set({
  95. edges: applyEdgeChanges(changes, get().edges),
  96. });
  97. },
  98. onConnect: (connection: Connection) => {
  99. const {
  100. deletePreviousEdgeOfClassificationNode,
  101. updateFormDataOnConnect,
  102. } = get();
  103. set({
  104. edges: addEdge(connection, get().edges),
  105. });
  106. deletePreviousEdgeOfClassificationNode(connection);
  107. updateFormDataOnConnect(connection);
  108. },
  109. onSelectionChange: ({ nodes, edges }: OnSelectionChangeParams) => {
  110. set({
  111. selectedEdgeIds: edges.map((x) => x.id),
  112. selectedNodeIds: nodes.map((x) => x.id),
  113. });
  114. },
  115. setNodes: (nodes: RAGFlowNodeType[]) => {
  116. set({ nodes });
  117. },
  118. setEdges: (edges: Edge[]) => {
  119. set({ edges });
  120. },
  121. setEdgesByNodeId: (nodeId: string, currentDownstreamEdges: Edge[]) => {
  122. const { edges, setEdges } = get();
  123. // the previous downstream edge of this node
  124. const previousDownstreamEdges = edges.filter(
  125. (x) => x.source === nodeId,
  126. );
  127. const isDifferent =
  128. previousDownstreamEdges.length !== currentDownstreamEdges.length ||
  129. !previousDownstreamEdges.every((x) =>
  130. currentDownstreamEdges.some(
  131. (y) =>
  132. y.source === x.source &&
  133. y.target === x.target &&
  134. y.sourceHandle === x.sourceHandle,
  135. ),
  136. ) ||
  137. !currentDownstreamEdges.every((x) =>
  138. previousDownstreamEdges.some(
  139. (y) =>
  140. y.source === x.source &&
  141. y.target === x.target &&
  142. y.sourceHandle === x.sourceHandle,
  143. ),
  144. );
  145. const intersectionDownstreamEdges = intersectionWith(
  146. previousDownstreamEdges,
  147. currentDownstreamEdges,
  148. isEdgeEqual,
  149. );
  150. if (isDifferent) {
  151. // other operator's edges
  152. const irrelevantEdges = edges.filter((x) => x.source !== nodeId);
  153. // the added downstream edges
  154. const selfAddedDownstreamEdges = differenceWith(
  155. currentDownstreamEdges,
  156. intersectionDownstreamEdges,
  157. isEdgeEqual,
  158. );
  159. setEdges([
  160. ...irrelevantEdges,
  161. ...intersectionDownstreamEdges,
  162. ...selfAddedDownstreamEdges,
  163. ]);
  164. }
  165. },
  166. addNode: (node: RAGFlowNodeType) => {
  167. set({ nodes: get().nodes.concat(node) });
  168. },
  169. getNode: (id?: string | null) => {
  170. return get().nodes.find((x) => x.id === id);
  171. },
  172. getOperatorTypeFromId: (id?: string | null) => {
  173. return get().getNode(id)?.data?.label;
  174. },
  175. getParentIdById: (id?: string | null) => {
  176. return get().getNode(id)?.parentId;
  177. },
  178. addEdge: (connection: Connection) => {
  179. set({
  180. edges: addEdge(connection, get().edges),
  181. });
  182. get().deletePreviousEdgeOfClassificationNode(connection);
  183. // TODO: This may not be reasonable. You need to choose between listening to changes in the form.
  184. get().updateFormDataOnConnect(connection);
  185. },
  186. getEdge: (id: string) => {
  187. return get().edges.find((x) => x.id === id);
  188. },
  189. updateFormDataOnConnect: (connection: Connection) => {
  190. const { getOperatorTypeFromId, updateNodeForm, updateSwitchFormData } =
  191. get();
  192. const { source, target, sourceHandle } = connection;
  193. const operatorType = getOperatorTypeFromId(source);
  194. if (source) {
  195. switch (operatorType) {
  196. case Operator.Relevant:
  197. updateNodeForm(source, { [sourceHandle as string]: target });
  198. break;
  199. case Operator.Categorize:
  200. if (sourceHandle)
  201. updateNodeForm(source, target, [
  202. 'category_description',
  203. sourceHandle,
  204. 'to',
  205. ]);
  206. break;
  207. case Operator.Switch: {
  208. updateSwitchFormData(source, sourceHandle, target, true);
  209. break;
  210. }
  211. default:
  212. break;
  213. }
  214. }
  215. },
  216. deletePreviousEdgeOfClassificationNode: (connection: Connection) => {
  217. // Delete the edge on the classification node or relevant node anchor when the anchor is connected to other nodes
  218. const { edges, getOperatorTypeFromId, deleteEdgeById } = get();
  219. // the node containing the anchor
  220. const anchoredNodes = [
  221. Operator.Categorize,
  222. Operator.Relevant,
  223. // Operator.Switch,
  224. ];
  225. if (
  226. anchoredNodes.some(
  227. (x) => x === getOperatorTypeFromId(connection.source),
  228. )
  229. ) {
  230. const previousEdge = edges.find(
  231. (x) =>
  232. x.source === connection.source &&
  233. x.sourceHandle === connection.sourceHandle &&
  234. x.target !== connection.target,
  235. );
  236. if (previousEdge) {
  237. deleteEdgeById(previousEdge.id);
  238. }
  239. }
  240. },
  241. duplicateNode: (id: string, name: string) => {
  242. const { getNode, addNode, generateNodeName, duplicateIterationNode } =
  243. get();
  244. const node = getNode(id);
  245. if (node?.data.label === Operator.Iteration) {
  246. duplicateIterationNode(id, name);
  247. return;
  248. }
  249. addNode({
  250. ...(node || {}),
  251. data: {
  252. ...duplicateNodeForm(node?.data),
  253. name: generateNodeName(name),
  254. },
  255. ...generateDuplicateNode(node?.position, node?.data?.label),
  256. });
  257. },
  258. duplicateIterationNode: (id: string, name: string) => {
  259. const { getNode, generateNodeName, nodes } = get();
  260. const node = getNode(id);
  261. const iterationNode: RAGFlowNodeType = {
  262. ...(node || {}),
  263. data: {
  264. ...(node?.data || { label: Operator.Iteration, form: {} }),
  265. name: generateNodeName(name),
  266. },
  267. ...generateDuplicateNode(node?.position, node?.data?.label),
  268. };
  269. const children = nodes
  270. .filter((x) => x.parentId === node?.id)
  271. .map((x) => ({
  272. ...(x || {}),
  273. data: {
  274. ...duplicateNodeForm(x?.data),
  275. name: generateNodeName(x.data.name),
  276. },
  277. ...omit(generateDuplicateNode(x?.position, x?.data?.label), [
  278. 'position',
  279. ]),
  280. parentId: iterationNode.id,
  281. }));
  282. set({ nodes: nodes.concat(iterationNode, ...children) });
  283. },
  284. deleteEdge: () => {
  285. const { edges, selectedEdgeIds } = get();
  286. set({
  287. edges: edges.filter((edge) =>
  288. selectedEdgeIds.every((x) => x !== edge.id),
  289. ),
  290. });
  291. },
  292. deleteEdgeById: (id: string) => {
  293. const {
  294. edges,
  295. updateNodeForm,
  296. getOperatorTypeFromId,
  297. updateSwitchFormData,
  298. } = get();
  299. const currentEdge = edges.find((x) => x.id === id);
  300. if (currentEdge) {
  301. const { source, sourceHandle, target } = currentEdge;
  302. const operatorType = getOperatorTypeFromId(source);
  303. // After deleting the edge, set the corresponding field in the node's form field to undefined
  304. switch (operatorType) {
  305. case Operator.Relevant:
  306. updateNodeForm(source, {
  307. [sourceHandle as string]: undefined,
  308. });
  309. break;
  310. case Operator.Categorize:
  311. if (sourceHandle)
  312. updateNodeForm(source, undefined, [
  313. 'category_description',
  314. sourceHandle,
  315. 'to',
  316. ]);
  317. break;
  318. case Operator.Switch: {
  319. updateSwitchFormData(source, sourceHandle, target, false);
  320. break;
  321. }
  322. default:
  323. break;
  324. }
  325. }
  326. set({
  327. edges: edges.filter((edge) => edge.id !== id),
  328. });
  329. },
  330. deleteEdgeBySourceAndSourceHandle: ({
  331. source,
  332. sourceHandle,
  333. }: Partial<Connection>) => {
  334. const { edges } = get();
  335. const nextEdges = edges.filter(
  336. (edge) =>
  337. edge.source !== source || edge.sourceHandle !== sourceHandle,
  338. );
  339. set({
  340. edges: nextEdges,
  341. });
  342. },
  343. deleteNodeById: (id: string) => {
  344. const { nodes, edges } = get();
  345. set({
  346. nodes: nodes.filter((node) => node.id !== id),
  347. edges: edges
  348. .filter((edge) => edge.source !== id)
  349. .filter((edge) => edge.target !== id),
  350. });
  351. },
  352. deleteIterationNodeById: (id: string) => {
  353. const { nodes, edges } = get();
  354. const children = nodes.filter((node) => node.parentId === id);
  355. set({
  356. nodes: nodes.filter((node) => node.id !== id && node.parentId !== id),
  357. edges: edges.filter(
  358. (edge) =>
  359. edge.source !== id &&
  360. edge.target !== id &&
  361. !children.some(
  362. (child) => edge.source === child.id && edge.target === child.id,
  363. ),
  364. ),
  365. });
  366. },
  367. findNodeByName: (name: Operator) => {
  368. return get().nodes.find((x) => x.data.label === name);
  369. },
  370. updateNodeForm: (
  371. nodeId: string,
  372. values: any,
  373. path: (string | number)[] = [],
  374. ) => {
  375. const nextNodes = get().nodes.map((node) => {
  376. if (node.id === nodeId) {
  377. let nextForm: Record<string, unknown> = { ...node.data.form };
  378. if (path.length === 0) {
  379. nextForm = Object.assign(nextForm, values);
  380. } else {
  381. lodashSet(nextForm, path, values);
  382. }
  383. return {
  384. ...node,
  385. data: {
  386. ...node.data,
  387. form: nextForm,
  388. },
  389. } as any;
  390. }
  391. return node;
  392. });
  393. set({
  394. nodes: nextNodes,
  395. });
  396. return nextNodes;
  397. },
  398. updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
  399. const { updateNodeForm, edges } = get();
  400. if (sourceHandle) {
  401. // A handle will connect to multiple downstream nodes
  402. let currentHandleTargets = edges
  403. .filter(
  404. (x) =>
  405. x.source === source &&
  406. x.sourceHandle === sourceHandle &&
  407. typeof x.target === 'string',
  408. )
  409. .map((x) => x.target);
  410. let targets: string[] = currentHandleTargets;
  411. if (target) {
  412. if (!isConnecting) {
  413. targets = currentHandleTargets.filter((x) => x !== target);
  414. }
  415. }
  416. if (sourceHandle === SwitchElseTo) {
  417. updateNodeForm(source, targets, [SwitchElseTo]);
  418. } else {
  419. const operatorIndex = getOperatorIndex(sourceHandle);
  420. if (operatorIndex) {
  421. updateNodeForm(source, targets, [
  422. 'conditions',
  423. Number(operatorIndex) - 1, // The index is the conditions form index
  424. 'to',
  425. ]);
  426. }
  427. }
  428. }
  429. },
  430. updateMutableNodeFormItem: (id: string, field: string, value: any) => {
  431. const { nodes } = get();
  432. const idx = nodes.findIndex((x) => x.id === id);
  433. if (idx) {
  434. lodashSet(nodes, [idx, 'data', 'form', field], value);
  435. }
  436. },
  437. updateNodeName: (id, name) => {
  438. if (id) {
  439. set({
  440. nodes: get().nodes.map((node) => {
  441. if (node.id === id) {
  442. node.data.name = name;
  443. }
  444. return node;
  445. }),
  446. });
  447. }
  448. },
  449. setClickedNodeId: (id?: string) => {
  450. set({ clickedNodeId: id });
  451. },
  452. generateNodeName: (name: string) => {
  453. const { nodes } = get();
  454. return generateNodeNamesWithIncreasingIndex(name, nodes);
  455. },
  456. setClickedToolId: (id?: string) => {
  457. set({ clickedToolId: id });
  458. },
  459. })),
  460. { name: 'graph', trace: true },
  461. ),
  462. );
  463. export default useGraphStore;