Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

hooks.ts 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import { useSetModalState } from '@/hooks/common-hooks';
  2. import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
  3. import { useFetchLlmList } from '@/hooks/llm-hooks';
  4. import { IGraph } from '@/interfaces/database/flow';
  5. import { useIsFetching } from '@tanstack/react-query';
  6. import React, {
  7. ChangeEvent,
  8. KeyboardEventHandler,
  9. useCallback,
  10. useEffect,
  11. useMemo,
  12. useState,
  13. } from 'react';
  14. import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
  15. // import { shallow } from 'zustand/shallow';
  16. import { variableEnabledFieldMap } from '@/constants/chat';
  17. import {
  18. ModelVariableType,
  19. settledModelVariableMap,
  20. } from '@/constants/knowledge';
  21. import { useFetchModelId, useSendMessageWithSse } from '@/hooks/logic-hooks';
  22. import { Variable } from '@/interfaces/database/chat';
  23. import api from '@/utils/api';
  24. import { useDebounceEffect } from 'ahooks';
  25. import { FormInstance, message } from 'antd';
  26. import { humanId } from 'human-id';
  27. import trim from 'lodash/trim';
  28. import { useParams } from 'umi';
  29. import { v4 as uuid } from 'uuid';
  30. import {
  31. NodeMap,
  32. Operator,
  33. RestrictedUpstreamMap,
  34. initialBaiduValues,
  35. initialBeginValues,
  36. initialCategorizeValues,
  37. initialDuckValues,
  38. initialGenerateValues,
  39. initialKeywordExtractValues,
  40. initialMessageValues,
  41. initialPubMedValues,
  42. initialRelevantValues,
  43. initialRetrievalValues,
  44. initialRewriteQuestionValues,
  45. initialWikipediaValues,
  46. } from './constant';
  47. import { ICategorizeForm, IRelevantForm } from './interface';
  48. import useGraphStore, { RFState } from './store';
  49. import {
  50. buildDslComponentsByGraph,
  51. receiveMessageError,
  52. replaceIdWithText,
  53. } from './utils';
  54. const selector = (state: RFState) => ({
  55. nodes: state.nodes,
  56. edges: state.edges,
  57. onNodesChange: state.onNodesChange,
  58. onEdgesChange: state.onEdgesChange,
  59. onConnect: state.onConnect,
  60. setNodes: state.setNodes,
  61. onSelectionChange: state.onSelectionChange,
  62. });
  63. export const useSelectCanvasData = () => {
  64. // return useStore(useShallow(selector)); // throw error
  65. // return useStore(selector, shallow);
  66. return useGraphStore(selector);
  67. };
  68. export const useInitializeOperatorParams = () => {
  69. const llmId = useFetchModelId(true);
  70. const initialFormValuesMap = useMemo(() => {
  71. return {
  72. [Operator.Begin]: initialBeginValues,
  73. [Operator.Retrieval]: initialRetrievalValues,
  74. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  75. [Operator.Answer]: {},
  76. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  77. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  78. [Operator.RewriteQuestion]: {
  79. ...initialRewriteQuestionValues,
  80. llm_id: llmId,
  81. },
  82. [Operator.Message]: initialMessageValues,
  83. [Operator.KeywordExtract]: {
  84. ...initialKeywordExtractValues,
  85. llm_id: llmId,
  86. },
  87. [Operator.DuckDuckGo]: initialDuckValues,
  88. [Operator.Baidu]: initialBaiduValues,
  89. [Operator.Wikipedia]: initialWikipediaValues,
  90. [Operator.PubMed]: initialPubMedValues,
  91. };
  92. }, [llmId]);
  93. const initializeOperatorParams = useCallback(
  94. (operatorName: Operator) => {
  95. return initialFormValuesMap[operatorName];
  96. },
  97. [initialFormValuesMap],
  98. );
  99. return initializeOperatorParams;
  100. };
  101. export const useHandleDrag = () => {
  102. const handleDragStart = useCallback(
  103. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  104. ev.dataTransfer.setData('application/reactflow', operatorId);
  105. ev.dataTransfer.effectAllowed = 'move';
  106. },
  107. [],
  108. );
  109. return { handleDragStart };
  110. };
  111. export const useHandleDrop = () => {
  112. const addNode = useGraphStore((state) => state.addNode);
  113. const [reactFlowInstance, setReactFlowInstance] =
  114. useState<ReactFlowInstance<any, any>>();
  115. const initializeOperatorParams = useInitializeOperatorParams();
  116. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  117. event.preventDefault();
  118. event.dataTransfer.dropEffect = 'move';
  119. }, []);
  120. const onDrop = useCallback(
  121. (event: React.DragEvent<HTMLDivElement>) => {
  122. event.preventDefault();
  123. const type = event.dataTransfer.getData('application/reactflow');
  124. // check if the dropped element is valid
  125. if (typeof type === 'undefined' || !type) {
  126. return;
  127. }
  128. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  129. // and you don't need to subtract the reactFlowBounds.left/top anymore
  130. // details: https://reactflow.dev/whats-new/2023-11-10
  131. const position = reactFlowInstance?.screenToFlowPosition({
  132. x: event.clientX,
  133. y: event.clientY,
  134. });
  135. const newNode = {
  136. id: `${type}:${humanId()}`,
  137. type: NodeMap[type as Operator] || 'ragNode',
  138. position: position || {
  139. x: 0,
  140. y: 0,
  141. },
  142. data: {
  143. label: `${type}`,
  144. name: humanId(),
  145. form: initializeOperatorParams(type as Operator),
  146. },
  147. sourcePosition: Position.Right,
  148. targetPosition: Position.Left,
  149. };
  150. addNode(newNode);
  151. },
  152. [reactFlowInstance, addNode, initializeOperatorParams],
  153. );
  154. return { onDrop, onDragOver, setReactFlowInstance };
  155. };
  156. export const useShowDrawer = () => {
  157. const {
  158. clickedNodeId: clickNodeId,
  159. setClickedNodeId,
  160. getNode,
  161. } = useGraphStore((state) => state);
  162. const {
  163. visible: drawerVisible,
  164. hideModal: hideDrawer,
  165. showModal: showDrawer,
  166. } = useSetModalState();
  167. const handleShow = useCallback(
  168. (node: Node) => {
  169. setClickedNodeId(node.id);
  170. showDrawer();
  171. },
  172. [showDrawer, setClickedNodeId],
  173. );
  174. return {
  175. drawerVisible,
  176. hideDrawer,
  177. showDrawer: handleShow,
  178. clickedNode: getNode(clickNodeId),
  179. };
  180. };
  181. export const useHandleKeyUp = () => {
  182. const deleteEdge = useGraphStore((state) => state.deleteEdge);
  183. const handleKeyUp: KeyboardEventHandler = useCallback(
  184. (e) => {
  185. if (e.code === 'Delete') {
  186. deleteEdge();
  187. }
  188. },
  189. [deleteEdge],
  190. );
  191. return { handleKeyUp };
  192. };
  193. export const useSaveGraph = () => {
  194. const { data } = useFetchFlow();
  195. const { setFlow } = useSetFlow();
  196. const { id } = useParams();
  197. const { nodes, edges } = useGraphStore((state) => state);
  198. const saveGraph = useCallback(async () => {
  199. const dslComponents = buildDslComponentsByGraph(nodes, edges);
  200. return setFlow({
  201. id,
  202. title: data.title,
  203. dsl: { ...data.dsl, graph: { nodes, edges }, components: dslComponents },
  204. });
  205. }, [nodes, edges, setFlow, id, data]);
  206. return { saveGraph };
  207. };
  208. export const useWatchGraphChange = () => {
  209. const nodes = useGraphStore((state) => state.nodes);
  210. const edges = useGraphStore((state) => state.edges);
  211. useDebounceEffect(
  212. () => {
  213. // console.info('useDebounceEffect');
  214. },
  215. [nodes, edges],
  216. {
  217. wait: 1000,
  218. },
  219. );
  220. };
  221. export const useHandleFormValuesChange = (id?: string) => {
  222. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  223. const handleValuesChange = useCallback(
  224. (changedValues: any, values: any) => {
  225. if (id) {
  226. updateNodeForm(id, values);
  227. }
  228. },
  229. [updateNodeForm, id],
  230. );
  231. return { handleValuesChange };
  232. };
  233. const useSetGraphInfo = () => {
  234. const { setEdges, setNodes } = useGraphStore((state) => state);
  235. const setGraphInfo = useCallback(
  236. ({ nodes = [], edges = [] }: IGraph) => {
  237. if (nodes.length || edges.length) {
  238. setNodes(nodes);
  239. setEdges(edges);
  240. }
  241. },
  242. [setEdges, setNodes],
  243. );
  244. return setGraphInfo;
  245. };
  246. export const useFetchDataOnMount = () => {
  247. const { loading, data, refetch } = useFetchFlow();
  248. const setGraphInfo = useSetGraphInfo();
  249. useEffect(() => {
  250. setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
  251. }, [setGraphInfo, data]);
  252. useWatchGraphChange();
  253. useFetchLlmList();
  254. useEffect(() => {
  255. refetch();
  256. }, [refetch]);
  257. return { loading, flowDetail: data };
  258. };
  259. export const useFlowIsFetching = () => {
  260. return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
  261. };
  262. export const useSetLlmSetting = (form?: FormInstance) => {
  263. const initialLlmSetting = undefined;
  264. useEffect(() => {
  265. const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce<
  266. Record<string, boolean>
  267. >((pre, field) => {
  268. pre[field] =
  269. initialLlmSetting === undefined
  270. ? true
  271. : !!initialLlmSetting[
  272. variableEnabledFieldMap[
  273. field as keyof typeof variableEnabledFieldMap
  274. ] as keyof Variable
  275. ];
  276. return pre;
  277. }, {});
  278. const otherValues = settledModelVariableMap[ModelVariableType.Precise];
  279. form?.setFieldsValue({
  280. ...switchBoxValues,
  281. ...otherValues,
  282. });
  283. }, [form, initialLlmSetting]);
  284. };
  285. export const useValidateConnection = () => {
  286. const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
  287. // restricted lines cannot be connected successfully.
  288. const isValidConnection = useCallback(
  289. (connection: Connection) => {
  290. // node cannot connect to itself
  291. const isSelfConnected = connection.target === connection.source;
  292. // limit the connection between two nodes to only one connection line in one direction
  293. const hasLine = edges.some(
  294. (x) => x.source === connection.source && x.target === connection.target,
  295. );
  296. const ret =
  297. !isSelfConnected &&
  298. !hasLine &&
  299. RestrictedUpstreamMap[
  300. getOperatorTypeFromId(connection.source) as Operator
  301. ]?.every((x) => x !== getOperatorTypeFromId(connection.target));
  302. return ret;
  303. },
  304. [edges, getOperatorTypeFromId],
  305. );
  306. return isValidConnection;
  307. };
  308. export const useHandleNodeNameChange = (node?: Node) => {
  309. const [name, setName] = useState<string>('');
  310. const { updateNodeName, nodes } = useGraphStore((state) => state);
  311. const previousName = node?.data.name;
  312. const id = node?.id;
  313. const handleNameBlur = useCallback(() => {
  314. const existsSameName = nodes.some((x) => x.data.name === name);
  315. if (trim(name) === '' || existsSameName) {
  316. if (existsSameName && previousName !== name) {
  317. message.error('The name cannot be repeated');
  318. }
  319. setName(previousName);
  320. return;
  321. }
  322. if (id) {
  323. updateNodeName(id, name);
  324. }
  325. }, [name, id, updateNodeName, previousName, nodes]);
  326. const handleNameChange = useCallback((e: ChangeEvent<any>) => {
  327. setName(e.target.value);
  328. }, []);
  329. useEffect(() => {
  330. setName(previousName);
  331. }, [previousName]);
  332. return { name, handleNameBlur, handleNameChange };
  333. };
  334. export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
  335. const { id } = useParams();
  336. const { saveGraph } = useSaveGraph();
  337. const { resetFlow } = useResetFlow();
  338. const { refetch } = useFetchFlow();
  339. const { send } = useSendMessageWithSse(api.runCanvas);
  340. const handleRun = useCallback(async () => {
  341. const saveRet = await saveGraph();
  342. if (saveRet?.retcode === 0) {
  343. // Call the reset api before opening the run drawer each time
  344. const resetRet = await resetFlow();
  345. // After resetting, all previous messages will be cleared.
  346. if (resetRet?.retcode === 0) {
  347. // fetch prologue
  348. const sendRet = await send({ id });
  349. if (receiveMessageError(sendRet)) {
  350. message.error(sendRet?.data?.retmsg);
  351. } else {
  352. refetch();
  353. show();
  354. }
  355. }
  356. }
  357. }, [saveGraph, resetFlow, id, send, show, refetch]);
  358. return handleRun;
  359. };
  360. export const useReplaceIdWithText = (output: unknown) => {
  361. const getNode = useGraphStore((state) => state.getNode);
  362. const getNameById = (id?: string) => {
  363. return getNode(id)?.data.name;
  364. };
  365. return replaceIdWithText(output, getNameById);
  366. };
  367. /**
  368. * monitor changes in the data.form field of the categorize and relevant operators
  369. * and then synchronize them to the edge
  370. */
  371. export const useWatchNodeFormDataChange = () => {
  372. const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state);
  373. const buildCategorizeEdgesByFormData = useCallback(
  374. (nodeId: string, form: ICategorizeForm) => {
  375. // add
  376. // delete
  377. // edit
  378. const categoryDescription = form.category_description;
  379. const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>(
  380. (pre, sourceHandle) => {
  381. const target = categoryDescription[sourceHandle]?.to;
  382. if (target) {
  383. pre.push({
  384. id: uuid(),
  385. source: nodeId,
  386. target,
  387. sourceHandle,
  388. });
  389. }
  390. return pre;
  391. },
  392. [],
  393. );
  394. setEdgesByNodeId(nodeId, downstreamEdges);
  395. },
  396. [setEdgesByNodeId],
  397. );
  398. const buildRelevantEdgesByFormData = useCallback(
  399. (nodeId: string, form: IRelevantForm) => {
  400. const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => {
  401. const target = form[cur as keyof IRelevantForm] as string;
  402. if (target) {
  403. pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur });
  404. }
  405. return pre;
  406. }, []);
  407. setEdgesByNodeId(nodeId, downstreamEdges);
  408. },
  409. [setEdgesByNodeId],
  410. );
  411. useEffect(() => {
  412. nodes.forEach((node) => {
  413. const currentNode = getNode(node.id);
  414. const form = currentNode?.data.form ?? {};
  415. const operatorType = currentNode?.data.label;
  416. switch (operatorType) {
  417. case Operator.Relevant:
  418. buildRelevantEdgesByFormData(node.id, form as IRelevantForm);
  419. break;
  420. case Operator.Categorize:
  421. buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
  422. break;
  423. default:
  424. break;
  425. }
  426. });
  427. }, [
  428. nodes,
  429. buildCategorizeEdgesByFormData,
  430. getNode,
  431. buildRelevantEdgesByFormData,
  432. ]);
  433. };