Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

use-add-node.ts 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import { useFetchModelId } from '@/hooks/logic-hooks';
  2. import { Connection, Node, Position, ReactFlowInstance } from '@xyflow/react';
  3. import humanId from 'human-id';
  4. import { lowerFirst } from 'lodash';
  5. import { useCallback, useMemo } from 'react';
  6. import { useTranslation } from 'react-i18next';
  7. import {
  8. NodeHandleId,
  9. NodeMap,
  10. Operator,
  11. initialAgentValues,
  12. initialAkShareValues,
  13. initialArXivValues,
  14. initialBaiduFanyiValues,
  15. initialBaiduValues,
  16. initialBeginValues,
  17. initialBingValues,
  18. initialCategorizeValues,
  19. initialCodeValues,
  20. initialConcentratorValues,
  21. initialCrawlerValues,
  22. initialDeepLValues,
  23. initialDuckValues,
  24. initialEmailValues,
  25. initialExeSqlValues,
  26. initialGenerateValues,
  27. initialGithubValues,
  28. initialGoogleScholarValues,
  29. initialGoogleValues,
  30. initialInvokeValues,
  31. initialIterationStartValues,
  32. initialIterationValues,
  33. initialJin10Values,
  34. initialKeywordExtractValues,
  35. initialMessageValues,
  36. initialNoteValues,
  37. initialPubMedValues,
  38. initialQWeatherValues,
  39. initialRelevantValues,
  40. initialRetrievalValues,
  41. initialRewriteQuestionValues,
  42. initialStringTransformValues,
  43. initialSwitchValues,
  44. initialTavilyValues,
  45. initialTemplateValues,
  46. initialTuShareValues,
  47. initialUserFillUpValues,
  48. initialWaitingDialogueValues,
  49. initialWenCaiValues,
  50. initialWikipediaValues,
  51. initialYahooFinanceValues,
  52. } from '../constant';
  53. import useGraphStore from '../store';
  54. import {
  55. generateNodeNamesWithIncreasingIndex,
  56. getNodeDragHandle,
  57. } from '../utils';
  58. export const useInitializeOperatorParams = () => {
  59. const llmId = useFetchModelId();
  60. const initialFormValuesMap = useMemo(() => {
  61. return {
  62. [Operator.Begin]: initialBeginValues,
  63. [Operator.Retrieval]: initialRetrievalValues,
  64. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  65. [Operator.Answer]: {},
  66. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  67. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  68. [Operator.RewriteQuestion]: {
  69. ...initialRewriteQuestionValues,
  70. llm_id: llmId,
  71. },
  72. [Operator.Message]: initialMessageValues,
  73. [Operator.KeywordExtract]: {
  74. ...initialKeywordExtractValues,
  75. llm_id: llmId,
  76. },
  77. [Operator.DuckDuckGo]: initialDuckValues,
  78. [Operator.Baidu]: initialBaiduValues,
  79. [Operator.Wikipedia]: initialWikipediaValues,
  80. [Operator.PubMed]: initialPubMedValues,
  81. [Operator.ArXiv]: initialArXivValues,
  82. [Operator.Google]: initialGoogleValues,
  83. [Operator.Bing]: initialBingValues,
  84. [Operator.GoogleScholar]: initialGoogleScholarValues,
  85. [Operator.DeepL]: initialDeepLValues,
  86. [Operator.GitHub]: initialGithubValues,
  87. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  88. [Operator.QWeather]: initialQWeatherValues,
  89. [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
  90. [Operator.Switch]: initialSwitchValues,
  91. [Operator.WenCai]: initialWenCaiValues,
  92. [Operator.AkShare]: initialAkShareValues,
  93. [Operator.YahooFinance]: initialYahooFinanceValues,
  94. [Operator.Jin10]: initialJin10Values,
  95. [Operator.Concentrator]: initialConcentratorValues,
  96. [Operator.TuShare]: initialTuShareValues,
  97. [Operator.Note]: initialNoteValues,
  98. [Operator.Crawler]: initialCrawlerValues,
  99. [Operator.Invoke]: initialInvokeValues,
  100. [Operator.Template]: initialTemplateValues,
  101. [Operator.Email]: initialEmailValues,
  102. [Operator.Iteration]: initialIterationValues,
  103. [Operator.IterationStart]: initialIterationStartValues,
  104. [Operator.Code]: initialCodeValues,
  105. [Operator.WaitingDialogue]: initialWaitingDialogueValues,
  106. [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
  107. [Operator.Tool]: {},
  108. [Operator.TavilySearch]: initialTavilyValues,
  109. [Operator.UserFillUp]: initialUserFillUpValues,
  110. [Operator.StringTransform]: initialStringTransformValues,
  111. };
  112. }, [llmId]);
  113. const initializeOperatorParams = useCallback(
  114. (operatorName: Operator) => {
  115. return initialFormValuesMap[operatorName];
  116. },
  117. [initialFormValuesMap],
  118. );
  119. return initializeOperatorParams;
  120. };
  121. export const useGetNodeName = () => {
  122. const { t } = useTranslation();
  123. return (type: string) => {
  124. const name = t(`flow.${lowerFirst(type)}`);
  125. return name;
  126. };
  127. };
  128. export function useCalculateNewlyChildPosition() {
  129. const getNode = useGraphStore((state) => state.getNode);
  130. const nodes = useGraphStore((state) => state.nodes);
  131. const edges = useGraphStore((state) => state.edges);
  132. const calculateNewlyBackChildPosition = useCallback(
  133. (id?: string, sourceHandle?: string) => {
  134. const parentNode = getNode(id);
  135. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  136. const allChildNodeIds = edges
  137. .filter((x) => x.source === id && x.sourceHandle === sourceHandle)
  138. .map((x) => x.target);
  139. const yAxises = nodes
  140. .filter((x) => allChildNodeIds.some((y) => y === x.id))
  141. .map((x) => x.position.y);
  142. const maxY = Math.max(...yAxises);
  143. const position = {
  144. y: yAxises.length > 0 ? maxY + 150 : parentNode?.position.y || 0,
  145. x: (parentNode?.position.x || 0) + 300,
  146. };
  147. return position;
  148. },
  149. [edges, getNode, nodes],
  150. );
  151. return { calculateNewlyBackChildPosition };
  152. }
  153. function useAddChildEdge() {
  154. const addEdge = useGraphStore((state) => state.addEdge);
  155. const addChildEdge = useCallback(
  156. (position: Position = Position.Right, edge: Partial<Connection>) => {
  157. if (
  158. position === Position.Right &&
  159. edge.source &&
  160. edge.target &&
  161. edge.sourceHandle
  162. ) {
  163. addEdge({
  164. source: edge.source,
  165. target: edge.target,
  166. sourceHandle: edge.sourceHandle,
  167. targetHandle: NodeHandleId.End,
  168. });
  169. }
  170. },
  171. [addEdge],
  172. );
  173. return { addChildEdge };
  174. }
  175. function useAddToolNode() {
  176. const { nodes, edges, addEdge, getNode, addNode } = useGraphStore(
  177. (state) => state,
  178. );
  179. const addToolNode = useCallback(
  180. (newNode: Node<any>, nodeId?: string) => {
  181. const agentNode = getNode(nodeId);
  182. if (agentNode) {
  183. const childToolNodeIds = edges
  184. .filter(
  185. (x) => x.source === nodeId && x.sourceHandle === NodeHandleId.Tool,
  186. )
  187. .map((x) => x.target);
  188. if (
  189. childToolNodeIds.length > 0 &&
  190. nodes.some((x) => x.id === childToolNodeIds[0])
  191. ) {
  192. return;
  193. }
  194. newNode.position = {
  195. x: agentNode.position.x - 82,
  196. y: agentNode.position.y + 140,
  197. };
  198. addNode(newNode);
  199. if (nodeId) {
  200. addEdge({
  201. source: nodeId,
  202. target: newNode.id,
  203. sourceHandle: NodeHandleId.Tool,
  204. targetHandle: NodeHandleId.End,
  205. });
  206. }
  207. }
  208. },
  209. [addEdge, addNode, edges, getNode, nodes],
  210. );
  211. return { addToolNode };
  212. }
  213. function isBottomSubAgent(type: string, position: Position) {
  214. return (
  215. (type === Operator.Agent && position === Position.Bottom) ||
  216. type === Operator.Tool
  217. );
  218. }
  219. function useResizeIterationNode() {
  220. const { getNode, nodes, updateNode } = useGraphStore((state) => state);
  221. const resizeIterationNode = useCallback(
  222. (type: string, position: Position, parentId?: string) => {
  223. const parentNode = getNode(parentId);
  224. if (parentNode && !isBottomSubAgent(type, position)) {
  225. const MoveRightDistance = 310;
  226. const childNodeList = nodes.filter((x) => x.parentId === parentId);
  227. const maxX = Math.max(...childNodeList.map((x) => x.position.x));
  228. if (maxX + MoveRightDistance > parentNode.position.x) {
  229. updateNode({
  230. ...parentNode,
  231. width: (parentNode.width || 0) + MoveRightDistance,
  232. position: {
  233. x: parentNode.position.x + MoveRightDistance / 2,
  234. y: parentNode.position.y,
  235. },
  236. });
  237. }
  238. }
  239. },
  240. [getNode, nodes, updateNode],
  241. );
  242. return { resizeIterationNode };
  243. }
  244. type CanvasMouseEvent = Pick<
  245. React.MouseEvent<HTMLElement>,
  246. 'clientX' | 'clientY'
  247. >;
  248. export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
  249. const { edges, nodes, addEdge, addNode, getNode } = useGraphStore(
  250. (state) => state,
  251. );
  252. const getNodeName = useGetNodeName();
  253. const initializeOperatorParams = useInitializeOperatorParams();
  254. const { calculateNewlyBackChildPosition } = useCalculateNewlyChildPosition();
  255. const { addChildEdge } = useAddChildEdge();
  256. const { addToolNode } = useAddToolNode();
  257. const { resizeIterationNode } = useResizeIterationNode();
  258. // const [reactFlowInstance, setReactFlowInstance] =
  259. // useState<ReactFlowInstance<any, any>>();
  260. const addCanvasNode = useCallback(
  261. (
  262. type: string,
  263. params: { nodeId?: string; position: Position; id?: string } = {
  264. position: Position.Right,
  265. },
  266. ) =>
  267. (event?: CanvasMouseEvent) => {
  268. const nodeId = params.nodeId;
  269. const node = getNode(nodeId);
  270. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  271. // and you don't need to subtract the reactFlowBounds.left/top anymore
  272. // details: https://@xyflow/react.dev/whats-new/2023-11-10
  273. let position = reactFlowInstance?.screenToFlowPosition({
  274. x: event?.clientX || 0,
  275. y: event?.clientY || 0,
  276. });
  277. if (params.position === Position.Right && type !== Operator.Note) {
  278. position = calculateNewlyBackChildPosition(nodeId, params.id);
  279. }
  280. const newNode: Node<any> = {
  281. id: `${type}:${humanId()}`,
  282. type: NodeMap[type as Operator] || 'ragNode',
  283. position: position || {
  284. x: 0,
  285. y: 0,
  286. },
  287. data: {
  288. label: `${type}`,
  289. name: generateNodeNamesWithIncreasingIndex(
  290. getNodeName(type),
  291. nodes,
  292. ),
  293. form: initializeOperatorParams(type as Operator),
  294. },
  295. sourcePosition: Position.Right,
  296. targetPosition: Position.Left,
  297. dragHandle: getNodeDragHandle(type),
  298. };
  299. if (node && node.parentId) {
  300. newNode.parentId = node.parentId;
  301. newNode.extent = 'parent';
  302. const parentNode = getNode(node.parentId);
  303. if (parentNode && !isBottomSubAgent(type, params.position)) {
  304. resizeIterationNode(type, params.position, node.parentId);
  305. }
  306. }
  307. if (type === Operator.Iteration) {
  308. newNode.width = 500;
  309. newNode.height = 250;
  310. const iterationStartNode: Node<any> = {
  311. id: `${Operator.IterationStart}:${humanId()}`,
  312. type: 'iterationStartNode',
  313. position: { x: 50, y: 100 },
  314. // draggable: false,
  315. data: {
  316. label: Operator.IterationStart,
  317. name: Operator.IterationStart,
  318. form: initialIterationStartValues,
  319. },
  320. parentId: newNode.id,
  321. extent: 'parent',
  322. };
  323. addNode(newNode);
  324. addNode(iterationStartNode);
  325. if (nodeId) {
  326. addEdge({
  327. source: nodeId,
  328. target: newNode.id,
  329. sourceHandle: NodeHandleId.Start,
  330. targetHandle: NodeHandleId.End,
  331. });
  332. }
  333. } else if (
  334. type === Operator.Agent &&
  335. params.position === Position.Bottom
  336. ) {
  337. const agentNode = getNode(nodeId);
  338. if (agentNode) {
  339. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  340. const allChildAgentNodeIds = edges
  341. .filter(
  342. (x) =>
  343. x.source === nodeId &&
  344. x.sourceHandle === NodeHandleId.AgentBottom,
  345. )
  346. .map((x) => x.target);
  347. const xAxises = nodes
  348. .filter((x) => allChildAgentNodeIds.some((y) => y === x.id))
  349. .map((x) => x.position.x);
  350. const maxX = Math.max(...xAxises);
  351. newNode.position = {
  352. x: xAxises.length > 0 ? maxX + 262 : agentNode.position.x + 82,
  353. y: agentNode.position.y + 140,
  354. };
  355. }
  356. addNode(newNode);
  357. if (nodeId) {
  358. addEdge({
  359. source: nodeId,
  360. target: newNode.id,
  361. sourceHandle: NodeHandleId.AgentBottom,
  362. targetHandle: NodeHandleId.AgentTop,
  363. });
  364. }
  365. } else if (type === Operator.Tool) {
  366. addToolNode(newNode, params.nodeId);
  367. } else {
  368. addNode(newNode);
  369. addChildEdge(params.position, {
  370. source: params.nodeId,
  371. target: newNode.id,
  372. sourceHandle: params.id,
  373. });
  374. }
  375. },
  376. [
  377. addChildEdge,
  378. addEdge,
  379. addNode,
  380. addToolNode,
  381. calculateNewlyBackChildPosition,
  382. edges,
  383. getNode,
  384. getNodeName,
  385. initializeOperatorParams,
  386. nodes,
  387. reactFlowInstance,
  388. resizeIterationNode,
  389. ],
  390. );
  391. const addNoteNode = useCallback(
  392. (e: CanvasMouseEvent) => {
  393. addCanvasNode(Operator.Note)(e);
  394. },
  395. [addCanvasNode],
  396. );
  397. return { addCanvasNode, addNoteNode };
  398. }