You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hooks.ts 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. import { useSetModalState } from '@/hooks/common-hooks';
  2. import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
  3. import { IGraph } from '@/interfaces/database/flow';
  4. import { useIsFetching } from '@tanstack/react-query';
  5. import React, {
  6. ChangeEvent,
  7. KeyboardEventHandler,
  8. useCallback,
  9. useEffect,
  10. useMemo,
  11. useState,
  12. } from 'react';
  13. import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
  14. // import { shallow } from 'zustand/shallow';
  15. import { variableEnabledFieldMap } from '@/constants/chat';
  16. import {
  17. ModelVariableType,
  18. settledModelVariableMap,
  19. } from '@/constants/knowledge';
  20. import { useFetchModelId, useSendMessageWithSse } from '@/hooks/logic-hooks';
  21. import { Variable } from '@/interfaces/database/chat';
  22. import api from '@/utils/api';
  23. import { useDebounceEffect } from 'ahooks';
  24. import { FormInstance, message } from 'antd';
  25. import { humanId } from 'human-id';
  26. import trim from 'lodash/trim';
  27. import { useParams } from 'umi';
  28. import { v4 as uuid } from 'uuid';
  29. import {
  30. NodeMap,
  31. Operator,
  32. RestrictedUpstreamMap,
  33. SwitchElseTo,
  34. initialAkShareValues,
  35. initialArXivValues,
  36. initialBaiduFanyiValues,
  37. initialBaiduValues,
  38. initialBeginValues,
  39. initialBingValues,
  40. initialCategorizeValues,
  41. initialConcentratorValues,
  42. initialDeepLValues,
  43. initialDuckValues,
  44. initialExeSqlValues,
  45. initialGenerateValues,
  46. initialGithubValues,
  47. initialGoogleScholarValues,
  48. initialGoogleValues,
  49. initialJin10Values,
  50. initialKeywordExtractValues,
  51. initialMessageValues,
  52. initialPubMedValues,
  53. initialQWeatherValues,
  54. initialRelevantValues,
  55. initialRetrievalValues,
  56. initialRewriteQuestionValues,
  57. initialSwitchValues,
  58. initialWenCaiValues,
  59. initialWikipediaValues,
  60. initialYahooFinanceValues,
  61. } from './constant';
  62. import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface';
  63. import useGraphStore, { RFState } from './store';
  64. import {
  65. buildDslComponentsByGraph,
  66. generateSwitchHandleText,
  67. receiveMessageError,
  68. replaceIdWithText,
  69. } from './utils';
  70. const selector = (state: RFState) => ({
  71. nodes: state.nodes,
  72. edges: state.edges,
  73. onNodesChange: state.onNodesChange,
  74. onEdgesChange: state.onEdgesChange,
  75. onConnect: state.onConnect,
  76. setNodes: state.setNodes,
  77. onSelectionChange: state.onSelectionChange,
  78. });
  79. export const useSelectCanvasData = () => {
  80. // return useStore(useShallow(selector)); // throw error
  81. // return useStore(selector, shallow);
  82. return useGraphStore(selector);
  83. };
  84. export const useInitializeOperatorParams = () => {
  85. const llmId = useFetchModelId();
  86. const initialFormValuesMap = useMemo(() => {
  87. return {
  88. [Operator.Begin]: initialBeginValues,
  89. [Operator.Retrieval]: initialRetrievalValues,
  90. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  91. [Operator.Answer]: {},
  92. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  93. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  94. [Operator.RewriteQuestion]: {
  95. ...initialRewriteQuestionValues,
  96. llm_id: llmId,
  97. },
  98. [Operator.Message]: initialMessageValues,
  99. [Operator.KeywordExtract]: {
  100. ...initialKeywordExtractValues,
  101. llm_id: llmId,
  102. },
  103. [Operator.DuckDuckGo]: initialDuckValues,
  104. [Operator.Baidu]: initialBaiduValues,
  105. [Operator.Wikipedia]: initialWikipediaValues,
  106. [Operator.PubMed]: initialPubMedValues,
  107. [Operator.ArXiv]: initialArXivValues,
  108. [Operator.Google]: initialGoogleValues,
  109. [Operator.Bing]: initialBingValues,
  110. [Operator.GoogleScholar]: initialGoogleScholarValues,
  111. [Operator.DeepL]: initialDeepLValues,
  112. [Operator.GitHub]: initialGithubValues,
  113. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  114. [Operator.QWeather]: initialQWeatherValues,
  115. [Operator.ExeSQL]: initialExeSqlValues,
  116. [Operator.Switch]: initialSwitchValues,
  117. [Operator.WenCai]: initialWenCaiValues,
  118. [Operator.AkShare]: initialAkShareValues,
  119. [Operator.YahooFinance]: initialYahooFinanceValues,
  120. [Operator.Jin10]: initialJin10Values,
  121. [Operator.Concentrator]: initialConcentratorValues,
  122. };
  123. }, [llmId]);
  124. const initializeOperatorParams = useCallback(
  125. (operatorName: Operator) => {
  126. return initialFormValuesMap[operatorName];
  127. },
  128. [initialFormValuesMap],
  129. );
  130. return initializeOperatorParams;
  131. };
  132. export const useHandleDrag = () => {
  133. const handleDragStart = useCallback(
  134. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  135. ev.dataTransfer.setData('application/reactflow', operatorId);
  136. ev.dataTransfer.effectAllowed = 'move';
  137. },
  138. [],
  139. );
  140. return { handleDragStart };
  141. };
  142. export const useHandleDrop = () => {
  143. const addNode = useGraphStore((state) => state.addNode);
  144. const [reactFlowInstance, setReactFlowInstance] =
  145. useState<ReactFlowInstance<any, any>>();
  146. const initializeOperatorParams = useInitializeOperatorParams();
  147. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  148. event.preventDefault();
  149. event.dataTransfer.dropEffect = 'move';
  150. }, []);
  151. const onDrop = useCallback(
  152. (event: React.DragEvent<HTMLDivElement>) => {
  153. event.preventDefault();
  154. const type = event.dataTransfer.getData('application/reactflow');
  155. // check if the dropped element is valid
  156. if (typeof type === 'undefined' || !type) {
  157. return;
  158. }
  159. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  160. // and you don't need to subtract the reactFlowBounds.left/top anymore
  161. // details: https://reactflow.dev/whats-new/2023-11-10
  162. const position = reactFlowInstance?.screenToFlowPosition({
  163. x: event.clientX,
  164. y: event.clientY,
  165. });
  166. const newNode = {
  167. id: `${type}:${humanId()}`,
  168. type: NodeMap[type as Operator] || 'ragNode',
  169. position: position || {
  170. x: 0,
  171. y: 0,
  172. },
  173. data: {
  174. label: `${type}`,
  175. name: humanId(),
  176. form: initializeOperatorParams(type as Operator),
  177. },
  178. sourcePosition: Position.Right,
  179. targetPosition: Position.Left,
  180. };
  181. addNode(newNode);
  182. },
  183. [reactFlowInstance, addNode, initializeOperatorParams],
  184. );
  185. return { onDrop, onDragOver, setReactFlowInstance };
  186. };
  187. export const useShowDrawer = () => {
  188. const {
  189. clickedNodeId: clickNodeId,
  190. setClickedNodeId,
  191. getNode,
  192. } = useGraphStore((state) => state);
  193. const {
  194. visible: drawerVisible,
  195. hideModal: hideDrawer,
  196. showModal: showDrawer,
  197. } = useSetModalState();
  198. const handleShow = useCallback(
  199. (node: Node) => {
  200. setClickedNodeId(node.id);
  201. showDrawer();
  202. },
  203. [showDrawer, setClickedNodeId],
  204. );
  205. return {
  206. drawerVisible,
  207. hideDrawer,
  208. showDrawer: handleShow,
  209. clickedNode: getNode(clickNodeId),
  210. };
  211. };
  212. export const useHandleKeyUp = () => {
  213. const deleteEdge = useGraphStore((state) => state.deleteEdge);
  214. const handleKeyUp: KeyboardEventHandler = useCallback(
  215. (e) => {
  216. if (e.code === 'Delete') {
  217. deleteEdge();
  218. }
  219. },
  220. [deleteEdge],
  221. );
  222. return { handleKeyUp };
  223. };
  224. export const useSaveGraph = () => {
  225. const { data } = useFetchFlow();
  226. const { setFlow } = useSetFlow();
  227. const { id } = useParams();
  228. const { nodes, edges } = useGraphStore((state) => state);
  229. const saveGraph = useCallback(async () => {
  230. const dslComponents = buildDslComponentsByGraph(nodes, edges);
  231. return setFlow({
  232. id,
  233. title: data.title,
  234. dsl: { ...data.dsl, graph: { nodes, edges }, components: dslComponents },
  235. });
  236. }, [nodes, edges, setFlow, id, data]);
  237. return { saveGraph };
  238. };
  239. export const useWatchGraphChange = () => {
  240. const nodes = useGraphStore((state) => state.nodes);
  241. const edges = useGraphStore((state) => state.edges);
  242. useDebounceEffect(
  243. () => {
  244. // console.info('useDebounceEffect');
  245. },
  246. [nodes, edges],
  247. {
  248. wait: 1000,
  249. },
  250. );
  251. };
  252. export const useHandleFormValuesChange = (id?: string) => {
  253. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  254. const handleValuesChange = useCallback(
  255. (changedValues: any, values: any) => {
  256. let nextValues: any = values;
  257. // Fixed the issue that the related form value does not change after selecting the freedom field of the model
  258. if (
  259. Object.keys(changedValues).length === 1 &&
  260. 'parameter' in changedValues &&
  261. changedValues['parameter'] in settledModelVariableMap
  262. ) {
  263. nextValues = {
  264. ...values,
  265. ...settledModelVariableMap[
  266. changedValues['parameter'] as keyof typeof settledModelVariableMap
  267. ],
  268. };
  269. }
  270. if (id) {
  271. updateNodeForm(id, nextValues);
  272. }
  273. },
  274. [updateNodeForm, id],
  275. );
  276. return { handleValuesChange };
  277. };
  278. const useSetGraphInfo = () => {
  279. const { setEdges, setNodes } = useGraphStore((state) => state);
  280. const setGraphInfo = useCallback(
  281. ({ nodes = [], edges = [] }: IGraph) => {
  282. if (nodes.length || edges.length) {
  283. setNodes(nodes);
  284. setEdges(edges);
  285. }
  286. },
  287. [setEdges, setNodes],
  288. );
  289. return setGraphInfo;
  290. };
  291. export const useFetchDataOnMount = () => {
  292. const { loading, data, refetch } = useFetchFlow();
  293. const setGraphInfo = useSetGraphInfo();
  294. useEffect(() => {
  295. setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
  296. }, [setGraphInfo, data]);
  297. useWatchGraphChange();
  298. useEffect(() => {
  299. refetch();
  300. }, [refetch]);
  301. return { loading, flowDetail: data };
  302. };
  303. export const useFlowIsFetching = () => {
  304. return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
  305. };
  306. export const useSetLlmSetting = (form?: FormInstance) => {
  307. const initialLlmSetting = undefined;
  308. useEffect(() => {
  309. const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce<
  310. Record<string, boolean>
  311. >((pre, field) => {
  312. pre[field] =
  313. initialLlmSetting === undefined
  314. ? true
  315. : !!initialLlmSetting[
  316. variableEnabledFieldMap[
  317. field as keyof typeof variableEnabledFieldMap
  318. ] as keyof Variable
  319. ];
  320. return pre;
  321. }, {});
  322. const otherValues = settledModelVariableMap[ModelVariableType.Precise];
  323. form?.setFieldsValue({
  324. ...switchBoxValues,
  325. ...otherValues,
  326. });
  327. }, [form, initialLlmSetting]);
  328. };
  329. export const useValidateConnection = () => {
  330. const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
  331. // restricted lines cannot be connected successfully.
  332. const isValidConnection = useCallback(
  333. (connection: Connection) => {
  334. // node cannot connect to itself
  335. const isSelfConnected = connection.target === connection.source;
  336. // limit the connection between two nodes to only one connection line in one direction
  337. const hasLine = edges.some(
  338. (x) => x.source === connection.source && x.target === connection.target,
  339. );
  340. const ret =
  341. !isSelfConnected &&
  342. !hasLine &&
  343. RestrictedUpstreamMap[
  344. getOperatorTypeFromId(connection.source) as Operator
  345. ]?.every((x) => x !== getOperatorTypeFromId(connection.target));
  346. return ret;
  347. },
  348. [edges, getOperatorTypeFromId],
  349. );
  350. return isValidConnection;
  351. };
  352. export const useHandleNodeNameChange = (node?: Node) => {
  353. const [name, setName] = useState<string>('');
  354. const { updateNodeName, nodes } = useGraphStore((state) => state);
  355. const previousName = node?.data.name;
  356. const id = node?.id;
  357. const handleNameBlur = useCallback(() => {
  358. const existsSameName = nodes.some((x) => x.data.name === name);
  359. if (trim(name) === '' || existsSameName) {
  360. if (existsSameName && previousName !== name) {
  361. message.error('The name cannot be repeated');
  362. }
  363. setName(previousName);
  364. return;
  365. }
  366. if (id) {
  367. updateNodeName(id, name);
  368. }
  369. }, [name, id, updateNodeName, previousName, nodes]);
  370. const handleNameChange = useCallback((e: ChangeEvent<any>) => {
  371. setName(e.target.value);
  372. }, []);
  373. useEffect(() => {
  374. setName(previousName);
  375. }, [previousName]);
  376. return { name, handleNameBlur, handleNameChange };
  377. };
  378. export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
  379. const { id } = useParams();
  380. const { saveGraph } = useSaveGraph();
  381. const { resetFlow } = useResetFlow();
  382. const { refetch } = useFetchFlow();
  383. const { send } = useSendMessageWithSse(api.runCanvas);
  384. const handleRun = useCallback(async () => {
  385. const saveRet = await saveGraph();
  386. if (saveRet?.retcode === 0) {
  387. // Call the reset api before opening the run drawer each time
  388. const resetRet = await resetFlow();
  389. // After resetting, all previous messages will be cleared.
  390. if (resetRet?.retcode === 0) {
  391. // fetch prologue
  392. const sendRet = await send({ id });
  393. if (receiveMessageError(sendRet)) {
  394. message.error(sendRet?.data?.retmsg);
  395. } else {
  396. refetch();
  397. show();
  398. }
  399. }
  400. }
  401. }, [saveGraph, resetFlow, id, send, show, refetch]);
  402. return handleRun;
  403. };
  404. export const useReplaceIdWithText = (output: unknown) => {
  405. const getNode = useGraphStore((state) => state.getNode);
  406. const getNameById = (id?: string) => {
  407. return getNode(id)?.data.name;
  408. };
  409. return replaceIdWithText(output, getNameById);
  410. };
  411. /**
  412. * monitor changes in the data.form field of the categorize and relevant operators
  413. * and then synchronize them to the edge
  414. */
  415. export const useWatchNodeFormDataChange = () => {
  416. const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state);
  417. const buildCategorizeEdgesByFormData = useCallback(
  418. (nodeId: string, form: ICategorizeForm) => {
  419. // add
  420. // delete
  421. // edit
  422. const categoryDescription = form.category_description;
  423. const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>(
  424. (pre, sourceHandle) => {
  425. const target = categoryDescription[sourceHandle]?.to;
  426. if (target) {
  427. pre.push({
  428. id: uuid(),
  429. source: nodeId,
  430. target,
  431. sourceHandle,
  432. });
  433. }
  434. return pre;
  435. },
  436. [],
  437. );
  438. setEdgesByNodeId(nodeId, downstreamEdges);
  439. },
  440. [setEdgesByNodeId],
  441. );
  442. const buildRelevantEdgesByFormData = useCallback(
  443. (nodeId: string, form: IRelevantForm) => {
  444. const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => {
  445. const target = form[cur as keyof IRelevantForm] as string;
  446. if (target) {
  447. pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur });
  448. }
  449. return pre;
  450. }, []);
  451. setEdgesByNodeId(nodeId, downstreamEdges);
  452. },
  453. [setEdgesByNodeId],
  454. );
  455. const buildSwitchEdgesByFormData = useCallback(
  456. (nodeId: string, form: ISwitchForm) => {
  457. // add
  458. // delete
  459. // edit
  460. const conditions = form.conditions;
  461. const downstreamEdges = conditions.reduce<Edge[]>((pre, _, idx) => {
  462. const target = conditions[idx]?.to;
  463. if (target) {
  464. pre.push({
  465. id: uuid(),
  466. source: nodeId,
  467. target,
  468. sourceHandle: generateSwitchHandleText(idx),
  469. });
  470. }
  471. return pre;
  472. }, []);
  473. // Splice the else condition of the conditional judgment to the edge list
  474. const elseTo = form[SwitchElseTo];
  475. if (elseTo) {
  476. downstreamEdges.push({
  477. id: uuid(),
  478. source: nodeId,
  479. target: elseTo,
  480. sourceHandle: SwitchElseTo,
  481. });
  482. }
  483. setEdgesByNodeId(nodeId, downstreamEdges);
  484. },
  485. [setEdgesByNodeId],
  486. );
  487. useEffect(() => {
  488. nodes.forEach((node) => {
  489. const currentNode = getNode(node.id);
  490. const form = currentNode?.data.form ?? {};
  491. const operatorType = currentNode?.data.label;
  492. switch (operatorType) {
  493. case Operator.Relevant:
  494. buildRelevantEdgesByFormData(node.id, form as IRelevantForm);
  495. break;
  496. case Operator.Categorize:
  497. buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
  498. break;
  499. case Operator.Switch:
  500. buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
  501. break;
  502. default:
  503. break;
  504. }
  505. });
  506. }, [
  507. nodes,
  508. buildCategorizeEdgesByFormData,
  509. getNode,
  510. buildRelevantEdgesByFormData,
  511. buildSwitchEdgesByFormData,
  512. ]);
  513. };
  514. // exclude nodes with branches
  515. const ExcludedNodes = [
  516. Operator.Categorize,
  517. Operator.Relevant,
  518. Operator.Begin,
  519. Operator.Answer,
  520. ];
  521. export const useBuildComponentIdSelectOptions = (nodeId?: string) => {
  522. const nodes = useGraphStore((state) => state.nodes);
  523. const options = useMemo(() => {
  524. return nodes
  525. .filter(
  526. (x) =>
  527. x.id !== nodeId && !ExcludedNodes.some((y) => y === x.data.label),
  528. )
  529. .map((x) => ({ label: x.data.name, value: x.id }));
  530. }, [nodes, nodeId]);
  531. return options;
  532. };