You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.ts 10KB


  1. import {
  2. getConnectedEdges,
  3. getIncomers,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import { v4 as uuid4 } from 'uuid'
  7. import {
  8. groupBy,
  9. isEqual,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. Edge,
  14. Node,
  15. } from '../types'
  16. import {
  17. BlockEnum,
  18. } from '../types'
  19. export const canRunBySingle = (nodeType: BlockEnum, isChildNode: boolean) => {
  20. // child node means in iteration or loop. Set value to iteration(or loop) may cause variable not exit problem in backend.
  21. if(isChildNode && nodeType === BlockEnum.Assigner)
  22. return false
  23. return nodeType === BlockEnum.LLM
  24. || nodeType === BlockEnum.KnowledgeRetrieval
  25. || nodeType === BlockEnum.Code
  26. || nodeType === BlockEnum.TemplateTransform
  27. || nodeType === BlockEnum.QuestionClassifier
  28. || nodeType === BlockEnum.HttpRequest
  29. || nodeType === BlockEnum.Tool
  30. || nodeType === BlockEnum.ParameterExtractor
  31. || nodeType === BlockEnum.Iteration
  32. || nodeType === BlockEnum.Agent
  33. || nodeType === BlockEnum.DocExtractor
  34. || nodeType === BlockEnum.Loop
  35. || nodeType === BlockEnum.Start
  36. || nodeType === BlockEnum.IfElse
  37. || nodeType === BlockEnum.VariableAggregator
  38. || nodeType === BlockEnum.Assigner
  39. || nodeType === BlockEnum.DataSource
  40. }
  41. export const isSupportCustomRunForm = (nodeType: BlockEnum) => {
  42. return nodeType === BlockEnum.DataSource
  43. }
  44. type ConnectedSourceOrTargetNodesChange = {
  45. type: string
  46. edge: Edge
  47. }[]
  48. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  49. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  50. changes.forEach((change) => {
  51. const {
  52. edge,
  53. type,
  54. } = change
  55. const sourceNode = nodes.find(node => node.id === edge.source)!
  56. if (sourceNode) {
  57. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  58. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  59. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  60. }
  61. }
  62. const targetNode = nodes.find(node => node.id === edge.target)!
  63. if (targetNode) {
  64. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  65. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  66. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  67. }
  68. }
  69. if (sourceNode) {
  70. if (type === 'remove') {
  71. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  72. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  73. }
  74. if (type === 'add')
  75. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  76. }
  77. if (targetNode) {
  78. if (type === 'remove') {
  79. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  80. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  81. }
  82. if (type === 'add')
  83. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  84. }
  85. })
  86. return nodesConnectedSourceOrTargetHandleIdsMap
  87. }
  88. export const getValidTreeNodes = (startNode: Node, nodes: Node[], edges: Edge[]) => {
  89. if (!startNode) {
  90. return {
  91. validNodes: [],
  92. maxDepth: 0,
  93. }
  94. }
  95. const list: Node[] = [startNode]
  96. let maxDepth = 1
  97. const traverse = (root: Node, depth: number) => {
  98. if (depth > maxDepth)
  99. maxDepth = depth
  100. const outgoers = getOutgoers(root, nodes, edges)
  101. if (outgoers.length) {
  102. outgoers.forEach((outgoer) => {
  103. list.push(outgoer)
  104. if (outgoer.data.type === BlockEnum.Iteration)
  105. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  106. if (outgoer.data.type === BlockEnum.Loop)
  107. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  108. traverse(outgoer, depth + 1)
  109. })
  110. }
  111. else {
  112. list.push(root)
  113. if (root.data.type === BlockEnum.Iteration)
  114. list.push(...nodes.filter(node => node.parentId === root.id))
  115. if (root.data.type === BlockEnum.Loop)
  116. list.push(...nodes.filter(node => node.parentId === root.id))
  117. }
  118. }
  119. traverse(startNode, maxDepth)
  120. return {
  121. validNodes: uniqBy(list, 'id'),
  122. maxDepth,
  123. }
  124. }
  125. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  126. const idMap = nodes.reduce((acc, node) => {
  127. acc[node.id] = uuid4()
  128. return acc
  129. }, {} as Record<string, string>)
  130. const newNodes = nodes.map((node) => {
  131. return {
  132. ...node,
  133. id: idMap[node.id],
  134. }
  135. })
  136. const newEdges = edges.map((edge) => {
  137. return {
  138. ...edge,
  139. source: idMap[edge.source],
  140. target: idMap[edge.target],
  141. }
  142. })
  143. return [newNodes, newEdges] as [Node[], Edge[]]
  144. }
  145. type ParallelInfoItem = {
  146. parallelNodeId: string
  147. depth: number
  148. isBranch?: boolean
  149. }
  150. type NodeParallelInfo = {
  151. parallelNodeId: string
  152. edgeHandleId: string
  153. depth: number
  154. }
  155. type NodeHandle = {
  156. node: Node
  157. handle: string
  158. }
  159. type NodeStreamInfo = {
  160. upstreamNodes: Set<string>
  161. downstreamEdges: Set<string>
  162. }
  163. export const getParallelInfo = (startNode: Node, nodes: Node[], edges: Edge[]) => {
  164. if (!startNode)
  165. throw new Error('Start node not found')
  166. const parallelList = [] as ParallelInfoItem[]
  167. const nextNodeHandles = [{ node: startNode, handle: 'source' }]
  168. let hasAbnormalEdges = false
  169. const traverse = (firstNodeHandle: NodeHandle) => {
  170. const nodeEdgesSet = {} as Record<string, Set<string>>
  171. const totalEdgesSet = new Set<string>()
  172. const nextHandles = [firstNodeHandle]
  173. const streamInfo = {} as Record<string, NodeStreamInfo>
  174. const parallelListItem = {
  175. parallelNodeId: '',
  176. depth: 0,
  177. } as ParallelInfoItem
  178. const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
  179. nodeParallelInfoMap[firstNodeHandle.node.id] = {
  180. parallelNodeId: '',
  181. edgeHandleId: '',
  182. depth: 0,
  183. }
  184. while (nextHandles.length) {
  185. const currentNodeHandle = nextHandles.shift()!
  186. const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
  187. const currentNodeHandleKey = currentNode.id
  188. const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
  189. const connectedEdgesLength = connectedEdges.length
  190. const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
  191. const incomers = getIncomers(currentNode, nodes, edges)
  192. if (!streamInfo[currentNodeHandleKey]) {
  193. streamInfo[currentNodeHandleKey] = {
  194. upstreamNodes: new Set<string>(),
  195. downstreamEdges: new Set<string>(),
  196. }
  197. }
  198. if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
  199. const newSet = new Set<string>()
  200. for (const item of totalEdgesSet) {
  201. if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
  202. newSet.add(item)
  203. }
  204. if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
  205. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  206. nextNodeHandles.push({ node: currentNode, handle: currentHandle })
  207. break
  208. }
  209. }
  210. if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
  211. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  212. outgoers.forEach((outgoer) => {
  213. const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
  214. const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
  215. const incomers = getIncomers(outgoer, nodes, edges)
  216. if (outgoers.length > 1 && incomers.length > 1)
  217. hasAbnormalEdges = true
  218. Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
  219. nextHandles.push({ node: outgoer, handle: sourceHandle })
  220. })
  221. if (!outgoerConnectedEdges.length)
  222. nextHandles.push({ node: outgoer, handle: 'source' })
  223. const outgoerKey = outgoer.id
  224. if (!nodeEdgesSet[outgoerKey])
  225. nodeEdgesSet[outgoerKey] = new Set<string>()
  226. if (nodeEdgesSet[currentNodeHandleKey]) {
  227. for (const item of nodeEdgesSet[currentNodeHandleKey])
  228. nodeEdgesSet[outgoerKey].add(item)
  229. }
  230. if (!streamInfo[outgoerKey]) {
  231. streamInfo[outgoerKey] = {
  232. upstreamNodes: new Set<string>(),
  233. downstreamEdges: new Set<string>(),
  234. }
  235. }
  236. if (!nodeParallelInfoMap[outgoer.id]) {
  237. nodeParallelInfoMap[outgoer.id] = {
  238. ...nodeParallelInfoMap[currentNode.id],
  239. }
  240. }
  241. if (connectedEdgesLength > 1) {
  242. const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
  243. nodeEdgesSet[outgoerKey].add(edge.id)
  244. totalEdgesSet.add(edge.id)
  245. streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
  246. streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
  247. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  248. streamInfo[item].downstreamEdges.add(edge.id)
  249. if (!parallelListItem.parallelNodeId)
  250. parallelListItem.parallelNodeId = currentNode.id
  251. const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
  252. const currentDepth = nodeParallelInfoMap[outgoer.id].depth
  253. nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
  254. }
  255. else {
  256. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  257. streamInfo[outgoerKey].upstreamNodes.add(item)
  258. nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
  259. }
  260. })
  261. }
  262. parallelList.push(parallelListItem)
  263. }
  264. while (nextNodeHandles.length) {
  265. const nodeHandle = nextNodeHandles.shift()!
  266. traverse(nodeHandle)
  267. }
  268. return {
  269. parallelList,
  270. hasAbnormalEdges,
  271. }
  272. }
  273. export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
  274. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  275. }