You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import { useSetModalState } from '@/hooks/common-hooks';
  2. import { useFetchFlow, useResetFlow, useSetFlow } from '@/hooks/flow-hooks';
  3. import { useFetchLlmList } from '@/hooks/llm-hooks';
  4. import { IGraph } from '@/interfaces/database/flow';
  5. import { useIsFetching } from '@tanstack/react-query';
  6. import React, {
  7. ChangeEvent,
  8. KeyboardEventHandler,
  9. useCallback,
  10. useEffect,
  11. useMemo,
  12. useState,
  13. } from 'react';
  14. import { Connection, Edge, Node, Position, ReactFlowInstance } from 'reactflow';
  15. // import { shallow } from 'zustand/shallow';
  16. import { variableEnabledFieldMap } from '@/constants/chat';
  17. import {
  18. ModelVariableType,
  19. settledModelVariableMap,
  20. } from '@/constants/knowledge';
  21. import { useFetchModelId, useSendMessageWithSse } from '@/hooks/logic-hooks';
  22. import { Variable } from '@/interfaces/database/chat';
  23. import api from '@/utils/api';
  24. import { useDebounceEffect } from 'ahooks';
  25. import { FormInstance, message } from 'antd';
  26. import { humanId } from 'human-id';
  27. import trim from 'lodash/trim';
  28. import { useParams } from 'umi';
  29. import { v4 as uuid } from 'uuid';
  30. import {
  31. NodeMap,
  32. Operator,
  33. RestrictedUpstreamMap,
  34. initialArxivValues,
  35. initialBaiduValues,
  36. initialBeginValues,
  37. initialCategorizeValues,
  38. initialDuckValues,
  39. initialGenerateValues,
  40. initialKeywordExtractValues,
  41. initialMessageValues,
  42. initialPubMedValues,
  43. initialRelevantValues,
  44. initialRetrievalValues,
  45. initialRewriteQuestionValues,
  46. initialWikipediaValues,
  47. } from './constant';
  48. import { ICategorizeForm, IRelevantForm } from './interface';
  49. import useGraphStore, { RFState } from './store';
  50. import {
  51. buildDslComponentsByGraph,
  52. receiveMessageError,
  53. replaceIdWithText,
  54. } from './utils';
  55. const selector = (state: RFState) => ({
  56. nodes: state.nodes,
  57. edges: state.edges,
  58. onNodesChange: state.onNodesChange,
  59. onEdgesChange: state.onEdgesChange,
  60. onConnect: state.onConnect,
  61. setNodes: state.setNodes,
  62. onSelectionChange: state.onSelectionChange,
  63. });
  64. export const useSelectCanvasData = () => {
  65. // return useStore(useShallow(selector)); // throw error
  66. // return useStore(selector, shallow);
  67. return useGraphStore(selector);
  68. };
  69. export const useInitializeOperatorParams = () => {
  70. const llmId = useFetchModelId(true);
  71. const initialFormValuesMap = useMemo(() => {
  72. return {
  73. [Operator.Begin]: initialBeginValues,
  74. [Operator.Retrieval]: initialRetrievalValues,
  75. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  76. [Operator.Answer]: {},
  77. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  78. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  79. [Operator.RewriteQuestion]: {
  80. ...initialRewriteQuestionValues,
  81. llm_id: llmId,
  82. },
  83. [Operator.Message]: initialMessageValues,
  84. [Operator.KeywordExtract]: {
  85. ...initialKeywordExtractValues,
  86. llm_id: llmId,
  87. },
  88. [Operator.DuckDuckGo]: initialDuckValues,
  89. [Operator.Baidu]: initialBaiduValues,
  90. [Operator.Wikipedia]: initialWikipediaValues,
  91. [Operator.PubMed]: initialPubMedValues,
  92. [Operator.Arxiv]: initialArxivValues,
  93. };
  94. }, [llmId]);
  95. const initializeOperatorParams = useCallback(
  96. (operatorName: Operator) => {
  97. return initialFormValuesMap[operatorName];
  98. },
  99. [initialFormValuesMap],
  100. );
  101. return initializeOperatorParams;
  102. };
  103. export const useHandleDrag = () => {
  104. const handleDragStart = useCallback(
  105. (operatorId: string) => (ev: React.DragEvent<HTMLDivElement>) => {
  106. ev.dataTransfer.setData('application/reactflow', operatorId);
  107. ev.dataTransfer.effectAllowed = 'move';
  108. },
  109. [],
  110. );
  111. return { handleDragStart };
  112. };
  113. export const useHandleDrop = () => {
  114. const addNode = useGraphStore((state) => state.addNode);
  115. const [reactFlowInstance, setReactFlowInstance] =
  116. useState<ReactFlowInstance<any, any>>();
  117. const initializeOperatorParams = useInitializeOperatorParams();
  118. const onDragOver = useCallback((event: React.DragEvent<HTMLDivElement>) => {
  119. event.preventDefault();
  120. event.dataTransfer.dropEffect = 'move';
  121. }, []);
  122. const onDrop = useCallback(
  123. (event: React.DragEvent<HTMLDivElement>) => {
  124. event.preventDefault();
  125. const type = event.dataTransfer.getData('application/reactflow');
  126. // check if the dropped element is valid
  127. if (typeof type === 'undefined' || !type) {
  128. return;
  129. }
  130. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  131. // and you don't need to subtract the reactFlowBounds.left/top anymore
  132. // details: https://reactflow.dev/whats-new/2023-11-10
  133. const position = reactFlowInstance?.screenToFlowPosition({
  134. x: event.clientX,
  135. y: event.clientY,
  136. });
  137. const newNode = {
  138. id: `${type}:${humanId()}`,
  139. type: NodeMap[type as Operator] || 'ragNode',
  140. position: position || {
  141. x: 0,
  142. y: 0,
  143. },
  144. data: {
  145. label: `${type}`,
  146. name: humanId(),
  147. form: initializeOperatorParams(type as Operator),
  148. },
  149. sourcePosition: Position.Right,
  150. targetPosition: Position.Left,
  151. };
  152. addNode(newNode);
  153. },
  154. [reactFlowInstance, addNode, initializeOperatorParams],
  155. );
  156. return { onDrop, onDragOver, setReactFlowInstance };
  157. };
  158. export const useShowDrawer = () => {
  159. const {
  160. clickedNodeId: clickNodeId,
  161. setClickedNodeId,
  162. getNode,
  163. } = useGraphStore((state) => state);
  164. const {
  165. visible: drawerVisible,
  166. hideModal: hideDrawer,
  167. showModal: showDrawer,
  168. } = useSetModalState();
  169. const handleShow = useCallback(
  170. (node: Node) => {
  171. setClickedNodeId(node.id);
  172. showDrawer();
  173. },
  174. [showDrawer, setClickedNodeId],
  175. );
  176. return {
  177. drawerVisible,
  178. hideDrawer,
  179. showDrawer: handleShow,
  180. clickedNode: getNode(clickNodeId),
  181. };
  182. };
  183. export const useHandleKeyUp = () => {
  184. const deleteEdge = useGraphStore((state) => state.deleteEdge);
  185. const handleKeyUp: KeyboardEventHandler = useCallback(
  186. (e) => {
  187. if (e.code === 'Delete') {
  188. deleteEdge();
  189. }
  190. },
  191. [deleteEdge],
  192. );
  193. return { handleKeyUp };
  194. };
  195. export const useSaveGraph = () => {
  196. const { data } = useFetchFlow();
  197. const { setFlow } = useSetFlow();
  198. const { id } = useParams();
  199. const { nodes, edges } = useGraphStore((state) => state);
  200. const saveGraph = useCallback(async () => {
  201. const dslComponents = buildDslComponentsByGraph(nodes, edges);
  202. return setFlow({
  203. id,
  204. title: data.title,
  205. dsl: { ...data.dsl, graph: { nodes, edges }, components: dslComponents },
  206. });
  207. }, [nodes, edges, setFlow, id, data]);
  208. return { saveGraph };
  209. };
  210. export const useWatchGraphChange = () => {
  211. const nodes = useGraphStore((state) => state.nodes);
  212. const edges = useGraphStore((state) => state.edges);
  213. useDebounceEffect(
  214. () => {
  215. // console.info('useDebounceEffect');
  216. },
  217. [nodes, edges],
  218. {
  219. wait: 1000,
  220. },
  221. );
  222. };
  223. export const useHandleFormValuesChange = (id?: string) => {
  224. const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
  225. const handleValuesChange = useCallback(
  226. (changedValues: any, values: any) => {
  227. if (id) {
  228. updateNodeForm(id, values);
  229. }
  230. },
  231. [updateNodeForm, id],
  232. );
  233. return { handleValuesChange };
  234. };
  235. const useSetGraphInfo = () => {
  236. const { setEdges, setNodes } = useGraphStore((state) => state);
  237. const setGraphInfo = useCallback(
  238. ({ nodes = [], edges = [] }: IGraph) => {
  239. if (nodes.length || edges.length) {
  240. setNodes(nodes);
  241. setEdges(edges);
  242. }
  243. },
  244. [setEdges, setNodes],
  245. );
  246. return setGraphInfo;
  247. };
  248. export const useFetchDataOnMount = () => {
  249. const { loading, data, refetch } = useFetchFlow();
  250. const setGraphInfo = useSetGraphInfo();
  251. useEffect(() => {
  252. setGraphInfo(data?.dsl?.graph ?? ({} as IGraph));
  253. }, [setGraphInfo, data]);
  254. useWatchGraphChange();
  255. useFetchLlmList();
  256. useEffect(() => {
  257. refetch();
  258. }, [refetch]);
  259. return { loading, flowDetail: data };
  260. };
  261. export const useFlowIsFetching = () => {
  262. return useIsFetching({ queryKey: ['flowDetail'] }) > 0;
  263. };
  264. export const useSetLlmSetting = (form?: FormInstance) => {
  265. const initialLlmSetting = undefined;
  266. useEffect(() => {
  267. const switchBoxValues = Object.keys(variableEnabledFieldMap).reduce<
  268. Record<string, boolean>
  269. >((pre, field) => {
  270. pre[field] =
  271. initialLlmSetting === undefined
  272. ? true
  273. : !!initialLlmSetting[
  274. variableEnabledFieldMap[
  275. field as keyof typeof variableEnabledFieldMap
  276. ] as keyof Variable
  277. ];
  278. return pre;
  279. }, {});
  280. const otherValues = settledModelVariableMap[ModelVariableType.Precise];
  281. form?.setFieldsValue({
  282. ...switchBoxValues,
  283. ...otherValues,
  284. });
  285. }, [form, initialLlmSetting]);
  286. };
  287. export const useValidateConnection = () => {
  288. const { edges, getOperatorTypeFromId } = useGraphStore((state) => state);
  289. // restricted lines cannot be connected successfully.
  290. const isValidConnection = useCallback(
  291. (connection: Connection) => {
  292. // node cannot connect to itself
  293. const isSelfConnected = connection.target === connection.source;
  294. // limit the connection between two nodes to only one connection line in one direction
  295. const hasLine = edges.some(
  296. (x) => x.source === connection.source && x.target === connection.target,
  297. );
  298. const ret =
  299. !isSelfConnected &&
  300. !hasLine &&
  301. RestrictedUpstreamMap[
  302. getOperatorTypeFromId(connection.source) as Operator
  303. ]?.every((x) => x !== getOperatorTypeFromId(connection.target));
  304. return ret;
  305. },
  306. [edges, getOperatorTypeFromId],
  307. );
  308. return isValidConnection;
  309. };
  310. export const useHandleNodeNameChange = (node?: Node) => {
  311. const [name, setName] = useState<string>('');
  312. const { updateNodeName, nodes } = useGraphStore((state) => state);
  313. const previousName = node?.data.name;
  314. const id = node?.id;
  315. const handleNameBlur = useCallback(() => {
  316. const existsSameName = nodes.some((x) => x.data.name === name);
  317. if (trim(name) === '' || existsSameName) {
  318. if (existsSameName && previousName !== name) {
  319. message.error('The name cannot be repeated');
  320. }
  321. setName(previousName);
  322. return;
  323. }
  324. if (id) {
  325. updateNodeName(id, name);
  326. }
  327. }, [name, id, updateNodeName, previousName, nodes]);
  328. const handleNameChange = useCallback((e: ChangeEvent<any>) => {
  329. setName(e.target.value);
  330. }, []);
  331. useEffect(() => {
  332. setName(previousName);
  333. }, [previousName]);
  334. return { name, handleNameBlur, handleNameChange };
  335. };
  336. export const useSaveGraphBeforeOpeningDebugDrawer = (show: () => void) => {
  337. const { id } = useParams();
  338. const { saveGraph } = useSaveGraph();
  339. const { resetFlow } = useResetFlow();
  340. const { refetch } = useFetchFlow();
  341. const { send } = useSendMessageWithSse(api.runCanvas);
  342. const handleRun = useCallback(async () => {
  343. const saveRet = await saveGraph();
  344. if (saveRet?.retcode === 0) {
  345. // Call the reset api before opening the run drawer each time
  346. const resetRet = await resetFlow();
  347. // After resetting, all previous messages will be cleared.
  348. if (resetRet?.retcode === 0) {
  349. // fetch prologue
  350. const sendRet = await send({ id });
  351. if (receiveMessageError(sendRet)) {
  352. message.error(sendRet?.data?.retmsg);
  353. } else {
  354. refetch();
  355. show();
  356. }
  357. }
  358. }
  359. }, [saveGraph, resetFlow, id, send, show, refetch]);
  360. return handleRun;
  361. };
  362. export const useReplaceIdWithText = (output: unknown) => {
  363. const getNode = useGraphStore((state) => state.getNode);
  364. const getNameById = (id?: string) => {
  365. return getNode(id)?.data.name;
  366. };
  367. return replaceIdWithText(output, getNameById);
  368. };
  369. /**
  370. * monitor changes in the data.form field of the categorize and relevant operators
  371. * and then synchronize them to the edge
  372. */
  373. export const useWatchNodeFormDataChange = () => {
  374. const { getNode, nodes, setEdgesByNodeId } = useGraphStore((state) => state);
  375. const buildCategorizeEdgesByFormData = useCallback(
  376. (nodeId: string, form: ICategorizeForm) => {
  377. // add
  378. // delete
  379. // edit
  380. const categoryDescription = form.category_description;
  381. const downstreamEdges = Object.keys(categoryDescription).reduce<Edge[]>(
  382. (pre, sourceHandle) => {
  383. const target = categoryDescription[sourceHandle]?.to;
  384. if (target) {
  385. pre.push({
  386. id: uuid(),
  387. source: nodeId,
  388. target,
  389. sourceHandle,
  390. });
  391. }
  392. return pre;
  393. },
  394. [],
  395. );
  396. setEdgesByNodeId(nodeId, downstreamEdges);
  397. },
  398. [setEdgesByNodeId],
  399. );
  400. const buildRelevantEdgesByFormData = useCallback(
  401. (nodeId: string, form: IRelevantForm) => {
  402. const downstreamEdges = ['yes', 'no'].reduce<Edge[]>((pre, cur) => {
  403. const target = form[cur as keyof IRelevantForm] as string;
  404. if (target) {
  405. pre.push({ id: uuid(), source: nodeId, target, sourceHandle: cur });
  406. }
  407. return pre;
  408. }, []);
  409. setEdgesByNodeId(nodeId, downstreamEdges);
  410. },
  411. [setEdgesByNodeId],
  412. );
  413. useEffect(() => {
  414. nodes.forEach((node) => {
  415. const currentNode = getNode(node.id);
  416. const form = currentNode?.data.form ?? {};
  417. const operatorType = currentNode?.data.label;
  418. switch (operatorType) {
  419. case Operator.Relevant:
  420. buildRelevantEdgesByFormData(node.id, form as IRelevantForm);
  421. break;
  422. case Operator.Categorize:
  423. buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
  424. break;
  425. default:
  426. break;
  427. }
  428. });
  429. }, [
  430. nodes,
  431. buildCategorizeEdgesByFormData,
  432. getNode,
  433. buildRelevantEdgesByFormData,
  434. ]);
  435. };