You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import { useFetchModelId } from '@/hooks/logic-hooks';
  2. import { Connection, Node, Position, ReactFlowInstance } from '@xyflow/react';
  3. import humanId from 'human-id';
  4. import { lowerFirst } from 'lodash';
  5. import { useCallback, useMemo } from 'react';
  6. import { useTranslation } from 'react-i18next';
  7. import {
  8. NodeHandleId,
  9. NodeMap,
  10. Operator,
  11. initialAgentValues,
  12. initialAkShareValues,
  13. initialArXivValues,
  14. initialBaiduFanyiValues,
  15. initialBaiduValues,
  16. initialBeginValues,
  17. initialBingValues,
  18. initialCategorizeValues,
  19. initialCodeValues,
  20. initialConcentratorValues,
  21. initialCrawlerValues,
  22. initialDeepLValues,
  23. initialDuckValues,
  24. initialEmailValues,
  25. initialExeSqlValues,
  26. initialGenerateValues,
  27. initialGithubValues,
  28. initialGoogleScholarValues,
  29. initialGoogleValues,
  30. initialInvokeValues,
  31. initialIterationStartValues,
  32. initialIterationValues,
  33. initialJin10Values,
  34. initialKeywordExtractValues,
  35. initialMessageValues,
  36. initialNoteValues,
  37. initialPubMedValues,
  38. initialQWeatherValues,
  39. initialRelevantValues,
  40. initialRetrievalValues,
  41. initialRewriteQuestionValues,
  42. initialStringTransformValues,
  43. initialSwitchValues,
  44. initialTavilyValues,
  45. initialTemplateValues,
  46. initialTuShareValues,
  47. initialUserFillUpValues,
  48. initialWaitingDialogueValues,
  49. initialWenCaiValues,
  50. initialWikipediaValues,
  51. initialYahooFinanceValues,
  52. } from '../constant';
  53. import useGraphStore from '../store';
  54. import {
  55. generateNodeNamesWithIncreasingIndex,
  56. getNodeDragHandle,
  57. } from '../utils';
  58. function isBottomSubAgent(type: string, position: Position) {
  59. return (
  60. (type === Operator.Agent && position === Position.Bottom) ||
  61. type === Operator.Tool
  62. );
  63. }
  64. export const useInitializeOperatorParams = () => {
  65. const llmId = useFetchModelId();
  66. const initialFormValuesMap = useMemo(() => {
  67. return {
  68. [Operator.Begin]: initialBeginValues,
  69. [Operator.Retrieval]: initialRetrievalValues,
  70. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  71. [Operator.Answer]: {},
  72. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  73. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  74. [Operator.RewriteQuestion]: {
  75. ...initialRewriteQuestionValues,
  76. llm_id: llmId,
  77. },
  78. [Operator.Message]: initialMessageValues,
  79. [Operator.KeywordExtract]: {
  80. ...initialKeywordExtractValues,
  81. llm_id: llmId,
  82. },
  83. [Operator.DuckDuckGo]: initialDuckValues,
  84. [Operator.Baidu]: initialBaiduValues,
  85. [Operator.Wikipedia]: initialWikipediaValues,
  86. [Operator.PubMed]: initialPubMedValues,
  87. [Operator.ArXiv]: initialArXivValues,
  88. [Operator.Google]: initialGoogleValues,
  89. [Operator.Bing]: initialBingValues,
  90. [Operator.GoogleScholar]: initialGoogleScholarValues,
  91. [Operator.DeepL]: initialDeepLValues,
  92. [Operator.GitHub]: initialGithubValues,
  93. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  94. [Operator.QWeather]: initialQWeatherValues,
  95. [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
  96. [Operator.Switch]: initialSwitchValues,
  97. [Operator.WenCai]: initialWenCaiValues,
  98. [Operator.AkShare]: initialAkShareValues,
  99. [Operator.YahooFinance]: initialYahooFinanceValues,
  100. [Operator.Jin10]: initialJin10Values,
  101. [Operator.Concentrator]: initialConcentratorValues,
  102. [Operator.TuShare]: initialTuShareValues,
  103. [Operator.Note]: initialNoteValues,
  104. [Operator.Crawler]: initialCrawlerValues,
  105. [Operator.Invoke]: initialInvokeValues,
  106. [Operator.Template]: initialTemplateValues,
  107. [Operator.Email]: initialEmailValues,
  108. [Operator.Iteration]: initialIterationValues,
  109. [Operator.IterationStart]: initialIterationStartValues,
  110. [Operator.Code]: initialCodeValues,
  111. [Operator.WaitingDialogue]: initialWaitingDialogueValues,
  112. [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
  113. [Operator.Tool]: {},
  114. [Operator.TavilySearch]: initialTavilyValues,
  115. [Operator.UserFillUp]: initialUserFillUpValues,
  116. [Operator.StringTransform]: initialStringTransformValues,
  117. };
  118. }, [llmId]);
  119. const initializeOperatorParams = useCallback(
  120. (operatorName: Operator, position: Position) => {
  121. const initialValues = initialFormValuesMap[operatorName];
  122. if (isBottomSubAgent(operatorName, position)) {
  123. return {
  124. ...initialValues,
  125. description: 'This is an agent for a specific task.',
  126. user_prompt: 'This is the order you need to send to the agent.',
  127. };
  128. }
  129. return initialValues;
  130. },
  131. [initialFormValuesMap],
  132. );
  133. return { initializeOperatorParams, initialFormValuesMap };
  134. };
  135. export const useGetNodeName = () => {
  136. const { t } = useTranslation();
  137. return (type: string) => {
  138. const name = t(`flow.${lowerFirst(type)}`);
  139. return name;
  140. };
  141. };
  142. export function useCalculateNewlyChildPosition() {
  143. const getNode = useGraphStore((state) => state.getNode);
  144. const nodes = useGraphStore((state) => state.nodes);
  145. const edges = useGraphStore((state) => state.edges);
  146. const calculateNewlyBackChildPosition = useCallback(
  147. (id?: string, sourceHandle?: string) => {
  148. const parentNode = getNode(id);
  149. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  150. const allChildNodeIds = edges
  151. .filter((x) => x.source === id && x.sourceHandle === sourceHandle)
  152. .map((x) => x.target);
  153. const yAxises = nodes
  154. .filter((x) => allChildNodeIds.some((y) => y === x.id))
  155. .map((x) => x.position.y);
  156. const maxY = Math.max(...yAxises);
  157. const position = {
  158. y: yAxises.length > 0 ? maxY + 150 : parentNode?.position.y || 0,
  159. x: (parentNode?.position.x || 0) + 300,
  160. };
  161. return position;
  162. },
  163. [edges, getNode, nodes],
  164. );
  165. return { calculateNewlyBackChildPosition };
  166. }
  167. function useAddChildEdge() {
  168. const addEdge = useGraphStore((state) => state.addEdge);
  169. const addChildEdge = useCallback(
  170. (position: Position = Position.Right, edge: Partial<Connection>) => {
  171. if (
  172. position === Position.Right &&
  173. edge.source &&
  174. edge.target &&
  175. edge.sourceHandle
  176. ) {
  177. addEdge({
  178. source: edge.source,
  179. target: edge.target,
  180. sourceHandle: edge.sourceHandle,
  181. targetHandle: NodeHandleId.End,
  182. });
  183. }
  184. },
  185. [addEdge],
  186. );
  187. return { addChildEdge };
  188. }
  189. function useAddToolNode() {
  190. const { nodes, edges, addEdge, getNode, addNode } = useGraphStore(
  191. (state) => state,
  192. );
  193. const addToolNode = useCallback(
  194. (newNode: Node<any>, nodeId?: string) => {
  195. const agentNode = getNode(nodeId);
  196. if (agentNode) {
  197. const childToolNodeIds = edges
  198. .filter(
  199. (x) => x.source === nodeId && x.sourceHandle === NodeHandleId.Tool,
  200. )
  201. .map((x) => x.target);
  202. if (
  203. childToolNodeIds.length > 0 &&
  204. nodes.some((x) => x.id === childToolNodeIds[0])
  205. ) {
  206. return;
  207. }
  208. newNode.position = {
  209. x: agentNode.position.x - 82,
  210. y: agentNode.position.y + 140,
  211. };
  212. addNode(newNode);
  213. if (nodeId) {
  214. addEdge({
  215. source: nodeId,
  216. target: newNode.id,
  217. sourceHandle: NodeHandleId.Tool,
  218. targetHandle: NodeHandleId.End,
  219. });
  220. }
  221. }
  222. },
  223. [addEdge, addNode, edges, getNode, nodes],
  224. );
  225. return { addToolNode };
  226. }
  227. function useResizeIterationNode() {
  228. const { getNode, nodes, updateNode } = useGraphStore((state) => state);
  229. const resizeIterationNode = useCallback(
  230. (type: string, position: Position, parentId?: string) => {
  231. const parentNode = getNode(parentId);
  232. if (parentNode && !isBottomSubAgent(type, position)) {
  233. const MoveRightDistance = 310;
  234. const childNodeList = nodes.filter((x) => x.parentId === parentId);
  235. const maxX = Math.max(...childNodeList.map((x) => x.position.x));
  236. if (maxX + MoveRightDistance > parentNode.position.x) {
  237. updateNode({
  238. ...parentNode,
  239. width: (parentNode.width || 0) + MoveRightDistance,
  240. position: {
  241. x: parentNode.position.x + MoveRightDistance / 2,
  242. y: parentNode.position.y,
  243. },
  244. });
  245. }
  246. }
  247. },
  248. [getNode, nodes, updateNode],
  249. );
  250. return { resizeIterationNode };
  251. }
  252. type CanvasMouseEvent = Pick<
  253. React.MouseEvent<HTMLElement>,
  254. 'clientX' | 'clientY'
  255. >;
  256. export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
  257. const { edges, nodes, addEdge, addNode, getNode } = useGraphStore(
  258. (state) => state,
  259. );
  260. const getNodeName = useGetNodeName();
  261. const { initializeOperatorParams } = useInitializeOperatorParams();
  262. const { calculateNewlyBackChildPosition } = useCalculateNewlyChildPosition();
  263. const { addChildEdge } = useAddChildEdge();
  264. const { addToolNode } = useAddToolNode();
  265. const { resizeIterationNode } = useResizeIterationNode();
  266. // const [reactFlowInstance, setReactFlowInstance] =
  267. // useState<ReactFlowInstance<any, any>>();
  268. const addCanvasNode = useCallback(
  269. (
  270. type: string,
  271. params: { nodeId?: string; position: Position; id?: string } = {
  272. position: Position.Right,
  273. },
  274. ) =>
  275. (event?: CanvasMouseEvent) => {
  276. const nodeId = params.nodeId;
  277. const node = getNode(nodeId);
  278. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  279. // and you don't need to subtract the reactFlowBounds.left/top anymore
  280. // details: https://@xyflow/react.dev/whats-new/2023-11-10
  281. let position = reactFlowInstance?.screenToFlowPosition({
  282. x: event?.clientX || 0,
  283. y: event?.clientY || 0,
  284. });
  285. if (params.position === Position.Right && type !== Operator.Note) {
  286. position = calculateNewlyBackChildPosition(nodeId, params.id);
  287. }
  288. const newNode: Node<any> = {
  289. id: `${type}:${humanId()}`,
  290. type: NodeMap[type as Operator] || 'ragNode',
  291. position: position || {
  292. x: 0,
  293. y: 0,
  294. },
  295. data: {
  296. label: `${type}`,
  297. name: generateNodeNamesWithIncreasingIndex(
  298. getNodeName(type),
  299. nodes,
  300. ),
  301. form: initializeOperatorParams(type as Operator, params.position),
  302. },
  303. sourcePosition: Position.Right,
  304. targetPosition: Position.Left,
  305. dragHandle: getNodeDragHandle(type),
  306. };
  307. if (node && node.parentId) {
  308. newNode.parentId = node.parentId;
  309. newNode.extent = 'parent';
  310. const parentNode = getNode(node.parentId);
  311. if (parentNode && !isBottomSubAgent(type, params.position)) {
  312. resizeIterationNode(type, params.position, node.parentId);
  313. }
  314. }
  315. if (type === Operator.Iteration) {
  316. newNode.width = 500;
  317. newNode.height = 250;
  318. const iterationStartNode: Node<any> = {
  319. id: `${Operator.IterationStart}:${humanId()}`,
  320. type: 'iterationStartNode',
  321. position: { x: 50, y: 100 },
  322. // draggable: false,
  323. data: {
  324. label: Operator.IterationStart,
  325. name: Operator.IterationStart,
  326. form: initialIterationStartValues,
  327. },
  328. parentId: newNode.id,
  329. extent: 'parent',
  330. };
  331. addNode(newNode);
  332. addNode(iterationStartNode);
  333. if (nodeId) {
  334. addEdge({
  335. source: nodeId,
  336. target: newNode.id,
  337. sourceHandle: NodeHandleId.Start,
  338. targetHandle: NodeHandleId.End,
  339. });
  340. }
  341. } else if (
  342. type === Operator.Agent &&
  343. params.position === Position.Bottom
  344. ) {
  345. const agentNode = getNode(nodeId);
  346. if (agentNode) {
  347. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  348. const allChildAgentNodeIds = edges
  349. .filter(
  350. (x) =>
  351. x.source === nodeId &&
  352. x.sourceHandle === NodeHandleId.AgentBottom,
  353. )
  354. .map((x) => x.target);
  355. const xAxises = nodes
  356. .filter((x) => allChildAgentNodeIds.some((y) => y === x.id))
  357. .map((x) => x.position.x);
  358. const maxX = Math.max(...xAxises);
  359. newNode.position = {
  360. x: xAxises.length > 0 ? maxX + 262 : agentNode.position.x + 82,
  361. y: agentNode.position.y + 140,
  362. };
  363. }
  364. addNode(newNode);
  365. if (nodeId) {
  366. addEdge({
  367. source: nodeId,
  368. target: newNode.id,
  369. sourceHandle: NodeHandleId.AgentBottom,
  370. targetHandle: NodeHandleId.AgentTop,
  371. });
  372. }
  373. } else if (type === Operator.Tool) {
  374. addToolNode(newNode, params.nodeId);
  375. } else {
  376. addNode(newNode);
  377. addChildEdge(params.position, {
  378. source: params.nodeId,
  379. target: newNode.id,
  380. sourceHandle: params.id,
  381. });
  382. }
  383. },
  384. [
  385. addChildEdge,
  386. addEdge,
  387. addNode,
  388. addToolNode,
  389. calculateNewlyBackChildPosition,
  390. edges,
  391. getNode,
  392. getNodeName,
  393. initializeOperatorParams,
  394. nodes,
  395. reactFlowInstance,
  396. resizeIterationNode,
  397. ],
  398. );
  399. const addNoteNode = useCallback(
  400. (e: CanvasMouseEvent) => {
  401. addCanvasNode(Operator.Note)(e);
  402. },
  403. [addCanvasNode],
  404. );
  405. return { addCanvasNode, addNoteNode };
  406. }