You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. import {
  2. Connection,
  3. Edge,
  4. getOutgoers,
  5. Node,
  6. Position,
  7. ReactFlowInstance,
  8. } from '@xyflow/react';
  9. import React, { useCallback, useEffect, useMemo, useState } from 'react';
  10. // import { shallow } from 'zustand/shallow';
  11. import { settledModelVariableMap } from '@/constants/knowledge';
  12. import { useFetchModelId } from '@/hooks/logic-hooks';
  13. import { RAGFlowNodeType } from '@/interfaces/database/flow';
  14. import { humanId } from 'human-id';
  15. import { get, lowerFirst, omit } from 'lodash';
  16. import { UseFormReturn } from 'react-hook-form';
  17. import { useTranslation } from 'react-i18next';
  18. import {
  19. initialAgentValues,
  20. initialAkShareValues,
  21. initialArXivValues,
  22. initialBaiduFanyiValues,
  23. initialBaiduValues,
  24. initialBeginValues,
  25. initialBingValues,
  26. initialCategorizeValues,
  27. initialCodeValues,
  28. initialConcentratorValues,
  29. initialCrawlerValues,
  30. initialDeepLValues,
  31. initialDuckValues,
  32. initialEmailValues,
  33. initialExeSqlValues,
  34. initialGithubValues,
  35. initialGoogleScholarValues,
  36. initialGoogleValues,
  37. initialInvokeValues,
  38. initialIterationValues,
  39. initialJin10Values,
  40. initialKeywordExtractValues,
  41. initialMessageValues,
  42. initialNoteValues,
  43. initialPubMedValues,
  44. initialQWeatherValues,
  45. initialRelevantValues,
  46. initialRetrievalValues,
  47. initialRewriteQuestionValues,
  48. initialStringTransformValues,
  49. initialSwitchValues,
  50. initialTavilyExtractValues,
  51. initialTavilyValues,
  52. initialTuShareValues,
  53. initialUserFillUpValues,
  54. initialWaitingDialogueValues,
  55. initialWenCaiValues,
  56. initialWikipediaValues,
  57. initialYahooFinanceValues,
  58. NodeMap,
  59. Operator,
  60. RestrictedUpstreamMap,
  61. } from './constant';
  62. import useGraphStore, { RFState } from './store';
  63. import {
  64. buildCategorizeObjectFromList,
  65. generateNodeNamesWithIncreasingIndex,
  66. getNodeDragHandle,
  67. getRelativePositionToIterationNode,
  68. replaceIdWithText,
  69. } from './utils';
  70. const selector = (state: RFState) => ({
  71. nodes: state.nodes,
  72. edges: state.edges,
  73. onNodesChange: state.onNodesChange,
  74. onEdgesChange: state.onEdgesChange,
  75. onConnect: state.onConnect,
  76. setNodes: state.setNodes,
  77. onSelectionChange: state.onSelectionChange,
  78. onEdgeMouseEnter: state.onEdgeMouseEnter,
  79. onEdgeMouseLeave: state.onEdgeMouseLeave,
  80. });
  81. export const useSelectCanvasData = () => {
  82. // return useStore(useShallow(selector)); // throw error
  83. // return useStore(selector, shallow);
  84. return useGraphStore(selector);
  85. };
  86. export const useInitializeOperatorParams = () => {
  87. const llmId = useFetchModelId();
  88. const initialFormValuesMap = useMemo(() => {
  89. return {
  90. [Operator.Begin]: initialBeginValues,
  91. [Operator.Retrieval]: initialRetrievalValues,
  92. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  93. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  94. [Operator.RewriteQuestion]: {
  95. ...initialRewriteQuestionValues,
  96. llm_id: llmId,
  97. },
  98. [Operator.Message]: initialMessageValues,
  99. [Operator.KeywordExtract]: {
  100. ...initialKeywordExtractValues,
  101. llm_id: llmId,
  102. },
  103. [Operator.DuckDuckGo]: initialDuckValues,
  104. [Operator.Baidu]: initialBaiduValues,
  105. [Operator.Wikipedia]: initialWikipediaValues,
  106. [Operator.PubMed]: initialPubMedValues,
  107. [Operator.ArXiv]: initialArXivValues,
  108. [Operator.Google]: initialGoogleValues,
  109. [Operator.Bing]: initialBingValues,
  110. [Operator.GoogleScholar]: initialGoogleScholarValues,
  111. [Operator.DeepL]: initialDeepLValues,
  112. [Operator.GitHub]: initialGithubValues,
  113. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  114. [Operator.QWeather]: initialQWeatherValues,
  115. [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
  116. [Operator.Switch]: initialSwitchValues,
  117. [Operator.WenCai]: initialWenCaiValues,
  118. [Operator.AkShare]: initialAkShareValues,
  119. [Operator.YahooFinance]: initialYahooFinanceValues,
  120. [Operator.Jin10]: initialJin10Values,
  121. [Operator.Concentrator]: initialConcentratorValues,
  122. [Operator.TuShare]: initialTuShareValues,
  123. [Operator.Note]: initialNoteValues,
  124. [Operator.Crawler]: initialCrawlerValues,
  125. [Operator.Invoke]: initialInvokeValues,
  126. [Operator.Email]: initialEmailValues,
  127. [Operator.Iteration]: initialIterationValues,
  128. [Operator.IterationStart]: initialIterationValues,
  129. [Operator.Code]: initialCodeValues,
  130. [Operator.WaitingDialogue]: initialWaitingDialogueValues,
  131. [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
  132. [Operator.TavilySearch]: initialTavilyValues,
  133. [Operator.TavilyExtract]: initialTavilyExtractValues,
  134. [Operator.Tool]: {},
  135. [Operator.UserFillUp]: initialUserFillUpValues,
  136. [Operator.StringTransform]: initialStringTransformValues,
  137. };
  138. }, [llmId]);
  139. const initializeOperatorParams = useCallback(
  140. (operatorName: Operator) => {
  141. return initialFormValuesMap[operatorName];
  142. },
  143. [initialFormValuesMap],
  144. );
  145. return initializeOperatorParams;
  146. };
  147. export const useHandleDrag = () => {
  148. const handleDragStart = useCallback(
  149. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  150. ev.dataTransfer.setData('application/@xyflow/react', operatorId);
  151. ev.dataTransfer.effectAllowed = 'move';
  152. },
  153. [],
  154. );
  155. return { handleDragStart };
  156. };
  157. export const useGetNodeName = () => {
  158. const { t } = useTranslation();
  159. return (type: string) => {
  160. const name = t(`flow.${lowerFirst(type)}`);
  161. return name;
  162. };
  163. };
  164. export const useHandleDrop = () => {
  165. const addNode = useGraphStore((state) => state.addNode);
  166. const nodes = useGraphStore((state) => state.nodes);
  167. const [reactFlowInstance, setReactFlowInstance] =
  168. useState<ReactFlowInstance<any, any>>();
  169. const initializeOperatorParams = useInitializeOperatorParams();
  170. const getNodeName = useGetNodeName();
  171. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  172. event.preventDefault();
  173. event.dataTransfer.dropEffect = 'move';
  174. }, []);
  175. const onDrop = useCallback(
  176. (event: React.DragEvent<HTMLDivElement>) => {
  177. event.preventDefault();
  178. const type = event.dataTransfer.getData('application/@xyflow/react');
  179. // check if the dropped element is valid
  180. if (typeof type === 'undefined' || !type) {
  181. return;
  182. }
  183. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  184. // and you don't need to subtract the reactFlowBounds.left/top anymore
  185. // details: https://@xyflow/react.dev/whats-new/2023-11-10
  186. const position = reactFlowInstance?.screenToFlowPosition({
  187. x: event.clientX,
  188. y: event.clientY,
  189. });
  190. const newNode: Node<any> = {
  191. id: `${type}:${humanId()}`,
  192. type: NodeMap[type as Operator] || 'ragNode',
  193. position: position || {
  194. x: 0,
  195. y: 0,
  196. },
  197. data: {
  198. label: `${type}`,
  199. name: generateNodeNamesWithIncreasingIndex(getNodeName(type), nodes),
  200. form: initializeOperatorParams(type as Operator),
  201. },
  202. sourcePosition: Position.Right,
  203. targetPosition: Position.Left,
  204. dragHandle: getNodeDragHandle(type),
  205. };
  206. if (type === Operator.Iteration) {
  207. newNode.width = 500;
  208. newNode.height = 250;
  209. const iterationStartNode: Node<any> = {
  210. id: `${Operator.IterationStart}:${humanId()}`,
  211. type: 'iterationStartNode',
  212. position: { x: 50, y: 100 },
  213. // draggable: false,
  214. data: {
  215. label: Operator.IterationStart,
  216. name: Operator.IterationStart,
  217. form: {},
  218. },
  219. parentId: newNode.id,
  220. extent: 'parent',
  221. };
  222. addNode(newNode);
  223. addNode(iterationStartNode);
  224. } else {
  225. const subNodeOfIteration = getRelativePositionToIterationNode(
  226. nodes,
  227. position,
  228. );
  229. if (subNodeOfIteration) {
  230. newNode.parentId = subNodeOfIteration.parentId;
  231. newNode.position = subNodeOfIteration.position;
  232. newNode.extent = 'parent';
  233. }
  234. addNode(newNode);
  235. }
  236. },
  237. [reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode],
  238. );
  239. return { onDrop, onDragOver, setReactFlowInstance, reactFlowInstance };
  240. };
  241. export const useHandleFormValuesChange = (
  242. operatorName: Operator,
  243. id?: string,
  244. form?: UseFormReturn,
  245. ) => {
  246. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  247. const handleValuesChange = useCallback(
  248. (changedValues: any, values: any) => {
  249. let nextValues: any = values;
  250. // Fixed the issue that the related form value does not change after selecting the freedom field of the model
  251. if (
  252. Object.keys(changedValues).length === 1 &&
  253. 'parameter' in changedValues &&
  254. changedValues['parameter'] in settledModelVariableMap
  255. ) {
  256. nextValues = {
  257. ...values,
  258. ...settledModelVariableMap[
  259. changedValues['parameter'] as keyof typeof settledModelVariableMap
  260. ],
  261. };
  262. }
  263. if (id) {
  264. updateNodeForm(id, nextValues);
  265. }
  266. },
  267. [updateNodeForm, id],
  268. );
  269. useEffect(() => {
  270. const subscription = form?.watch((value, { name, type, values }) => {
  271. if (id && name) {
  272. console.log(
  273. '🚀 ~ useEffect ~ value:',
  274. name,
  275. type,
  276. values,
  277. operatorName,
  278. );
  279. let nextValues: any = value;
  280. // Fixed the issue that the related form value does not change after selecting the freedom field of the model
  281. if (
  282. name === 'parameter' &&
  283. value['parameter'] in settledModelVariableMap
  284. ) {
  285. nextValues = {
  286. ...value,
  287. ...settledModelVariableMap[
  288. value['parameter'] as keyof typeof settledModelVariableMap
  289. ],
  290. };
  291. }
  292. const categoryDescriptionRegex = /items\.\d+\.name/g;
  293. if (
  294. operatorName === Operator.Categorize &&
  295. categoryDescriptionRegex.test(name)
  296. ) {
  297. nextValues = {
  298. ...omit(value, 'items'),
  299. category_description: buildCategorizeObjectFromList(value.items),
  300. };
  301. }
  302. // Manually triggered form updates are synchronized to the canvas
  303. if (type) {
  304. updateNodeForm(id, nextValues);
  305. }
  306. }
  307. });
  308. return () => subscription?.unsubscribe();
  309. }, [form, form?.watch, id, operatorName, updateNodeForm]);
  310. return { handleValuesChange };
  311. };
  312. export const useValidateConnection = () => {
  313. const { getOperatorTypeFromId, getParentIdById, edges, nodes } =
  314. useGraphStore((state) => state);
  315. const isSameNodeChild = useCallback(
  316. (connection: Connection | Edge) => {
  317. const sourceParentId = getParentIdById(connection.source);
  318. const targetParentId = getParentIdById(connection.target);
  319. if (sourceParentId || targetParentId) {
  320. return sourceParentId === targetParentId;
  321. }
  322. return true;
  323. },
  324. [getParentIdById],
  325. );
  326. const hasCanvasCycle = useCallback(
  327. (connection: Connection | Edge) => {
  328. const target = nodes.find((node) => node.id === connection.target);
  329. const hasCycle = (node: RAGFlowNodeType, visited = new Set()) => {
  330. if (visited.has(node.id)) return false;
  331. visited.add(node.id);
  332. for (const outgoer of getOutgoers(node, nodes, edges)) {
  333. if (outgoer.id === connection.source) return true;
  334. if (hasCycle(outgoer, visited)) return true;
  335. }
  336. };
  337. if (target?.id === connection.source) return false;
  338. return target ? !hasCycle(target) : false;
  339. },
  340. [edges, nodes],
  341. );
  342. // restricted lines cannot be connected successfully.
  343. const isValidConnection = useCallback(
  344. (connection: Connection | Edge) => {
  345. // node cannot connect to itself
  346. const isSelfConnected = connection.target === connection.source;
  347. // limit the connection between two nodes to only one connection line in one direction
  348. // const hasLine = edges.some(
  349. // (x) => x.source === connection.source && x.target === connection.target,
  350. // );
  351. const ret =
  352. !isSelfConnected &&
  353. RestrictedUpstreamMap[
  354. getOperatorTypeFromId(connection.source) as Operator
  355. ]?.every((x) => x !== getOperatorTypeFromId(connection.target)) &&
  356. isSameNodeChild(connection) &&
  357. hasCanvasCycle(connection);
  358. return ret;
  359. },
  360. [getOperatorTypeFromId, hasCanvasCycle, isSameNodeChild],
  361. );
  362. return isValidConnection;
  363. };
  364. export const useReplaceIdWithName = () => {
  365. const getNode = useGraphStore((state) => state.getNode);
  366. const replaceIdWithName = useCallback(
  367. (id?: string) => {
  368. return getNode(id)?.data.name;
  369. },
  370. [getNode],
  371. );
  372. return replaceIdWithName;
  373. };
  374. export const useReplaceIdWithText = (output: unknown) => {
  375. const getNameById = useReplaceIdWithName();
  376. return {
  377. replacedOutput: replaceIdWithText(output, getNameById),
  378. getNameById,
  379. };
  380. };
  381. export const useDuplicateNode = () => {
  382. const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
  383. const getNodeName = useGetNodeName();
  384. const duplicateNode = useCallback(
  385. (id: string, label: string) => {
  386. duplicateNodeById(id, getNodeName(label));
  387. },
  388. [duplicateNodeById, getNodeName],
  389. );
  390. return duplicateNode;
  391. };
  392. export const useCopyPaste = () => {
  393. const nodes = useGraphStore((state) => state.nodes);
  394. const duplicateNode = useDuplicateNode();
  395. const onCopyCapture = useCallback(
  396. (event: ClipboardEvent) => {
  397. if (get(event, 'srcElement.tagName') !== 'BODY') return;
  398. event.preventDefault();
  399. const nodesStr = JSON.stringify(
  400. nodes.filter((n) => n.selected && n.data.label !== Operator.Begin),
  401. );
  402. event.clipboardData?.setData('agent:nodes', nodesStr);
  403. },
  404. [nodes],
  405. );
  406. const onPasteCapture = useCallback(
  407. (event: ClipboardEvent) => {
  408. const nodes = JSON.parse(
  409. event.clipboardData?.getData('agent:nodes') || '[]',
  410. ) as RAGFlowNodeType[] | undefined;
  411. if (Array.isArray(nodes) && nodes.length) {
  412. event.preventDefault();
  413. nodes.forEach((n) => {
  414. duplicateNode(n.id, n.data.label);
  415. });
  416. }
  417. },
  418. [duplicateNode],
  419. );
  420. useEffect(() => {
  421. window.addEventListener('copy', onCopyCapture);
  422. return () => {
  423. window.removeEventListener('copy', onCopyCapture);
  424. };
  425. }, [onCopyCapture]);
  426. useEffect(() => {
  427. window.addEventListener('paste', onPasteCapture);
  428. return () => {
  429. window.removeEventListener('paste', onPasteCapture);
  430. };
  431. }, [onPasteCapture]);
  432. };