選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

use-add-node.ts 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import { useFetchModelId } from '@/hooks/logic-hooks';
  2. import { Node, Position, ReactFlowInstance } from '@xyflow/react';
  3. import humanId from 'human-id';
  4. import { lowerFirst } from 'lodash';
  5. import { useCallback, useMemo } from 'react';
  6. import { useTranslation } from 'react-i18next';
  7. import {
  8. NodeMap,
  9. Operator,
  10. initialAgentValues,
  11. initialAkShareValues,
  12. initialArXivValues,
  13. initialBaiduFanyiValues,
  14. initialBaiduValues,
  15. initialBeginValues,
  16. initialBingValues,
  17. initialCategorizeValues,
  18. initialCodeValues,
  19. initialConcentratorValues,
  20. initialCrawlerValues,
  21. initialDeepLValues,
  22. initialDuckValues,
  23. initialEmailValues,
  24. initialExeSqlValues,
  25. initialGenerateValues,
  26. initialGithubValues,
  27. initialGoogleScholarValues,
  28. initialGoogleValues,
  29. initialInvokeValues,
  30. initialIterationValues,
  31. initialJin10Values,
  32. initialKeywordExtractValues,
  33. initialMessageValues,
  34. initialNoteValues,
  35. initialPubMedValues,
  36. initialQWeatherValues,
  37. initialRelevantValues,
  38. initialRetrievalValues,
  39. initialRewriteQuestionValues,
  40. initialSwitchValues,
  41. initialTemplateValues,
  42. initialTuShareValues,
  43. initialWaitingDialogueValues,
  44. initialWenCaiValues,
  45. initialWikipediaValues,
  46. initialYahooFinanceValues,
  47. } from '../constant';
  48. import useGraphStore from '../store';
  49. import {
  50. generateNodeNamesWithIncreasingIndex,
  51. getNodeDragHandle,
  52. getRelativePositionToIterationNode,
  53. } from '../utils';
  54. export const useInitializeOperatorParams = () => {
  55. const llmId = useFetchModelId();
  56. const initialFormValuesMap = useMemo(() => {
  57. return {
  58. [Operator.Begin]: initialBeginValues,
  59. [Operator.Retrieval]: initialRetrievalValues,
  60. [Operator.Generate]: { ...initialGenerateValues, llm_id: llmId },
  61. [Operator.Answer]: {},
  62. [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId },
  63. [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId },
  64. [Operator.RewriteQuestion]: {
  65. ...initialRewriteQuestionValues,
  66. llm_id: llmId,
  67. },
  68. [Operator.Message]: initialMessageValues,
  69. [Operator.KeywordExtract]: {
  70. ...initialKeywordExtractValues,
  71. llm_id: llmId,
  72. },
  73. [Operator.DuckDuckGo]: initialDuckValues,
  74. [Operator.Baidu]: initialBaiduValues,
  75. [Operator.Wikipedia]: initialWikipediaValues,
  76. [Operator.PubMed]: initialPubMedValues,
  77. [Operator.ArXiv]: initialArXivValues,
  78. [Operator.Google]: initialGoogleValues,
  79. [Operator.Bing]: initialBingValues,
  80. [Operator.GoogleScholar]: initialGoogleScholarValues,
  81. [Operator.DeepL]: initialDeepLValues,
  82. [Operator.GitHub]: initialGithubValues,
  83. [Operator.BaiduFanyi]: initialBaiduFanyiValues,
  84. [Operator.QWeather]: initialQWeatherValues,
  85. [Operator.ExeSQL]: { ...initialExeSqlValues, llm_id: llmId },
  86. [Operator.Switch]: initialSwitchValues,
  87. [Operator.WenCai]: initialWenCaiValues,
  88. [Operator.AkShare]: initialAkShareValues,
  89. [Operator.YahooFinance]: initialYahooFinanceValues,
  90. [Operator.Jin10]: initialJin10Values,
  91. [Operator.Concentrator]: initialConcentratorValues,
  92. [Operator.TuShare]: initialTuShareValues,
  93. [Operator.Note]: initialNoteValues,
  94. [Operator.Crawler]: initialCrawlerValues,
  95. [Operator.Invoke]: initialInvokeValues,
  96. [Operator.Template]: initialTemplateValues,
  97. [Operator.Email]: initialEmailValues,
  98. [Operator.Iteration]: initialIterationValues,
  99. [Operator.IterationStart]: initialIterationValues,
  100. [Operator.Code]: initialCodeValues,
  101. [Operator.WaitingDialogue]: initialWaitingDialogueValues,
  102. [Operator.Agent]: { ...initialAgentValues, llm_id: llmId },
  103. };
  104. }, [llmId]);
  105. const initializeOperatorParams = useCallback(
  106. (operatorName: Operator) => {
  107. return initialFormValuesMap[operatorName];
  108. },
  109. [initialFormValuesMap],
  110. );
  111. return initializeOperatorParams;
  112. };
  113. export const useGetNodeName = () => {
  114. const { t } = useTranslation();
  115. return (type: string) => {
  116. const name = t(`flow.${lowerFirst(type)}`);
  117. return name;
  118. };
  119. };
  120. export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
  121. const addNode = useGraphStore((state) => state.addNode);
  122. const getNode = useGraphStore((state) => state.getNode);
  123. const addEdge = useGraphStore((state) => state.addEdge);
  124. const nodes = useGraphStore((state) => state.nodes);
  125. const edges = useGraphStore((state) => state.edges);
  126. const getNodeName = useGetNodeName();
  127. const initializeOperatorParams = useInitializeOperatorParams();
  128. // const [reactFlowInstance, setReactFlowInstance] =
  129. // useState<ReactFlowInstance<any, any>>();
  130. const addCanvasNode = useCallback(
  131. (type: string, id?: string) => (event: React.MouseEvent<HTMLElement>) => {
  132. // reactFlowInstance.project was renamed to reactFlowInstance.screenToFlowPosition
  133. // and you don't need to subtract the reactFlowBounds.left/top anymore
  134. // details: https://@xyflow/react.dev/whats-new/2023-11-10
  135. const position = reactFlowInstance?.screenToFlowPosition({
  136. x: event.clientX,
  137. y: event.clientY,
  138. });
  139. const newNode: Node<any> = {
  140. id: `${type}:${humanId()}`,
  141. type: NodeMap[type as Operator] || 'ragNode',
  142. position: position || {
  143. x: 0,
  144. y: 0,
  145. },
  146. data: {
  147. label: `${type}`,
  148. name: generateNodeNamesWithIncreasingIndex(getNodeName(type), nodes),
  149. form: initializeOperatorParams(type as Operator),
  150. },
  151. sourcePosition: Position.Right,
  152. targetPosition: Position.Left,
  153. dragHandle: getNodeDragHandle(type),
  154. };
  155. if (type === Operator.Iteration) {
  156. newNode.width = 500;
  157. newNode.height = 250;
  158. const iterationStartNode: Node<any> = {
  159. id: `${Operator.IterationStart}:${humanId()}`,
  160. type: 'iterationStartNode',
  161. position: { x: 50, y: 100 },
  162. // draggable: false,
  163. data: {
  164. label: Operator.IterationStart,
  165. name: Operator.IterationStart,
  166. form: {},
  167. },
  168. parentId: newNode.id,
  169. extent: 'parent',
  170. };
  171. addNode(newNode);
  172. addNode(iterationStartNode);
  173. } else if (type === Operator.Agent) {
  174. const agentNode = getNode(id);
  175. if (agentNode) {
  176. // Calculate the coordinates of child nodes to prevent newly added child nodes from covering other child nodes
  177. const allChildAgentNodeIds = edges
  178. .filter((x) => x.source === id && x.sourceHandle === 'e')
  179. .map((x) => x.target);
  180. const xAxises = nodes
  181. .filter((x) => allChildAgentNodeIds.some((y) => y === x.id))
  182. .map((x) => x.position.x);
  183. const maxX = Math.max(...xAxises);
  184. newNode.position = {
  185. x: xAxises.length > 0 ? maxX + 262 : agentNode.position.x + 82,
  186. y: agentNode.position.y + 140,
  187. };
  188. }
  189. addNode(newNode);
  190. if (id) {
  191. addEdge({
  192. source: id,
  193. target: newNode.id,
  194. sourceHandle: 'e',
  195. targetHandle: 'f',
  196. });
  197. }
  198. } else {
  199. const subNodeOfIteration = getRelativePositionToIterationNode(
  200. nodes,
  201. position,
  202. );
  203. if (subNodeOfIteration) {
  204. newNode.parentId = subNodeOfIteration.parentId;
  205. newNode.position = subNodeOfIteration.position;
  206. newNode.extent = 'parent';
  207. }
  208. addNode(newNode);
  209. }
  210. },
  211. [
  212. addEdge,
  213. addNode,
  214. edges,
  215. getNode,
  216. getNodeName,
  217. initializeOperatorParams,
  218. nodes,
  219. reactFlowInstance,
  220. ],
  221. );
  222. return { addCanvasNode };
  223. }