Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getIncomers,
  5. getOutgoers,
  6. } from 'reactflow'
  7. import dagre from '@dagrejs/dagre'
  8. import { v4 as uuid4 } from 'uuid'
  9. import {
  10. cloneDeep,
  11. groupBy,
  12. isEqual,
  13. uniqBy,
  14. } from 'lodash-es'
  15. import type {
  16. Edge,
  17. InputVar,
  18. Node,
  19. ToolWithProvider,
  20. ValueSelector,
  21. } from './types'
  22. import {
  23. BlockEnum,
  24. ErrorHandleMode,
  25. NodeRunningStatus,
  26. } from './types'
  27. import {
  28. CUSTOM_NODE,
  29. DEFAULT_RETRY_INTERVAL,
  30. DEFAULT_RETRY_MAX,
  31. ITERATION_CHILDREN_Z_INDEX,
  32. ITERATION_NODE_Z_INDEX,
  33. LOOP_CHILDREN_Z_INDEX,
  34. LOOP_NODE_Z_INDEX,
  35. NODE_LAYOUT_HORIZONTAL_PADDING,
  36. NODE_LAYOUT_MIN_DISTANCE,
  37. NODE_LAYOUT_VERTICAL_PADDING,
  38. NODE_WIDTH_X_OFFSET,
  39. START_INITIAL_POSITION,
  40. } from './constants'
  41. import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants'
  42. import { CUSTOM_LOOP_START_NODE } from './nodes/loop-start/constants'
  43. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  44. import type { IfElseNodeType } from './nodes/if-else/types'
  45. import { branchNameCorrect } from './nodes/if-else/utils'
  46. import type { ToolNodeType } from './nodes/tool/types'
  47. import type { IterationNodeType } from './nodes/iteration/types'
  48. import type { LoopNodeType } from './nodes/loop/types'
  49. import { CollectionType } from '@/app/components/tools/types'
  50. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  51. import { canFindTool, correctModelProvider } from '@/utils'
  52. import { CUSTOM_SIMPLE_NODE } from '@/app/components/workflow/simple-node/constants'
  53. const WHITE = 'WHITE'
  54. const GRAY = 'GRAY'
  55. const BLACK = 'BLACK'
  56. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjList: Record<string, string[]>, stack: string[]) => {
  57. color[nodeId] = GRAY
  58. stack.push(nodeId)
  59. for (let i = 0; i < adjList[nodeId].length; ++i) {
  60. const childId = adjList[nodeId][i]
  61. if (color[childId] === GRAY) {
  62. stack.push(childId)
  63. return true
  64. }
  65. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjList, stack))
  66. return true
  67. }
  68. color[nodeId] = BLACK
  69. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  70. stack.pop()
  71. return false
  72. }
  73. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  74. const adjList: Record<string, string[]> = {}
  75. const color: Record<string, string> = {}
  76. const stack: string[] = []
  77. for (const node of nodes) {
  78. color[node.id] = WHITE
  79. adjList[node.id] = []
  80. }
  81. for (const edge of edges)
  82. adjList[edge.source]?.push(edge.target)
  83. for (let i = 0; i < nodes.length; i++) {
  84. if (color[nodes[i].id] === WHITE)
  85. isCyclicUtil(nodes[i].id, color, adjList, stack)
  86. }
  87. const cycleEdges = []
  88. if (stack.length > 0) {
  89. const cycleNodes = new Set(stack)
  90. for (const edge of edges) {
  91. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  92. cycleEdges.push(edge)
  93. }
  94. }
  95. return cycleEdges
  96. }
  97. export function getIterationStartNode(iterationId: string): Node {
  98. return generateNewNode({
  99. id: `${iterationId}start`,
  100. type: CUSTOM_ITERATION_START_NODE,
  101. data: {
  102. title: '',
  103. desc: '',
  104. type: BlockEnum.IterationStart,
  105. isInIteration: true,
  106. },
  107. position: {
  108. x: 24,
  109. y: 68,
  110. },
  111. zIndex: ITERATION_CHILDREN_Z_INDEX,
  112. parentId: iterationId,
  113. selectable: false,
  114. draggable: false,
  115. }).newNode
  116. }
  117. export function getLoopStartNode(loopId: string): Node {
  118. return generateNewNode({
  119. id: `${loopId}start`,
  120. type: CUSTOM_LOOP_START_NODE,
  121. data: {
  122. title: '',
  123. desc: '',
  124. type: BlockEnum.LoopStart,
  125. isInLoop: true,
  126. },
  127. position: {
  128. x: 24,
  129. y: 68,
  130. },
  131. zIndex: LOOP_CHILDREN_Z_INDEX,
  132. parentId: loopId,
  133. selectable: false,
  134. draggable: false,
  135. }).newNode
  136. }
  137. export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit<Node, 'id'> & { id?: string }): {
  138. newNode: Node
  139. newIterationStartNode?: Node
  140. newLoopStartNode?: Node
  141. } {
  142. const newNode = {
  143. id: id || `${Date.now()}`,
  144. type: type || CUSTOM_NODE,
  145. data,
  146. position,
  147. targetPosition: Position.Left,
  148. sourcePosition: Position.Right,
  149. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : (data.type === BlockEnum.Loop ? LOOP_NODE_Z_INDEX : zIndex),
  150. ...rest,
  151. } as Node
  152. if (data.type === BlockEnum.Iteration) {
  153. const newIterationStartNode = getIterationStartNode(newNode.id);
  154. (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id;
  155. (newNode.data as IterationNodeType)._children = [{ nodeId: newIterationStartNode.id, nodeType: BlockEnum.IterationStart }]
  156. return {
  157. newNode,
  158. newIterationStartNode,
  159. }
  160. }
  161. if (data.type === BlockEnum.Loop) {
  162. const newLoopStartNode = getLoopStartNode(newNode.id);
  163. (newNode.data as LoopNodeType).start_node_id = newLoopStartNode.id;
  164. (newNode.data as LoopNodeType)._children = [{ nodeId: newLoopStartNode.id, nodeType: BlockEnum.LoopStart }]
  165. return {
  166. newNode,
  167. newLoopStartNode,
  168. }
  169. }
  170. return {
  171. newNode,
  172. }
  173. }
  174. export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
  175. const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration)
  176. const hasLoopNode = nodes.some(node => node.data.type === BlockEnum.Loop)
  177. if (!hasIterationNode) {
  178. return {
  179. nodes,
  180. edges,
  181. }
  182. }
  183. if (!hasLoopNode) {
  184. return {
  185. nodes,
  186. edges,
  187. }
  188. }
  189. const nodesMap = nodes.reduce((prev, next) => {
  190. prev[next.id] = next
  191. return prev
  192. }, {} as Record<string, Node>)
  193. const iterationNodesWithStartNode = []
  194. const iterationNodesWithoutStartNode = []
  195. const loopNodesWithStartNode = []
  196. const loopNodesWithoutStartNode = []
  197. for (let i = 0; i < nodes.length; i++) {
  198. const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType>
  199. if (currentNode.data.type === BlockEnum.Iteration) {
  200. if (currentNode.data.start_node_id) {
  201. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE)
  202. iterationNodesWithStartNode.push(currentNode)
  203. }
  204. else {
  205. iterationNodesWithoutStartNode.push(currentNode)
  206. }
  207. }
  208. if (currentNode.data.type === BlockEnum.Loop) {
  209. if (currentNode.data.start_node_id) {
  210. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE)
  211. loopNodesWithStartNode.push(currentNode)
  212. }
  213. else {
  214. loopNodesWithoutStartNode.push(currentNode)
  215. }
  216. }
  217. }
  218. const newIterationStartNodesMap = {} as Record<string, Node>
  219. const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => {
  220. const newNode = getIterationStartNode(iterationNode.id)
  221. newNode.id = newNode.id + index
  222. newIterationStartNodesMap[iterationNode.id] = newNode
  223. return newNode
  224. })
  225. const newLoopStartNodesMap = {} as Record<string, Node>
  226. const newLoopStartNodes = [...loopNodesWithStartNode, ...loopNodesWithoutStartNode].map((loopNode, index) => {
  227. const newNode = getLoopStartNode(loopNode.id)
  228. newNode.id = newNode.id + index
  229. newLoopStartNodesMap[loopNode.id] = newNode
  230. return newNode
  231. })
  232. const newEdges = [...iterationNodesWithStartNode, ...loopNodesWithStartNode].map((nodeItem) => {
  233. const isIteration = nodeItem.data.type === BlockEnum.Iteration
  234. const newNode = (isIteration ? newIterationStartNodesMap : newLoopStartNodesMap)[nodeItem.id]
  235. const startNode = nodesMap[nodeItem.data.start_node_id]
  236. const source = newNode.id
  237. const sourceHandle = 'source'
  238. const target = startNode.id
  239. const targetHandle = 'target'
  240. const parentNode = nodes.find(node => node.id === startNode.parentId) || null
  241. const isInIteration = !!parentNode && parentNode.data.type === BlockEnum.Iteration
  242. const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
  243. return {
  244. id: `${source}-${sourceHandle}-${target}-${targetHandle}`,
  245. type: 'custom',
  246. source,
  247. sourceHandle,
  248. target,
  249. targetHandle,
  250. data: {
  251. sourceType: newNode.data.type,
  252. targetType: startNode.data.type,
  253. isInIteration,
  254. iteration_id: isInIteration ? startNode.parentId : undefined,
  255. isInLoop,
  256. loop_id: isInLoop ? startNode.parentId : undefined,
  257. _connectedNodeIsSelected: true,
  258. },
  259. zIndex: isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX,
  260. }
  261. })
  262. nodes.forEach((node) => {
  263. if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id])
  264. (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id
  265. if (node.data.type === BlockEnum.Loop && newLoopStartNodesMap[node.id])
  266. (node.data as LoopNodeType).start_node_id = newLoopStartNodesMap[node.id].id
  267. })
  268. return {
  269. nodes: [...nodes, ...newIterationStartNodes, ...newLoopStartNodes],
  270. edges: [...edges, ...newEdges],
  271. }
  272. }
  273. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  274. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  275. const firstNode = nodes[0]
  276. if (!firstNode?.position) {
  277. nodes.forEach((node, index) => {
  278. node.position = {
  279. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  280. y: START_INITIAL_POSITION.y,
  281. }
  282. })
  283. }
  284. const iterationOrLoopNodeMap = nodes.reduce((acc, node) => {
  285. if (node.parentId) {
  286. if (acc[node.parentId])
  287. acc[node.parentId].push({ nodeId: node.id, nodeType: node.data.type })
  288. else
  289. acc[node.parentId] = [{ nodeId: node.id, nodeType: node.data.type }]
  290. }
  291. return acc
  292. }, {} as Record<string, { nodeId: string; nodeType: BlockEnum }[]>)
  293. return nodes.map((node) => {
  294. if (!node.type)
  295. node.type = CUSTOM_NODE
  296. const connectedEdges = getConnectedEdges([node], edges)
  297. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  298. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  299. if (node.data.type === BlockEnum.IfElse) {
  300. const nodeData = node.data as IfElseNodeType
  301. if (!nodeData.cases && nodeData.logical_operator && nodeData.conditions) {
  302. (node.data as IfElseNodeType).cases = [
  303. {
  304. case_id: 'true',
  305. logical_operator: nodeData.logical_operator,
  306. conditions: nodeData.conditions,
  307. },
  308. ]
  309. }
  310. node.data._targetBranches = branchNameCorrect([
  311. ...(node.data as IfElseNodeType).cases.map(item => ({ id: item.case_id, name: '' })),
  312. { id: 'false', name: '' },
  313. ])
  314. }
  315. if (node.data.type === BlockEnum.QuestionClassifier) {
  316. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  317. return topic
  318. })
  319. }
  320. if (node.data.type === BlockEnum.Iteration) {
  321. const iterationNodeData = node.data as IterationNodeType
  322. iterationNodeData._children = iterationOrLoopNodeMap[node.id] || []
  323. iterationNodeData.is_parallel = iterationNodeData.is_parallel || false
  324. iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10
  325. iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated
  326. }
  327. // TODO: loop error handle mode
  328. if (node.data.type === BlockEnum.Loop) {
  329. const loopNodeData = node.data as LoopNodeType
  330. loopNodeData._children = iterationOrLoopNodeMap[node.id] || []
  331. loopNodeData.error_handle_mode = loopNodeData.error_handle_mode || ErrorHandleMode.Terminated
  332. }
  333. // legacy provider handle
  334. if (node.data.type === BlockEnum.LLM)
  335. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  336. if (node.data.type === BlockEnum.KnowledgeRetrieval && (node as any).data.multiple_retrieval_config?.reranking_model)
  337. (node as any).data.multiple_retrieval_config.reranking_model.provider = correctModelProvider((node as any).data.multiple_retrieval_config?.reranking_model.provider)
  338. if (node.data.type === BlockEnum.QuestionClassifier)
  339. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  340. if (node.data.type === BlockEnum.ParameterExtractor)
  341. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  342. if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) {
  343. node.data.retry_config = {
  344. retry_enabled: true,
  345. max_retries: DEFAULT_RETRY_MAX,
  346. retry_interval: DEFAULT_RETRY_INTERVAL,
  347. }
  348. }
  349. return node
  350. })
  351. }
  352. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  353. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  354. let selectedNode: Node | null = null
  355. const nodesMap = nodes.reduce((acc, node) => {
  356. acc[node.id] = node
  357. if (node.data?.selected)
  358. selectedNode = node
  359. return acc
  360. }, {} as Record<string, Node>)
  361. const cycleEdges = getCycleEdges(nodes, edges)
  362. return edges.filter((edge) => {
  363. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  364. }).map((edge) => {
  365. edge.type = 'custom'
  366. if (!edge.sourceHandle)
  367. edge.sourceHandle = 'source'
  368. if (!edge.targetHandle)
  369. edge.targetHandle = 'target'
  370. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  371. edge.data = {
  372. ...edge.data,
  373. sourceType: nodesMap[edge.source].data.type!,
  374. } as any
  375. }
  376. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  377. edge.data = {
  378. ...edge.data,
  379. targetType: nodesMap[edge.target].data.type!,
  380. } as any
  381. }
  382. if (selectedNode) {
  383. edge.data = {
  384. ...edge.data,
  385. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  386. } as any
  387. }
  388. return edge
  389. })
  390. }
  391. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  392. const dagreGraph = new dagre.graphlib.Graph()
  393. dagreGraph.setDefaultEdgeLabel(() => ({}))
  394. const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE)
  395. const edges = cloneDeep(originEdges).filter(edge => (!edge.data?.isInIteration && !edge.data?.isInLoop))
  396. dagreGraph.setGraph({
  397. rankdir: 'LR',
  398. align: 'UL',
  399. nodesep: 40,
  400. ranksep: 60,
  401. ranker: 'tight-tree',
  402. marginx: 30,
  403. marginy: 200,
  404. })
  405. nodes.forEach((node) => {
  406. dagreGraph.setNode(node.id, {
  407. width: node.width!,
  408. height: node.height!,
  409. })
  410. })
  411. edges.forEach((edge) => {
  412. dagreGraph.setEdge(edge.source, edge.target)
  413. })
  414. dagre.layout(dagreGraph)
  415. return dagreGraph
  416. }
  417. export const getLayoutForChildNodes = (parentNodeId: string, originNodes: Node[], originEdges: Edge[]) => {
  418. const dagreGraph = new dagre.graphlib.Graph()
  419. dagreGraph.setDefaultEdgeLabel(() => ({}))
  420. const nodes = cloneDeep(originNodes).filter(node => node.parentId === parentNodeId)
  421. const edges = cloneDeep(originEdges).filter(edge =>
  422. (edge.data?.isInIteration && edge.data?.iteration_id === parentNodeId)
  423. || (edge.data?.isInLoop && edge.data?.loop_id === parentNodeId),
  424. )
  425. const startNode = nodes.find(node =>
  426. node.type === CUSTOM_ITERATION_START_NODE
  427. || node.type === CUSTOM_LOOP_START_NODE
  428. || node.data?.type === BlockEnum.LoopStart
  429. || node.data?.type === BlockEnum.IterationStart,
  430. )
  431. if (!startNode) {
  432. dagreGraph.setGraph({
  433. rankdir: 'LR',
  434. align: 'UL',
  435. nodesep: 40,
  436. ranksep: 60,
  437. marginx: NODE_LAYOUT_HORIZONTAL_PADDING,
  438. marginy: NODE_LAYOUT_VERTICAL_PADDING,
  439. })
  440. nodes.forEach((node) => {
  441. dagreGraph.setNode(node.id, {
  442. width: node.width || 244,
  443. height: node.height || 100,
  444. })
  445. })
  446. edges.forEach((edge) => {
  447. dagreGraph.setEdge(edge.source, edge.target)
  448. })
  449. dagre.layout(dagreGraph)
  450. return dagreGraph
  451. }
  452. const startNodeOutEdges = edges.filter(edge => edge.source === startNode.id)
  453. const firstConnectedNodes = startNodeOutEdges.map(edge =>
  454. nodes.find(node => node.id === edge.target),
  455. ).filter(Boolean) as Node[]
  456. const nonStartNodes = nodes.filter(node => node.id !== startNode.id)
  457. const nonStartEdges = edges.filter(edge => edge.source !== startNode.id && edge.target !== startNode.id)
  458. dagreGraph.setGraph({
  459. rankdir: 'LR',
  460. align: 'UL',
  461. nodesep: 40,
  462. ranksep: 60,
  463. marginx: NODE_LAYOUT_HORIZONTAL_PADDING / 2,
  464. marginy: NODE_LAYOUT_VERTICAL_PADDING / 2,
  465. })
  466. nonStartNodes.forEach((node) => {
  467. dagreGraph.setNode(node.id, {
  468. width: node.width || 244,
  469. height: node.height || 100,
  470. })
  471. })
  472. nonStartEdges.forEach((edge) => {
  473. dagreGraph.setEdge(edge.source, edge.target)
  474. })
  475. dagre.layout(dagreGraph)
  476. const startNodeSize = {
  477. width: startNode.width || 44,
  478. height: startNode.height || 48,
  479. }
  480. const startNodeX = NODE_LAYOUT_HORIZONTAL_PADDING / 1.5
  481. let startNodeY = 100
  482. let minFirstLayerX = Infinity
  483. let avgFirstLayerY = 0
  484. let firstLayerCount = 0
  485. if (firstConnectedNodes.length > 0) {
  486. firstConnectedNodes.forEach((node) => {
  487. if (dagreGraph.node(node.id)) {
  488. const nodePos = dagreGraph.node(node.id)
  489. avgFirstLayerY += nodePos.y
  490. firstLayerCount++
  491. minFirstLayerX = Math.min(minFirstLayerX, nodePos.x - nodePos.width / 2)
  492. }
  493. })
  494. if (firstLayerCount > 0) {
  495. avgFirstLayerY /= firstLayerCount
  496. startNodeY = avgFirstLayerY
  497. }
  498. const minRequiredX = startNodeX + startNodeSize.width + NODE_LAYOUT_MIN_DISTANCE
  499. if (minFirstLayerX < minRequiredX) {
  500. const shiftX = minRequiredX - minFirstLayerX
  501. nonStartNodes.forEach((node) => {
  502. if (dagreGraph.node(node.id)) {
  503. const nodePos = dagreGraph.node(node.id)
  504. dagreGraph.setNode(node.id, {
  505. x: nodePos.x + shiftX,
  506. y: nodePos.y,
  507. width: nodePos.width,
  508. height: nodePos.height,
  509. })
  510. }
  511. })
  512. }
  513. }
  514. dagreGraph.setNode(startNode.id, {
  515. x: startNodeX + startNodeSize.width / 2,
  516. y: startNodeY,
  517. width: startNodeSize.width,
  518. height: startNodeSize.height,
  519. })
  520. startNodeOutEdges.forEach((edge) => {
  521. dagreGraph.setEdge(edge.source, edge.target)
  522. })
  523. return dagreGraph
  524. }
  525. export const canRunBySingle = (nodeType: BlockEnum) => {
  526. return nodeType === BlockEnum.LLM
  527. || nodeType === BlockEnum.KnowledgeRetrieval
  528. || nodeType === BlockEnum.Code
  529. || nodeType === BlockEnum.TemplateTransform
  530. || nodeType === BlockEnum.QuestionClassifier
  531. || nodeType === BlockEnum.HttpRequest
  532. || nodeType === BlockEnum.Tool
  533. || nodeType === BlockEnum.ParameterExtractor
  534. || nodeType === BlockEnum.Iteration
  535. || nodeType === BlockEnum.Agent
  536. || nodeType === BlockEnum.DocExtractor
  537. || nodeType === BlockEnum.Loop
  538. }
  539. type ConnectedSourceOrTargetNodesChange = {
  540. type: string
  541. edge: Edge
  542. }[]
  543. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  544. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  545. changes.forEach((change) => {
  546. const {
  547. edge,
  548. type,
  549. } = change
  550. const sourceNode = nodes.find(node => node.id === edge.source)!
  551. if (sourceNode) {
  552. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  553. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  554. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  555. }
  556. }
  557. const targetNode = nodes.find(node => node.id === edge.target)!
  558. if (targetNode) {
  559. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  560. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  561. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  562. }
  563. }
  564. if (sourceNode) {
  565. if (type === 'remove') {
  566. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  567. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  568. }
  569. if (type === 'add')
  570. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  571. }
  572. if (targetNode) {
  573. if (type === 'remove') {
  574. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  575. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  576. }
  577. if (type === 'add')
  578. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  579. }
  580. })
  581. return nodesConnectedSourceOrTargetHandleIdsMap
  582. }
  583. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  584. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  585. const match = oldTitle.match(regex)
  586. if (match) {
  587. const title = match[1]
  588. const num = Number.parseInt(match[2], 10)
  589. return `${title} (${num + 1})`
  590. }
  591. else {
  592. return `${oldTitle} (1)`
  593. }
  594. }
  595. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  596. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  597. if (!startNode) {
  598. return {
  599. validNodes: [],
  600. maxDepth: 0,
  601. }
  602. }
  603. const list: Node[] = [startNode]
  604. let maxDepth = 1
  605. const traverse = (root: Node, depth: number) => {
  606. if (depth > maxDepth)
  607. maxDepth = depth
  608. const outgoers = getOutgoers(root, nodes, edges)
  609. if (outgoers.length) {
  610. outgoers.forEach((outgoer) => {
  611. list.push(outgoer)
  612. if (outgoer.data.type === BlockEnum.Iteration)
  613. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  614. if (outgoer.data.type === BlockEnum.Loop)
  615. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  616. traverse(outgoer, depth + 1)
  617. })
  618. }
  619. else {
  620. list.push(root)
  621. if (root.data.type === BlockEnum.Iteration)
  622. list.push(...nodes.filter(node => node.parentId === root.id))
  623. if (root.data.type === BlockEnum.Loop)
  624. list.push(...nodes.filter(node => node.parentId === root.id))
  625. }
  626. }
  627. traverse(startNode, maxDepth)
  628. return {
  629. validNodes: uniqBy(list, 'id'),
  630. maxDepth,
  631. }
  632. }
  633. export const getToolCheckParams = (
  634. toolData: ToolNodeType,
  635. buildInTools: ToolWithProvider[],
  636. customTools: ToolWithProvider[],
  637. workflowTools: ToolWithProvider[],
  638. language: string,
  639. ) => {
  640. const { provider_id, provider_type, tool_name } = toolData
  641. const isBuiltIn = provider_type === CollectionType.builtIn
  642. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  643. const currCollection = currentTools.find(item => canFindTool(item.id, provider_id))
  644. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  645. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  646. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  647. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  648. return {
  649. toolInputsSchema: (() => {
  650. const formInputs: InputVar[] = []
  651. toolInputVarSchema.forEach((item: any) => {
  652. formInputs.push({
  653. label: item.label[language] || item.label.en_US,
  654. variable: item.variable,
  655. type: item.type,
  656. required: item.required,
  657. })
  658. })
  659. return formInputs
  660. })(),
  661. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  662. toolSettingSchema,
  663. language,
  664. }
  665. }
  666. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  667. const idMap = nodes.reduce((acc, node) => {
  668. acc[node.id] = uuid4()
  669. return acc
  670. }, {} as Record<string, string>)
  671. const newNodes = nodes.map((node) => {
  672. return {
  673. ...node,
  674. id: idMap[node.id],
  675. }
  676. })
  677. const newEdges = edges.map((edge) => {
  678. return {
  679. ...edge,
  680. source: idMap[edge.source],
  681. target: idMap[edge.target],
  682. }
  683. })
  684. return [newNodes, newEdges] as [Node[], Edge[]]
  685. }
  686. export const isMac = () => {
  687. return navigator.userAgent.toUpperCase().includes('MAC')
  688. }
  689. const specialKeysNameMap: Record<string, string | undefined> = {
  690. ctrl: '⌘',
  691. alt: '⌥',
  692. shift: '⇧',
  693. }
  694. export const getKeyboardKeyNameBySystem = (key: string) => {
  695. if (isMac())
  696. return specialKeysNameMap[key] || key
  697. return key
  698. }
  699. const specialKeysCodeMap: Record<string, string | undefined> = {
  700. ctrl: 'meta',
  701. }
  702. export const getKeyboardKeyCodeBySystem = (key: string) => {
  703. if (isMac())
  704. return specialKeysCodeMap[key] || key
  705. return key
  706. }
  707. export const getTopLeftNodePosition = (nodes: Node[]) => {
  708. let minX = Infinity
  709. let minY = Infinity
  710. nodes.forEach((node) => {
  711. if (node.position.x < minX)
  712. minX = node.position.x
  713. if (node.position.y < minY)
  714. minY = node.position.y
  715. })
  716. return {
  717. x: minX,
  718. y: minY,
  719. }
  720. }
  721. export const isEventTargetInputArea = (target: HTMLElement) => {
  722. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  723. return true
  724. if (target.contentEditable === 'true')
  725. return true
  726. }
  727. export const variableTransformer = (v: ValueSelector | string) => {
  728. if (typeof v === 'string')
  729. return v.replace(/^{{#|#}}$/g, '').split('.')
  730. return `{{#${v.join('.')}#}}`
  731. }
  732. type ParallelInfoItem = {
  733. parallelNodeId: string
  734. depth: number
  735. isBranch?: boolean
  736. }
  737. type NodeParallelInfo = {
  738. parallelNodeId: string
  739. edgeHandleId: string
  740. depth: number
  741. }
  742. type NodeHandle = {
  743. node: Node
  744. handle: string
  745. }
  746. type NodeStreamInfo = {
  747. upstreamNodes: Set<string>
  748. downstreamEdges: Set<string>
  749. }
  750. export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => {
  751. let startNode
  752. if (parentNodeId) {
  753. const parentNode = nodes.find(node => node.id === parentNodeId)
  754. if (!parentNode)
  755. throw new Error('Parent node not found')
  756. startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id)
  757. }
  758. else {
  759. startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  760. }
  761. if (!startNode)
  762. throw new Error('Start node not found')
  763. const parallelList = [] as ParallelInfoItem[]
  764. const nextNodeHandles = [{ node: startNode, handle: 'source' }]
  765. let hasAbnormalEdges = false
  766. const traverse = (firstNodeHandle: NodeHandle) => {
  767. const nodeEdgesSet = {} as Record<string, Set<string>>
  768. const totalEdgesSet = new Set<string>()
  769. const nextHandles = [firstNodeHandle]
  770. const streamInfo = {} as Record<string, NodeStreamInfo>
  771. const parallelListItem = {
  772. parallelNodeId: '',
  773. depth: 0,
  774. } as ParallelInfoItem
  775. const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
  776. nodeParallelInfoMap[firstNodeHandle.node.id] = {
  777. parallelNodeId: '',
  778. edgeHandleId: '',
  779. depth: 0,
  780. }
  781. while (nextHandles.length) {
  782. const currentNodeHandle = nextHandles.shift()!
  783. const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
  784. const currentNodeHandleKey = currentNode.id
  785. const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
  786. const connectedEdgesLength = connectedEdges.length
  787. const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
  788. const incomers = getIncomers(currentNode, nodes, edges)
  789. if (!streamInfo[currentNodeHandleKey]) {
  790. streamInfo[currentNodeHandleKey] = {
  791. upstreamNodes: new Set<string>(),
  792. downstreamEdges: new Set<string>(),
  793. }
  794. }
  795. if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
  796. const newSet = new Set<string>()
  797. for (const item of totalEdgesSet) {
  798. if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
  799. newSet.add(item)
  800. }
  801. if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
  802. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  803. nextNodeHandles.push({ node: currentNode, handle: currentHandle })
  804. break
  805. }
  806. }
  807. if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
  808. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  809. outgoers.forEach((outgoer) => {
  810. const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
  811. const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
  812. const incomers = getIncomers(outgoer, nodes, edges)
  813. if (outgoers.length > 1 && incomers.length > 1)
  814. hasAbnormalEdges = true
  815. Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
  816. nextHandles.push({ node: outgoer, handle: sourceHandle })
  817. })
  818. if (!outgoerConnectedEdges.length)
  819. nextHandles.push({ node: outgoer, handle: 'source' })
  820. const outgoerKey = outgoer.id
  821. if (!nodeEdgesSet[outgoerKey])
  822. nodeEdgesSet[outgoerKey] = new Set<string>()
  823. if (nodeEdgesSet[currentNodeHandleKey]) {
  824. for (const item of nodeEdgesSet[currentNodeHandleKey])
  825. nodeEdgesSet[outgoerKey].add(item)
  826. }
  827. if (!streamInfo[outgoerKey]) {
  828. streamInfo[outgoerKey] = {
  829. upstreamNodes: new Set<string>(),
  830. downstreamEdges: new Set<string>(),
  831. }
  832. }
  833. if (!nodeParallelInfoMap[outgoer.id]) {
  834. nodeParallelInfoMap[outgoer.id] = {
  835. ...nodeParallelInfoMap[currentNode.id],
  836. }
  837. }
  838. if (connectedEdgesLength > 1) {
  839. const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
  840. nodeEdgesSet[outgoerKey].add(edge.id)
  841. totalEdgesSet.add(edge.id)
  842. streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
  843. streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
  844. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  845. streamInfo[item].downstreamEdges.add(edge.id)
  846. if (!parallelListItem.parallelNodeId)
  847. parallelListItem.parallelNodeId = currentNode.id
  848. const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
  849. const currentDepth = nodeParallelInfoMap[outgoer.id].depth
  850. nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
  851. }
  852. else {
  853. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  854. streamInfo[outgoerKey].upstreamNodes.add(item)
  855. nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
  856. }
  857. })
  858. }
  859. parallelList.push(parallelListItem)
  860. }
  861. while (nextNodeHandles.length) {
  862. const nodeHandle = nextNodeHandles.shift()!
  863. traverse(nodeHandle)
  864. }
  865. return {
  866. parallelList,
  867. hasAbnormalEdges,
  868. }
  869. }
  870. export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
  871. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  872. }
  873. export const getEdgeColor = (nodeRunningStatus?: NodeRunningStatus, isFailBranch?: boolean) => {
  874. if (nodeRunningStatus === NodeRunningStatus.Succeeded)
  875. return 'var(--color-workflow-link-line-success-handle)'
  876. if (nodeRunningStatus === NodeRunningStatus.Failed)
  877. return 'var(--color-workflow-link-line-error-handle)'
  878. if (nodeRunningStatus === NodeRunningStatus.Exception)
  879. return 'var(--color-workflow-link-line-failure-handle)'
  880. if (nodeRunningStatus === NodeRunningStatus.Running) {
  881. if (isFailBranch)
  882. return 'var(--color-workflow-link-line-failure-handle)'
  883. return 'var(--color-workflow-link-line-handle)'
  884. }
  885. return 'var(--color-workflow-link-line-normal)'
  886. }
  887. export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => {
  888. if ((variable === 'error_message' || variable === 'error_type') && hasErrorHandleNode(nodeType))
  889. return true
  890. return false
  891. }
  892. export const hasRetryNode = (nodeType?: BlockEnum) => {
  893. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  894. }
  895. export const getNodeCustomTypeByNodeDataType = (nodeType: BlockEnum) => {
  896. if (nodeType === BlockEnum.LoopEnd)
  897. return CUSTOM_SIMPLE_NODE
  898. }