You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hooks.ts 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import { useSetModalState } from '@/hooks/common-hooks';
  2. import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
  3. import { IGraph } from '@/interfaces/database/flow';
  4. import { useIsFetching } from '@tanstack/react-query';
  5. import React, {
  6. ChangeEvent,
  7. KeyboardEventHandler,
  8. useCallback,
  9. useEffect,
  10. useMemo,
  11. useState,
  12. } from 'react';
  13. import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
  14. // import { shallow } from 'zustand/shallow';
  15. import { variableEnabledFieldMap } from '@/constants/chat';
  16. import {
  17. ModelVariableType,
  18. settledModelVariableMap,
  19. } from '@/constants/knowledge';
  20. import { useFetchModelId, useSendMessageWithSse } from '@/hooks/logic-hooks';
  21. import { Variable } from '@/interfaces/database/chat';
  22. import api from '@/utils/api';
  23. import { useDebounceEffect } from 'ahooks';
  24. import { FormInstance, message } from 'antd';
  25. import { humanId } from 'human-id';
  26. import trim from 'lodash/trim';
  27. import { useParams } from 'umi';
  28. import { v4 as uuid } from 'uuid';
  29. import {
  30. NodeMap,
  31. Operator,
  32. RestrictedUpstreamMap,
  33. SwitchElseTo,
  34. initialAkShareValues,
  35. initialArXivValues,
  36. initialBaiduFanyiValues,
  37. initialBaiduValues,
  38. initialBeginValues,
  39. initialBingValues,
  40. initialCategorizeValues,
  41. initialDeepLValues,
  42. initialDuckValues,
  43. initialExeSqlValues,
  44. initialGenerateValues,
  45. initialGithubValues,
  46. initialGoogleScholarValues,
  47. initialGoogleValues,
  48. initialKeywordExtractValues,
  49. initialMessageValues,
  50. initialPubMedValues,
  51. initialQWeatherValues,
  52. initialRelevantValues,
  53. initialRetrievalValues,
  54. initialRewriteQuestionValues,
  55. initialSwitchValues,
  56. initialWenCaiValues,
  57. initialWikipediaValues,
  58. } from './constant';
  59. import { ICategorizeForm, IRelevantForm, ISwitchForm } from './interface';
  60. import useGraphStore, { RFState } from './store';
  61. import {
  62. buildDslComponentsByGraph,
  63. generateSwitchHandleText,
  64. receiveMessageError,
  65. replaceIdWithText,
  66. } from './utils';
  67. const selector = (state: RFState) => ({
  68. nodes: state.nodes,
  69. edges: state.edges,
  70. onNodesChange: state.onNodesChange,
  71. onEdgesChange: state.onEdgesChange,
  72. onConnect: state.onConnect,
  73. setNodes: state.setNodes,
  74. onSelectionChange: state.onSelectionChange,
  75. });
  76. export const useSelectCanvasData = () => {
  77. // return useStore(useShallow(selector)); // throw error
  78. // return useStore(selector, shallow);
  79. return useGraphStore(selector);
  80. };
  81. export const useInitializeOperatorParams = () => {
  82. const llmId = useFetchModelId();
  83. const initialFormValuesMap = useMemo(() => {
  84. return {
  85. [Operator.Begin]: initialBeginValues,
  86. [Operator.Retrieval]: initialRetrievalValues,
  87. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  88. [Operator.Answer]: {},
  89. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  90. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  91. [Operator.RewriteQuestion]: {
  92. ...initialRewriteQuestionValues,
  93. llm_id: llmId,
  94. },
  95. [Operator.Message]: initialMessageValues,
  96. [Operator.KeywordExtract]: {
  97. ...initialKeywordExtractValues,
  98. llm_id: llmId,
  99. },
  100. [Operator.DuckDuckGo]: initialDuckValues,
  101. [Operator.Baidu]: initialBaiduValues,
  102. [Operator.Wikipedia]: initialWikipediaValues,
  103. [Operator.PubMed]: initialPubMedValues,
  104. [Operator.ArXiv]: initialArXivValues,
  105. [Operator.Google]: initialGoogleValues,
  106. [Operator.Bing]: initialBingValues,
  107. [Operator.GoogleScholar]: initialGoogleScholarValues,
  108. [Operator.DeepL]: initialDeepLValues,
  109. [Operator.GitHub]: initialGithubValues,
  110. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  111. [Operator.QWeather]: initialQWeatherValues,
  112. [Operator.ExeSQL]: initialExeSqlValues,
  113. [Operator.Switch]: initialSwitchValues,
  114. [Operator.WenCai]: initialWenCaiValues,
  115. [Operator.AkShare]: initialAkShareValues,
  116. };
  117. }, [llmId]);
  118. const initializeOperatorParams = useCallback(
  119. (operatorName: Operator) => {
  120. return initialFormValuesMap[operatorName];
  121. },
  122. [initialFormValuesMap],
  123. );
  124. return initializeOperatorParams;
  125. };
  126. export const useHandleDrag = () => {
  127. const handleDragStart = useCallback(
  128. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  129. ev.dataTransfer.setData('application/reactflow', operatorId);
  130. ev.dataTransfer.effectAllowed = 'move';
  131. },
  132. [],
  133. );
  134. return { handleDragStart };
  135. };
  136. export const useHandleDrop = () => {
  137. const addNode = useGraphStore((state) => state.addNode);
  138. const [reactFlowInstance, setReactFlowInstance] =
  139. useState<ReactFlowInstance<any, any>>();
  140. const initializeOperatorParams = useInitializeOperatorParams();
  141. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  142. event.preventDefault();
  143. event.dataTransfer.dropEffect = 'move';
  144. }, []);
  145. const onDrop = useCallback(
  146. (event: React.DragEvent<HTMLDivElement>) => {
  147. event.preventDefault();
  148. const type = event.dataTransfer.getData('application/reactflow');
  149. // check if the dropped element is valid
  150. if (typeof type === 'undefined' || !type) {
  151. return;
  152. }
  153. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  154. // and you don't need to subtract the reactFlowBounds.left/top anymore
  155. // details: https://reactflow.dev/whats-new/2023-11-10
  156. const position = reactFlowInstance?.screenToFlowPosition({
  157. x: event.clientX,
  158. y: event.clientY,
  159. });
  160. const newNode = {
  161. id: `${type}:${humanId()}`,
  162. type: NodeMap[type as Operator] || 'ragNode',
  163. position: position || {
  164. x: 0,
  165. y: 0,
  166. },
  167. data: {
  168. label: `${type}`,
  169. name: humanId(),
  170. form: initializeOperatorParams(type as Operator),
  171. },
  172. sourcePosition: Position.Right,
  173. targetPosition: Position.Left,
  174. };
  175. addNode(newNode);
  176. },
  177. [reactFlowInstance, addNode, initializeOperatorParams],
  178. );
  179. return { onDrop, onDragOver, setReactFlowInstance };
  180. };
  181. export const useShowDrawer = () => {
  182. const {
  183. clickedNodeId: clickNodeId,
  184. setClickedNodeId,
  185. getNode,
  186. } = useGraphStore((state) => state);
  187. const {
  188. visible: drawerVisible,
  189. hideModal: hideDrawer,
  190. showModal: showDrawer,
  191. } = useSetModalState();
  192. const handleShow = useCallback(
  193. (node: Node) => {
  194. setClickedNodeId(node.id);
  195. showDrawer();
  196. },
  197. [showDrawer, setClickedNodeId],
  198. );
  199. return {
  200. drawerVisible,
  201. hideDrawer,
  202. showDrawer: handleShow,
  203. clickedNode: getNode(clickNodeId),
  204. };
  205. };
  206. export const useHandleKeyUp = () => {
  207. const deleteEdge = useGraphStore((state) => state.deleteEdge);
  208. const handleKeyUp: KeyboardEventHandler = useCallback(
  209. (e) => {
  210. if (e.code === 'Delete') {
  211. deleteEdge();
  212. }
  213. },
  214. [deleteEdge],
  215. );
  216. return { handleKeyUp };
  217. };
  218. export const useSaveGraph = () => {
  219. const { data } = useFetchFlow();
  220. const { setFlow } = useSetFlow();
  221. const { id } = useParams();
  222. const { nodes, edges } = useGraphStore((state) => state);
  223. const saveGraph = useCallback(async () => {
  224. const dslComponents = buildDslComponentsByGraph(nodes, edges);
  225. return setFlow({
  226. id,
  227. title: data.title,
  228. dsl: { ...data.dsl, graph: { nodes, edges }, components: dslComponents },
  229. });
  230. }, [nodes, edges, setFlow, id, data]);
  231. return { saveGraph };
  232. };
  233. export const useWatchGraphChange = () => {
  234. const nodes = useGraphStore((state) => state.nodes);
  235. const edges = useGraphStore((state) => state.edges);
  236. useDebounceEffect(
  237. () => {
  238. // console.info('useDebounceEffect');
  239. },
  240. [nodes, edges],
  241. {
  242. wait: 1000,
  243. },
  244. );
  245. };
  246. export const useHandleFormValuesChange = (id?: string) => {
  247. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  248. const handleValuesChange = useCallback(
  249. (changedValues: any, values: any) => {
  250. let nextValues: any = values;
  251. // Fixed the issue that the related form value does not change after selecting the freedom field of the model
  252. if (
  253. Object.keys(changedValues).length === 1 &&
  254. 'parameter' in changedValues &&
  255. changedValues['parameter'] in settledModelVariableMap
  256. ) {
  257. nextValues = {
  258. ...values,
  259. ...settledModelVariableMap[
  260. changedValues['parameter'] as keyof typeof settledModelVariableMap
  261. ],
  262. };
  263. }
  264. if (id) {
  265. updateNodeForm(id, nextValues);
  266. }
  267. },
  268. [updateNodeForm, id],
  269. );
  270. return { handleValuesChange };
  271. };
  272. const useSetGraphInfo = () => {
  273. const { setEdges, setNodes } = useGraphStore((state) => state);
  274. const setGraphInfo = useCallback(
  275. ({ nodes = [], edges = [] }: IGraph) => {
  276. if (nodes.length || edges.length) {
  277. setNodes(nodes);
  278. setEdges(edges);
  279. }
  280. },
  281. [setEdges, setNodes],
  282. );
  283. return setGraphInfo;
  284. };
  285. export const useFetchDataOnMount = () => {
  286. const { loading, data, refetch } = useFetchFlow();
  287. const setGraphInfo = useSetGraphInfo();
  288. useEffect(() => {
  289. setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
  290. }, [setGraphInfo, data]);
  291. useWatchGraphChange();
  292. useEffect(() => {
  293. refetch();
  294. }, [refetch]);
  295. return { loading, flowDetail: data };
  296. };
  297. export const useFlowIsFetching = () => {
  298. return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
  299. };
  300. export const useSetLlmSetting = (form?: FormInstance) => {
  301. const initialLlmSetting = undefined;
  302. useEffect(() => {
  303. const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce<
  304. Record<string, boolean>
  305. >((pre, field) => {
  306. pre[field] =
  307. initialLlmSetting === undefined
  308. ? true
  309. : !!initialLlmSetting[
  310. variableEnabledFieldMap[
  311. field as keyof typeof variableEnabledFieldMap
  312. ] as keyof Variable
  313. ];
  314. return pre;
  315. }, {});
  316. const otherValues = settledModelVariableMap[ModelVariableType.Precise];
  317. form?.setFieldsValue({
  318. ...switchBoxValues,
  319. ...otherValues,
  320. });
  321. }, [form, initialLlmSetting]);
  322. };
  323. export const useValidateConnection = () => {
  324. const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
  325. // restricted lines cannot be connected successfully.
  326. const isValidConnection = useCallback(
  327. (connection: Connection) => {
  328. // node cannot connect to itself
  329. const isSelfConnected = connection.target === connection.source;
  330. // limit the connection between two nodes to only one connection line in one direction
  331. const hasLine = edges.some(
  332. (x) => x.source === connection.source && x.target === connection.target,
  333. );
  334. const ret =
  335. !isSelfConnected &&
  336. !hasLine &&
  337. RestrictedUpstreamMap[
  338. getOperatorTypeFromId(connection.source) as Operator
  339. ]?.every((x) => x !== getOperatorTypeFromId(connection.target));
  340. return ret;
  341. },
  342. [edges, getOperatorTypeFromId],
  343. );
  344. return isValidConnection;
  345. };
  346. export const useHandleNodeNameChange = (node?: Node) => {
  347. const [name, setName] = useState<string>('');
  348. const { updateNodeName, nodes } = useGraphStore((state) => state);
  349. const previousName = node?.data.name;
  350. const id = node?.id;
  351. const handleNameBlur = useCallback(() => {
  352. const existsSameName = nodes.some((x) => x.data.name === name);
  353. if (trim(name) === '' || existsSameName) {
  354. if (existsSameName && previousName !== name) {
  355. message.error('The name cannot be repeated');
  356. }
  357. setName(previousName);
  358. return;
  359. }
  360. if (id) {
  361. updateNodeName(id, name);
  362. }
  363. }, [name, id, updateNodeName, previousName, nodes]);
  364. const handleNameChange = useCallback((e: ChangeEvent<any>) => {
  365. setName(e.target.value);
  366. }, []);
  367. useEffect(() => {
  368. setName(previousName);
  369. }, [previousName]);
  370. return { name, handleNameBlur, handleNameChange };
  371. };
  372. export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
  373. const { id } = useParams();
  374. const { saveGraph } = useSaveGraph();
  375. const { resetFlow } = useResetFlow();
  376. const { refetch } = useFetchFlow();
  377. const { send } = useSendMessageWithSse(api.runCanvas);
  378. const handleRun = useCallback(async () => {
  379. const saveRet = await saveGraph();
  380. if (saveRet?.retcode === 0) {
  381. // Call the reset api before opening the run drawer each time
  382. const resetRet = await resetFlow();
  383. // After resetting, all previous messages will be cleared.
  384. if (resetRet?.retcode === 0) {
  385. // fetch prologue
  386. const sendRet = await send({ id });
  387. if (receiveMessageError(sendRet)) {
  388. message.error(sendRet?.data?.retmsg);
  389. } else {
  390. refetch();
  391. show();
  392. }
  393. }
  394. }
  395. }, [saveGraph, resetFlow, id, send, show, refetch]);
  396. return handleRun;
  397. };
  398. export const useReplaceIdWithText = (output: unknown) => {
  399. const getNode = useGraphStore((state) => state.getNode);
  400. const getNameById = (id?: string) => {
  401. return getNode(id)?.data.name;
  402. };
  403. return replaceIdWithText(output, getNameById);
  404. };
  405. /**
  406. * monitor changes in the data.form field of the categorize and relevant operators
  407. * and then synchronize them to the edge
  408. */
  409. export const useWatchNodeFormDataChange = () => {
  410. const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state);
  411. const buildCategorizeEdgesByFormData = useCallback(
  412. (nodeId: string, form: ICategorizeForm) => {
  413. // add
  414. // delete
  415. // edit
  416. const categoryDescription = form.category_description;
  417. const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>(
  418. (pre, sourceHandle) => {
  419. const target = categoryDescription[sourceHandle]?.to;
  420. if (target) {
  421. pre.push({
  422. id: uuid(),
  423. source: nodeId,
  424. target,
  425. sourceHandle,
  426. });
  427. }
  428. return pre;
  429. },
  430. [],
  431. );
  432. setEdgesByNodeId(nodeId, downstreamEdges);
  433. },
  434. [setEdgesByNodeId],
  435. );
  436. const buildRelevantEdgesByFormData = useCallback(
  437. (nodeId: string, form: IRelevantForm) => {
  438. const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => {
  439. const target = form[cur as keyof IRelevantForm] as string;
  440. if (target) {
  441. pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur });
  442. }
  443. return pre;
  444. }, []);
  445. setEdgesByNodeId(nodeId, downstreamEdges);
  446. },
  447. [setEdgesByNodeId],
  448. );
  449. const buildSwitchEdgesByFormData = useCallback(
  450. (nodeId: string, form: ISwitchForm) => {
  451. // add
  452. // delete
  453. // edit
  454. const conditions = form.conditions;
  455. const downstreamEdges = conditions.reduce<Edge[]>((pre, _, idx) => {
  456. const target = conditions[idx]?.to;
  457. if (target) {
  458. pre.push({
  459. id: uuid(),
  460. source: nodeId,
  461. target,
  462. sourceHandle: generateSwitchHandleText(idx),
  463. });
  464. }
  465. return pre;
  466. }, []);
  467. // Splice the else condition of the conditional judgment to the edge list
  468. const elseTo = form[SwitchElseTo];
  469. if (elseTo) {
  470. downstreamEdges.push({
  471. id: uuid(),
  472. source: nodeId,
  473. target: elseTo,
  474. sourceHandle: SwitchElseTo,
  475. });
  476. }
  477. setEdgesByNodeId(nodeId, downstreamEdges);
  478. },
  479. [setEdgesByNodeId],
  480. );
  481. useEffect(() => {
  482. nodes.forEach((node) => {
  483. const currentNode = getNode(node.id);
  484. const form = currentNode?.data.form ?? {};
  485. const operatorType = currentNode?.data.label;
  486. switch (operatorType) {
  487. case Operator.Relevant:
  488. buildRelevantEdgesByFormData(node.id, form as IRelevantForm);
  489. break;
  490. case Operator.Categorize:
  491. buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
  492. break;
  493. case Operator.Switch:
  494. buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
  495. break;
  496. default:
  497. break;
  498. }
  499. });
  500. }, [
  501. nodes,
  502. buildCategorizeEdgesByFormData,
  503. getNode,
  504. buildRelevantEdgesByFormData,
  505. buildSwitchEdgesByFormData,
  506. ]);
  507. };
  508. // exclude nodes with branches
  509. const ExcludedNodes = [
  510. Operator.Categorize,
  511. Operator.Relevant,
  512. Operator.Begin,
  513. Operator.Answer,
  514. ];
  515. export const useBuildComponentIdSelectOptions = (nodeId?: string) => {
  516. const nodes = useGraphStore((state) => state.nodes);
  517. const options = useMemo(() => {
  518. return nodes
  519. .filter(
  520. (x) =>
  521. x.id !== nodeId && !ExcludedNodes.some((y) => y === x.data.label),
  522. )
  523. .map((x) => ({ label: x.data.name, value: x.id }));
  524. }, [nodes, nodeId]);
  525. return options;
  526. };