選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

use-add-node.ts 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import { useFetchModelId } from '@/hooks/logic-hooks';
  2. import { Node, Position, ReactFlowInstance } from '@xyflow/react';
  3. import humanId from 'human-id';
  4. import { lowerFirst } from 'lodash';
  5. import { useCallback, useMemo } from 'react';
  6. import { useTranslation } from 'react-i18next';
  7. import {
  8. NodeMap,
  9. Operator,
  10. initialAgentValues,
  11. initialAkShareValues,
  12. initialArXivValues,
  13. initialBaiduFanyiValues,
  14. initialBaiduValues,
  15. initialBeginValues,
  16. initialBingValues,
  17. initialCategorizeValues,
  18. initialCodeValues,
  19. initialConcentratorValues,
  20. initialCrawlerValues,
  21. initialDeepLValues,
  22. initialDuckValues,
  23. initialEmailValues,
  24. initialExeSqlValues,
  25. initialGenerateValues,
  26. initialGithubValues,
  27. initialGoogleScholarValues,
  28. initialGoogleValues,
  29. initialInvokeValues,
  30. initialIterationValues,
  31. initialJin10Values,
  32. initialKeywordExtractValues,
  33. initialMessageValues,
  34. initialNoteValues,
  35. initialPubMedValues,
  36. initialQWeatherValues,
  37. initialRelevantValues,
  38. initialRetrievalValues,
  39. initialRewriteQuestionValues,
  40. initialSwitchValues,
  41. initialTemplateValues,
  42. initialTuShareValues,
  43. initialWaitingDialogueValues,
  44. initialWenCaiValues,
  45. initialWikipediaValues,
  46. initialYahooFinanceValues,
  47. } from '../constant';
  48. import useGraphStore from '../store';
  49. import {
  50. generateNodeNamesWithIncreasingIndex,
  51. getNodeDragHandle,
  52. getRelativePositionToIterationNode,
  53. } from '../utils';
  54. export const useInitializeOperatorParams = () => {
  55. const llmId = useFetchModelId();
  56. const initialFormValuesMap = useMemo(() => {
  57. return {
  58. [Operator.Begin]: initialBeginValues,
  59. [Operator.Retrieval]: initialRetrievalValues,
  60. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  61. [Operator.Answer]: {},
  62. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  63. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  64. [Operator.RewriteQuestion]: {
  65. ...initialRewriteQuestionValues,
  66. llm_id: llmId,
  67. },
  68. [Operator.Message]: initialMessageValues,
  69. [Operator.KeywordExtract]: {
  70. ...initialKeywordExtractValues,
  71. llm_id: llmId,
  72. },
  73. [Operator.DuckDuckGo]: initialDuckValues,
  74. [Operator.Baidu]: initialBaiduValues,
  75. [Operator.Wikipedia]: initialWikipediaValues,
  76. [Operator.PubMed]: initialPubMedValues,
  77. [Operator.ArXiv]: initialArXivValues,
  78. [Operator.Google]: initialGoogleValues,
  79. [Operator.Bing]: initialBingValues,
  80. [Operator.GoogleScholar]: initialGoogleScholarValues,
  81. [Operator.DeepL]: initialDeepLValues,
  82. [Operator.GitHub]: initialGithubValues,
  83. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  84. [Operator.QWeather]: initialQWeatherValues,
  85. [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
  86. [Operator.Switch]: initialSwitchValues,
  87. [Operator.WenCai]: initialWenCaiValues,
  88. [Operator.AkShare]: initialAkShareValues,
  89. [Operator.YahooFinance]: initialYahooFinanceValues,
  90. [Operator.Jin10]: initialJin10Values,
  91. [Operator.Concentrator]: initialConcentratorValues,
  92. [Operator.TuShare]: initialTuShareValues,
  93. [Operator.Note]: initialNoteValues,
  94. [Operator.Crawler]: initialCrawlerValues,
  95. [Operator.Invoke]: initialInvokeValues,
  96. [Operator.Template]: initialTemplateValues,
  97. [Operator.Email]: initialEmailValues,
  98. [Operator.Iteration]: initialIterationValues,
  99. [Operator.IterationStart]: initialIterationValues,
  100. [Operator.Code]: initialCodeValues,
  101. [Operator.WaitingDialogue]: initialWaitingDialogueValues,
  102. [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
  103. };
  104. }, [llmId]);
  105. const initializeOperatorParams = useCallback(
  106. (operatorName: Operator) => {
  107. return initialFormValuesMap[operatorName];
  108. },
  109. [initialFormValuesMap],
  110. );
  111. return initializeOperatorParams;
  112. };
  113. export const useGetNodeName = () => {
  114. const { t } = useTranslation();
  115. return (type: string) => {
  116. const name = t(`flow.${lowerFirst(type)}`);
  117. return name;
  118. };
  119. };
  120. export function useCalculateNewlyChildPosition() {
  121. const getNode = useGraphStore((state) => state.getNode);
  122. const nodes = useGraphStore((state) => state.nodes);
  123. const edges = useGraphStore((state) => state.edges);
  124. const calculateNewlyBackChildPosition = useCallback(
  125. (id?: string, sourceHandle?: string) => {
  126. const parentNode = getNode(id);
  127. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  128. const allChildNodeIds = edges
  129. .filter((x) => x.source === id && x.sourceHandle === sourceHandle)
  130. .map((x) => x.target);
  131. const yAxises = nodes
  132. .filter((x) => allChildNodeIds.some((y) => y === x.id))
  133. .map((x) => x.position.y);
  134. const maxY = Math.max(...yAxises);
  135. const position = {
  136. y: yAxises.length > 0 ? maxY + 262 : (parentNode?.position.y || 0) + 82,
  137. x: (parentNode?.position.x || 0) + 140,
  138. };
  139. return position;
  140. },
  141. [edges, getNode, nodes],
  142. );
  143. return { calculateNewlyBackChildPosition };
  144. }
  145. export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
  146. const addNode = useGraphStore((state) => state.addNode);
  147. const getNode = useGraphStore((state) => state.getNode);
  148. const addEdge = useGraphStore((state) => state.addEdge);
  149. const nodes = useGraphStore((state) => state.nodes);
  150. const edges = useGraphStore((state) => state.edges);
  151. const getNodeName = useGetNodeName();
  152. const initializeOperatorParams = useInitializeOperatorParams();
  153. const { calculateNewlyBackChildPosition } = useCalculateNewlyChildPosition();
  154. // const [reactFlowInstance, setReactFlowInstance] =
  155. // useState<ReactFlowInstance<any, any>>();
  156. const addCanvasNode = useCallback(
  157. (
  158. type: string,
  159. params: { id?: string; position?: Position; sourceHandle?: string } = {
  160. position: Position.Right,
  161. },
  162. ) =>
  163. (event: React.MouseEvent<HTMLElement>) => {
  164. const id = params.id;
  165. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  166. // and you don't need to subtract the reactFlowBounds.left/top anymore
  167. // details: https://@xyflow/react.dev/whats-new/2023-11-10
  168. let position = reactFlowInstance?.screenToFlowPosition({
  169. x: event.clientX,
  170. y: event.clientY,
  171. });
  172. if (params.position === Position.Right) {
  173. position = calculateNewlyBackChildPosition(id, params.sourceHandle);
  174. }
  175. const newNode: Node<any> = {
  176. id: `${type}:${humanId()}`,
  177. type: NodeMap[type as Operator] || 'ragNode',
  178. position: position || {
  179. x: 0,
  180. y: 0,
  181. },
  182. data: {
  183. label: `${type}`,
  184. name: generateNodeNamesWithIncreasingIndex(
  185. getNodeName(type),
  186. nodes,
  187. ),
  188. form: initializeOperatorParams(type as Operator),
  189. },
  190. sourcePosition: Position.Right,
  191. targetPosition: Position.Left,
  192. dragHandle: getNodeDragHandle(type),
  193. };
  194. if (type === Operator.Iteration) {
  195. newNode.width = 500;
  196. newNode.height = 250;
  197. const iterationStartNode: Node<any> = {
  198. id: `${Operator.IterationStart}:${humanId()}`,
  199. type: 'iterationStartNode',
  200. position: { x: 50, y: 100 },
  201. // draggable: false,
  202. data: {
  203. label: Operator.IterationStart,
  204. name: Operator.IterationStart,
  205. form: {},
  206. },
  207. parentId: newNode.id,
  208. extent: 'parent',
  209. };
  210. addNode(newNode);
  211. addNode(iterationStartNode);
  212. } else if (type === Operator.Agent) {
  213. const agentNode = getNode(id);
  214. if (agentNode) {
  215. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  216. const allChildAgentNodeIds = edges
  217. .filter((x) => x.source === id && x.sourceHandle === 'e')
  218. .map((x) => x.target);
  219. const xAxises = nodes
  220. .filter((x) => allChildAgentNodeIds.some((y) => y === x.id))
  221. .map((x) => x.position.x);
  222. const maxX = Math.max(...xAxises);
  223. newNode.position = {
  224. x: xAxises.length > 0 ? maxX + 262 : agentNode.position.x + 82,
  225. y: agentNode.position.y + 140,
  226. };
  227. }
  228. addNode(newNode);
  229. if (id) {
  230. addEdge({
  231. source: id,
  232. target: newNode.id,
  233. sourceHandle: 'e',
  234. targetHandle: 'f',
  235. });
  236. }
  237. } else {
  238. const subNodeOfIteration = getRelativePositionToIterationNode(
  239. nodes,
  240. position,
  241. );
  242. if (subNodeOfIteration) {
  243. newNode.parentId = subNodeOfIteration.parentId;
  244. newNode.position = subNodeOfIteration.position;
  245. newNode.extent = 'parent';
  246. }
  247. addNode(newNode);
  248. }
  249. },
  250. [
  251. addEdge,
  252. addNode,
  253. edges,
  254. getNode,
  255. getNodeName,
  256. initializeOperatorParams,
  257. nodes,
  258. reactFlowInstance,
  259. ],
  260. );
  261. return { addCanvasNode };
  262. }