Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. import React, {
  2. ChangeEvent,
  3. useCallback,
  4. useEffect,
  5. useMemo,
  6. useState,
  7. } from 'react';
  8. import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
  9. // import { shallow } from 'zustand/shallow';
  10. import { variableEnabledFieldMap } from '@/constants/chat';
  11. import {
  12. ModelVariableType,
  13. settledModelVariableMap,
  14. } from '@/constants/knowledge';
  15. import { useFetchModelId } from '@/hooks/logic-hooks';
  16. import { Variable } from '@/interfaces/database/chat';
  17. import { FormInstance, message } from 'antd';
  18. import { humanId } from 'human-id';
  19. import { get, isEmpty, lowerFirst, pick } from 'lodash';
  20. import trim from 'lodash/trim';
  21. import { useTranslation } from 'react-i18next';
  22. import { v4 as uuid } from 'uuid';
  23. import {
  24. NodeMap,
  25. Operator,
  26. RestrictedUpstreamMap,
  27. SwitchElseTo,
  28. initialAkShareValues,
  29. initialArXivValues,
  30. initialBaiduFanyiValues,
  31. initialBaiduValues,
  32. initialBeginValues,
  33. initialBingValues,
  34. initialCategorizeValues,
  35. initialConcentratorValues,
  36. initialCrawlerValues,
  37. initialDeepLValues,
  38. initialDuckValues,
  39. initialEmailValues,
  40. initialExeSqlValues,
  41. initialGenerateValues,
  42. initialGithubValues,
  43. initialGoogleScholarValues,
  44. initialGoogleValues,
  45. initialInvokeValues,
  46. initialIterationValues,
  47. initialJin10Values,
  48. initialKeywordExtractValues,
  49. initialMessageValues,
  50. initialNoteValues,
  51. initialPubMedValues,
  52. initialQWeatherValues,
  53. initialRelevantValues,
  54. initialRetrievalValues,
  55. initialRewriteQuestionValues,
  56. initialSwitchValues,
  57. initialTemplateValues,
  58. initialTuShareValues,
  59. initialWenCaiValues,
  60. initialWikipediaValues,
  61. initialYahooFinanceValues,
  62. } from './constant';
  63. import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface';
  64. import useGraphStore, { RFState } from './store';
  65. import {
  66. generateNodeNamesWithIncreasingIndex,
  67. generateSwitchHandleText,
  68. getNodeDragHandle,
  69. getRelativePositionToIterationNode,
  70. replaceIdWithText,
  71. } from './utils';
  72. const selector = (state: RFState) => ({
  73. nodes: state.nodes,
  74. edges: state.edges,
  75. onNodesChange: state.onNodesChange,
  76. onEdgesChange: state.onEdgesChange,
  77. onConnect: state.onConnect,
  78. setNodes: state.setNodes,
  79. onSelectionChange: state.onSelectionChange,
  80. });
  81. export const useSelectCanvasData = () => {
  82. // return useStore(useShallow(selector)); // throw error
  83. // return useStore(selector, shallow);
  84. return useGraphStore(selector);
  85. };
  86. export const useInitializeOperatorParams = () => {
  87. const llmId = useFetchModelId();
  88. const initialFormValuesMap = useMemo(() => {
  89. return {
  90. [Operator.Begin]: initialBeginValues,
  91. [Operator.Retrieval]: initialRetrievalValues,
  92. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  93. [Operator.Answer]: {},
  94. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  95. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  96. [Operator.RewriteQuestion]: {
  97. ...initialRewriteQuestionValues,
  98. llm_id: llmId,
  99. },
  100. [Operator.Message]: initialMessageValues,
  101. [Operator.KeywordExtract]: {
  102. ...initialKeywordExtractValues,
  103. llm_id: llmId,
  104. },
  105. [Operator.DuckDuckGo]: initialDuckValues,
  106. [Operator.Baidu]: initialBaiduValues,
  107. [Operator.Wikipedia]: initialWikipediaValues,
  108. [Operator.PubMed]: initialPubMedValues,
  109. [Operator.ArXiv]: initialArXivValues,
  110. [Operator.Google]: initialGoogleValues,
  111. [Operator.Bing]: initialBingValues,
  112. [Operator.GoogleScholar]: initialGoogleScholarValues,
  113. [Operator.DeepL]: initialDeepLValues,
  114. [Operator.GitHub]: initialGithubValues,
  115. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  116. [Operator.QWeather]: initialQWeatherValues,
  117. [Operator.ExeSQL]: initialExeSqlValues,
  118. [Operator.Switch]: initialSwitchValues,
  119. [Operator.WenCai]: initialWenCaiValues,
  120. [Operator.AkShare]: initialAkShareValues,
  121. [Operator.YahooFinance]: initialYahooFinanceValues,
  122. [Operator.Jin10]: initialJin10Values,
  123. [Operator.Concentrator]: initialConcentratorValues,
  124. [Operator.TuShare]: initialTuShareValues,
  125. [Operator.Note]: initialNoteValues,
  126. [Operator.Crawler]: initialCrawlerValues,
  127. [Operator.Invoke]: initialInvokeValues,
  128. [Operator.Template]: initialTemplateValues,
  129. [Operator.Email]: initialEmailValues,
  130. [Operator.Iteration]: initialIterationValues,
  131. [Operator.IterationStart]: initialIterationValues,
  132. };
  133. }, [llmId]);
  134. const initializeOperatorParams = useCallback(
  135. (operatorName: Operator) => {
  136. return initialFormValuesMap[operatorName];
  137. },
  138. [initialFormValuesMap],
  139. );
  140. return initializeOperatorParams;
  141. };
  142. export const useHandleDrag = () => {
  143. const handleDragStart = useCallback(
  144. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  145. ev.dataTransfer.setData('application/reactflow', operatorId);
  146. ev.dataTransfer.effectAllowed = 'move';
  147. },
  148. [],
  149. );
  150. return { handleDragStart };
  151. };
  152. export const useGetNodeName = () => {
  153. const { t } = useTranslation();
  154. return (type: string) => {
  155. const name = t(`flow.${lowerFirst(type)}`);
  156. return name;
  157. };
  158. };
  159. export const useHandleDrop = () => {
  160. const addNode = useGraphStore((state) => state.addNode);
  161. const nodes = useGraphStore((state) => state.nodes);
  162. const [reactFlowInstance, setReactFlowInstance] =
  163. useState<ReactFlowInstance<any, any>>();
  164. const initializeOperatorParams = useInitializeOperatorParams();
  165. const getNodeName = useGetNodeName();
  166. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  167. event.preventDefault();
  168. event.dataTransfer.dropEffect = 'move';
  169. }, []);
  170. const onDrop = useCallback(
  171. (event: React.DragEvent<HTMLDivElement>) => {
  172. event.preventDefault();
  173. const type = event.dataTransfer.getData('application/reactflow');
  174. // check if the dropped element is valid
  175. if (typeof type === 'undefined' || !type) {
  176. return;
  177. }
  178. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  179. // and you don't need to subtract the reactFlowBounds.left/top anymore
  180. // details: https://reactflow.dev/whats-new/2023-11-10
  181. const position = reactFlowInstance?.screenToFlowPosition({
  182. x: event.clientX,
  183. y: event.clientY,
  184. });
  185. const newNode: Node<any> = {
  186. id: `${type}:${humanId()}`,
  187. type: NodeMap[type as Operator] || 'ragNode',
  188. position: position || {
  189. x: 0,
  190. y: 0,
  191. },
  192. data: {
  193. label: `${type}`,
  194. name: generateNodeNamesWithIncreasingIndex(getNodeName(type), nodes),
  195. form: initializeOperatorParams(type as Operator),
  196. },
  197. sourcePosition: Position.Right,
  198. targetPosition: Position.Left,
  199. dragHandle: getNodeDragHandle(type),
  200. };
  201. if (type === Operator.Iteration) {
  202. newNode.style = {
  203. width: 500,
  204. height: 250,
  205. };
  206. const iterationStartNode: Node<any> = {
  207. id: `${Operator.IterationStart}:${humanId()}`,
  208. type: 'iterationStartNode',
  209. position: { x: 50, y: 100 },
  210. // draggable: false,
  211. data: {
  212. label: Operator.IterationStart,
  213. name: Operator.IterationStart,
  214. form: {},
  215. },
  216. parentId: newNode.id,
  217. extent: 'parent',
  218. };
  219. addNode(newNode);
  220. addNode(iterationStartNode);
  221. } else {
  222. const subNodeOfIteration = getRelativePositionToIterationNode(
  223. nodes,
  224. position,
  225. );
  226. if (subNodeOfIteration) {
  227. newNode.parentId = subNodeOfIteration.parentId;
  228. newNode.position = subNodeOfIteration.position;
  229. newNode.extent = 'parent';
  230. }
  231. addNode(newNode);
  232. }
  233. },
  234. [reactFlowInstance, getNodeName, nodes, initializeOperatorParams, addNode],
  235. );
  236. return { onDrop, onDragOver, setReactFlowInstance };
  237. };
  238. export const useHandleFormValuesChange = (id?: string) => {
  239. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  240. const handleValuesChange = useCallback(
  241. (changedValues: any, values: any) => {
  242. let nextValues: any = values;
  243. // Fixed the issue that the related form value does not change after selecting the freedom field of the model
  244. if (
  245. Object.keys(changedValues).length === 1 &&
  246. 'parameter' in changedValues &&
  247. changedValues['parameter'] in settledModelVariableMap
  248. ) {
  249. nextValues = {
  250. ...values,
  251. ...settledModelVariableMap[
  252. changedValues['parameter'] as keyof typeof settledModelVariableMap
  253. ],
  254. };
  255. }
  256. if (id) {
  257. updateNodeForm(id, nextValues);
  258. }
  259. },
  260. [updateNodeForm, id],
  261. );
  262. return { handleValuesChange };
  263. };
  264. export const useSetLlmSetting = (
  265. form?: FormInstance,
  266. formData?: Record<string, any>,
  267. ) => {
  268. const initialLlmSetting = pick(
  269. formData,
  270. Object.values(variableEnabledFieldMap),
  271. );
  272. useEffect(() => {
  273. const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce<
  274. Record<string, boolean>
  275. >((pre, field) => {
  276. pre[field] = isEmpty(initialLlmSetting)
  277. ? true
  278. : !!initialLlmSetting[
  279. variableEnabledFieldMap[
  280. field as keyof typeof variableEnabledFieldMap
  281. ] as keyof Variable
  282. ];
  283. return pre;
  284. }, {});
  285. let otherValues = settledModelVariableMap[ModelVariableType.Precise];
  286. if (!isEmpty(initialLlmSetting)) {
  287. otherValues = initialLlmSetting;
  288. }
  289. form?.setFieldsValue({
  290. ...switchBoxValues,
  291. ...otherValues,
  292. });
  293. }, [form, initialLlmSetting]);
  294. };
  295. export const useValidateConnection = () => {
  296. const { edges, getOperatorTypeFromId, getParentIdById } = useGraphStore(
  297. (state) => state,
  298. );
  299. const isSameNodeChild = useCallback(
  300. (connection: Connection) => {
  301. const sourceParentId = getParentIdById(connection.source);
  302. const targetParentId = getParentIdById(connection.target);
  303. if (sourceParentId || targetParentId) {
  304. return sourceParentId === targetParentId;
  305. }
  306. return true;
  307. },
  308. [getParentIdById],
  309. );
  310. // restricted lines cannot be connected successfully.
  311. const isValidConnection = useCallback(
  312. (connection: Connection) => {
  313. // node cannot connect to itself
  314. const isSelfConnected = connection.target === connection.source;
  315. // limit the connection between two nodes to only one connection line in one direction
  316. const hasLine = edges.some(
  317. (x) => x.source === connection.source && x.target === connection.target,
  318. );
  319. const ret =
  320. !isSelfConnected &&
  321. !hasLine &&
  322. RestrictedUpstreamMap[
  323. getOperatorTypeFromId(connection.source) as Operator
  324. ]?.every((x) => x !== getOperatorTypeFromId(connection.target)) &&
  325. isSameNodeChild(connection);
  326. return ret;
  327. },
  328. [edges, getOperatorTypeFromId, isSameNodeChild],
  329. );
  330. return isValidConnection;
  331. };
  332. export const useHandleNodeNameChange = ({
  333. id,
  334. data,
  335. }: {
  336. id?: string;
  337. data: any;
  338. }) => {
  339. const [name, setName] = useState<string>('');
  340. const { updateNodeName, nodes } = useGraphStore((state) => state);
  341. const previousName = data?.name;
  342. const handleNameBlur = useCallback(() => {
  343. const existsSameName = nodes.some((x) => x.data.name === name);
  344. if (trim(name) === '' || existsSameName) {
  345. if (existsSameName && previousName !== name) {
  346. message.error('The name cannot be repeated');
  347. }
  348. setName(previousName);
  349. return;
  350. }
  351. if (id) {
  352. updateNodeName(id, name);
  353. }
  354. }, [name, id, updateNodeName, previousName, nodes]);
  355. const handleNameChange = useCallback((e: ChangeEvent<any>) => {
  356. setName(e.target.value);
  357. }, []);
  358. useEffect(() => {
  359. setName(previousName);
  360. }, [previousName]);
  361. return { name, handleNameBlur, handleNameChange };
  362. };
  363. export const useReplaceIdWithName = () => {
  364. const getNode = useGraphStore((state) => state.getNode);
  365. const replaceIdWithName = useCallback(
  366. (id?: string) => {
  367. return getNode(id)?.data.name;
  368. },
  369. [getNode],
  370. );
  371. return replaceIdWithName;
  372. };
  373. export const useReplaceIdWithText = (output: unknown) => {
  374. const getNameById = useReplaceIdWithName();
  375. return {
  376. replacedOutput: replaceIdWithText(output, getNameById),
  377. getNameById,
  378. };
  379. };
  380. /**
  381. * monitor changes in the data.form field of the categorize and relevant operators
  382. * and then synchronize them to the edge
  383. */
  384. export const useWatchNodeFormDataChange = () => {
  385. const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state);
  386. const buildCategorizeEdgesByFormData = useCallback(
  387. (nodeId: string, form: ICategorizeForm) => {
  388. // add
  389. // delete
  390. // edit
  391. const categoryDescription = form.category_description;
  392. const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>(
  393. (pre, sourceHandle) => {
  394. const target = categoryDescription[sourceHandle]?.to;
  395. if (target) {
  396. pre.push({
  397. id: uuid(),
  398. source: nodeId,
  399. target,
  400. sourceHandle,
  401. });
  402. }
  403. return pre;
  404. },
  405. [],
  406. );
  407. setEdgesByNodeId(nodeId, downstreamEdges);
  408. },
  409. [setEdgesByNodeId],
  410. );
  411. const buildRelevantEdgesByFormData = useCallback(
  412. (nodeId: string, form: IRelevantForm) => {
  413. const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => {
  414. const target = form[cur as keyof IRelevantForm] as string;
  415. if (target) {
  416. pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur });
  417. }
  418. return pre;
  419. }, []);
  420. setEdgesByNodeId(nodeId, downstreamEdges);
  421. },
  422. [setEdgesByNodeId],
  423. );
  424. const buildSwitchEdgesByFormData = useCallback(
  425. (nodeId: string, form: ISwitchForm) => {
  426. // add
  427. // delete
  428. // edit
  429. const conditions = form.conditions;
  430. const downstreamEdges = conditions.reduce<Edge[]>((pre, _, idx) => {
  431. const target = conditions[idx]?.to;
  432. if (target) {
  433. pre.push({
  434. id: uuid(),
  435. source: nodeId,
  436. target,
  437. sourceHandle: generateSwitchHandleText(idx),
  438. });
  439. }
  440. return pre;
  441. }, []);
  442. // Splice the else condition of the conditional judgment to the edge list
  443. const elseTo = form[SwitchElseTo];
  444. if (elseTo) {
  445. downstreamEdges.push({
  446. id: uuid(),
  447. source: nodeId,
  448. target: elseTo,
  449. sourceHandle: SwitchElseTo,
  450. });
  451. }
  452. setEdgesByNodeId(nodeId, downstreamEdges);
  453. },
  454. [setEdgesByNodeId],
  455. );
  456. useEffect(() => {
  457. nodes.forEach((node) => {
  458. const currentNode = getNode(node.id);
  459. const form = currentNode?.data.form ?? {};
  460. const operatorType = currentNode?.data.label;
  461. switch (operatorType) {
  462. case Operator.Relevant:
  463. buildRelevantEdgesByFormData(node.id, form as IRelevantForm);
  464. break;
  465. case Operator.Categorize:
  466. buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
  467. break;
  468. case Operator.Switch:
  469. buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
  470. break;
  471. default:
  472. break;
  473. }
  474. });
  475. }, [
  476. nodes,
  477. buildCategorizeEdgesByFormData,
  478. getNode,
  479. buildRelevantEdgesByFormData,
  480. buildSwitchEdgesByFormData,
  481. ]);
  482. };
  483. export const useDuplicateNode = () => {
  484. const duplicateNodeById = useGraphStore((store) => store.duplicateNode);
  485. const getNodeName = useGetNodeName();
  486. const duplicateNode = useCallback(
  487. (id: string, label: string) => {
  488. duplicateNodeById(id, getNodeName(label));
  489. },
  490. [duplicateNodeById, getNodeName],
  491. );
  492. return duplicateNode;
  493. };
  494. export const useCopyPaste = () => {
  495. const nodes = useGraphStore((state) => state.nodes);
  496. const duplicateNode = useDuplicateNode();
  497. const onCopyCapture = useCallback(
  498. (event: ClipboardEvent) => {
  499. if (get(event, 'srcElement.tagName') !== 'BODY') return;
  500. event.preventDefault();
  501. const nodesStr = JSON.stringify(
  502. nodes.filter((n) => n.selected && n.data.label !== Operator.Begin),
  503. );
  504. event.clipboardData?.setData('agent:nodes', nodesStr);
  505. },
  506. [nodes],
  507. );
  508. const onPasteCapture = useCallback(
  509. (event: ClipboardEvent) => {
  510. const nodes = JSON.parse(
  511. event.clipboardData?.getData('agent:nodes') || '[]',
  512. ) as Node[] | undefined;
  513. if (Array.isArray(nodes) && nodes.length) {
  514. event.preventDefault();
  515. nodes.forEach((n) => {
  516. duplicateNode(n.id, n.data.label);
  517. });
  518. }
  519. },
  520. [duplicateNode],
  521. );
  522. useEffect(() => {
  523. window.addEventListener('copy', onCopyCapture);
  524. return () => {
  525. window.removeEventListener('copy', onCopyCapture);
  526. };
  527. }, [onCopyCapture]);
  528. useEffect(() => {
  529. window.addEventListener('paste', onPasteCapture);
  530. return () => {
  531. window.removeEventListener('paste', onPasteCapture);
  532. };
  533. }, [onPasteCapture]);
  534. };