您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

workflow.ts 13KB


  1. import {
  2. getConnectedEdges,
  3. getIncomers,
  4. getOutgoers,
  5. } from 'reactflow'
  6. import { v4 as uuid4 } from 'uuid'
  7. import {
  8. groupBy,
  9. isEqual,
  10. uniqBy,
  11. } from 'lodash-es'
  12. import type {
  13. ConversationVariable,
  14. Edge,
  15. EnvironmentVariable,
  16. Node,
  17. Var,
  18. } from '../types'
  19. import {
  20. BlockEnum,
  21. } from '../types'
  22. import type { IterationNodeType } from '../nodes/iteration/types'
  23. import type { LoopNodeType } from '../nodes/loop/types'
  24. import { VAR_REGEX_TEXT } from '@/config'
  25. import { formatItem } from '../nodes/_base/components/variable/utils'
  26. import type { StructuredOutput } from '../nodes/llm/types'
  27. export const canRunBySingle = (nodeType: BlockEnum) => {
  28. return nodeType === BlockEnum.LLM
  29. || nodeType === BlockEnum.KnowledgeRetrieval
  30. || nodeType === BlockEnum.Code
  31. || nodeType === BlockEnum.TemplateTransform
  32. || nodeType === BlockEnum.QuestionClassifier
  33. || nodeType === BlockEnum.HttpRequest
  34. || nodeType === BlockEnum.Tool
  35. || nodeType === BlockEnum.ParameterExtractor
  36. || nodeType === BlockEnum.Iteration
  37. || nodeType === BlockEnum.Agent
  38. || nodeType === BlockEnum.DocExtractor
  39. || nodeType === BlockEnum.Loop
  40. }
  41. type ConnectedSourceOrTargetNodesChange = {
  42. type: string
  43. edge: Edge
  44. }[]
  45. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  46. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  47. changes.forEach((change) => {
  48. const {
  49. edge,
  50. type,
  51. } = change
  52. const sourceNode = nodes.find(node => node.id === edge.source)!
  53. if (sourceNode) {
  54. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  55. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  56. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  57. }
  58. }
  59. const targetNode = nodes.find(node => node.id === edge.target)!
  60. if (targetNode) {
  61. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  62. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  63. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  64. }
  65. }
  66. if (sourceNode) {
  67. if (type === 'remove') {
  68. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  69. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  70. }
  71. if (type === 'add')
  72. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  73. }
  74. if (targetNode) {
  75. if (type === 'remove') {
  76. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  77. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  78. }
  79. if (type === 'add')
  80. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  81. }
  82. })
  83. return nodesConnectedSourceOrTargetHandleIdsMap
  84. }
  85. function getParentOutputVarMap(item: Var, path: string, varMap: Record<string, Var>) {
  86. if (!item.children || (Array.isArray(item.children) && !item.children.length) || ((item.children as StructuredOutput).schema))
  87. return
  88. (item.children as Var[]).forEach((child) => {
  89. const newPath = `${path}.${child.variable}`
  90. varMap[newPath] = child
  91. getParentOutputVarMap(child, newPath, varMap)
  92. })
  93. }
  94. export const getValidTreeNodes = (nodes: Node[], edges: Edge[], isCollectVar?: boolean) => {
  95. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  96. if (!startNode) {
  97. return {
  98. validNodes: [],
  99. maxDepth: 0,
  100. }
  101. }
  102. const list: Node[] = [startNode]
  103. let maxDepth = 1
  104. const traverse = (root: Node, depth: number) => {
  105. if (depth > maxDepth)
  106. maxDepth = depth
  107. const outgoers = getOutgoers(root, nodes, edges)
  108. if (outgoers.length) {
  109. outgoers.forEach((outgoer) => {
  110. list.push(outgoer)
  111. if (isCollectVar) {
  112. const nodeObj = formatItem(root, false, () => true)
  113. const varMap = {} as Record<string, Var>
  114. nodeObj.vars.forEach((item) => {
  115. if (item.variable.startsWith('sys.'))
  116. return
  117. const newPath = `${nodeObj.nodeId}.${item.variable}`
  118. varMap[newPath] = item
  119. getParentOutputVarMap(item, newPath, varMap)
  120. })
  121. outgoer._parentOutputVarMap = { ...(root._parentOutputVarMap ?? {}), ...varMap }
  122. }
  123. if (outgoer.data.type === BlockEnum.Iteration)
  124. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  125. if (outgoer.data.type === BlockEnum.Loop)
  126. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  127. traverse(outgoer, depth + 1)
  128. })
  129. }
  130. else {
  131. list.push(root)
  132. if (root.data.type === BlockEnum.Iteration)
  133. list.push(...nodes.filter(node => node.parentId === root.id))
  134. if (root.data.type === BlockEnum.Loop)
  135. list.push(...nodes.filter(node => node.parentId === root.id))
  136. }
  137. }
  138. traverse(startNode, maxDepth)
  139. return {
  140. validNodes: uniqBy(list, 'id'),
  141. maxDepth,
  142. }
  143. }
  144. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  145. const idMap = nodes.reduce((acc, node) => {
  146. acc[node.id] = uuid4()
  147. return acc
  148. }, {} as Record<string, string>)
  149. const newNodes = nodes.map((node) => {
  150. return {
  151. ...node,
  152. id: idMap[node.id],
  153. }
  154. })
  155. const newEdges = edges.map((edge) => {
  156. return {
  157. ...edge,
  158. source: idMap[edge.source],
  159. target: idMap[edge.target],
  160. }
  161. })
  162. return [newNodes, newEdges] as [Node[], Edge[]]
  163. }
  164. type ParallelInfoItem = {
  165. parallelNodeId: string
  166. depth: number
  167. isBranch?: boolean
  168. }
  169. type NodeParallelInfo = {
  170. parallelNodeId: string
  171. edgeHandleId: string
  172. depth: number
  173. }
  174. type NodeHandle = {
  175. node: Node
  176. handle: string
  177. }
  178. type NodeStreamInfo = {
  179. upstreamNodes: Set<string>
  180. downstreamEdges: Set<string>
  181. }
  182. export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => {
  183. let startNode
  184. if (parentNodeId) {
  185. const parentNode = nodes.find(node => node.id === parentNodeId)
  186. if (!parentNode)
  187. throw new Error('Parent node not found')
  188. startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id)
  189. }
  190. else {
  191. startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  192. }
  193. if (!startNode)
  194. throw new Error('Start node not found')
  195. const parallelList = [] as ParallelInfoItem[]
  196. const nextNodeHandles = [{ node: startNode, handle: 'source' }]
  197. let hasAbnormalEdges = false
  198. const traverse = (firstNodeHandle: NodeHandle) => {
  199. const nodeEdgesSet = {} as Record<string, Set<string>>
  200. const totalEdgesSet = new Set<string>()
  201. const nextHandles = [firstNodeHandle]
  202. const streamInfo = {} as Record<string, NodeStreamInfo>
  203. const parallelListItem = {
  204. parallelNodeId: '',
  205. depth: 0,
  206. } as ParallelInfoItem
  207. const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
  208. nodeParallelInfoMap[firstNodeHandle.node.id] = {
  209. parallelNodeId: '',
  210. edgeHandleId: '',
  211. depth: 0,
  212. }
  213. while (nextHandles.length) {
  214. const currentNodeHandle = nextHandles.shift()!
  215. const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
  216. const currentNodeHandleKey = currentNode.id
  217. const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
  218. const connectedEdgesLength = connectedEdges.length
  219. const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
  220. const incomers = getIncomers(currentNode, nodes, edges)
  221. if (!streamInfo[currentNodeHandleKey]) {
  222. streamInfo[currentNodeHandleKey] = {
  223. upstreamNodes: new Set<string>(),
  224. downstreamEdges: new Set<string>(),
  225. }
  226. }
  227. if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
  228. const newSet = new Set<string>()
  229. for (const item of totalEdgesSet) {
  230. if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
  231. newSet.add(item)
  232. }
  233. if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
  234. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  235. nextNodeHandles.push({ node: currentNode, handle: currentHandle })
  236. break
  237. }
  238. }
  239. if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
  240. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  241. outgoers.forEach((outgoer) => {
  242. const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
  243. const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
  244. const incomers = getIncomers(outgoer, nodes, edges)
  245. if (outgoers.length > 1 && incomers.length > 1)
  246. hasAbnormalEdges = true
  247. Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
  248. nextHandles.push({ node: outgoer, handle: sourceHandle })
  249. })
  250. if (!outgoerConnectedEdges.length)
  251. nextHandles.push({ node: outgoer, handle: 'source' })
  252. const outgoerKey = outgoer.id
  253. if (!nodeEdgesSet[outgoerKey])
  254. nodeEdgesSet[outgoerKey] = new Set<string>()
  255. if (nodeEdgesSet[currentNodeHandleKey]) {
  256. for (const item of nodeEdgesSet[currentNodeHandleKey])
  257. nodeEdgesSet[outgoerKey].add(item)
  258. }
  259. if (!streamInfo[outgoerKey]) {
  260. streamInfo[outgoerKey] = {
  261. upstreamNodes: new Set<string>(),
  262. downstreamEdges: new Set<string>(),
  263. }
  264. }
  265. if (!nodeParallelInfoMap[outgoer.id]) {
  266. nodeParallelInfoMap[outgoer.id] = {
  267. ...nodeParallelInfoMap[currentNode.id],
  268. }
  269. }
  270. if (connectedEdgesLength > 1) {
  271. const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
  272. nodeEdgesSet[outgoerKey].add(edge.id)
  273. totalEdgesSet.add(edge.id)
  274. streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
  275. streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
  276. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  277. streamInfo[item].downstreamEdges.add(edge.id)
  278. if (!parallelListItem.parallelNodeId)
  279. parallelListItem.parallelNodeId = currentNode.id
  280. const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
  281. const currentDepth = nodeParallelInfoMap[outgoer.id].depth
  282. nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
  283. }
  284. else {
  285. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  286. streamInfo[outgoerKey].upstreamNodes.add(item)
  287. nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
  288. }
  289. })
  290. }
  291. parallelList.push(parallelListItem)
  292. }
  293. while (nextNodeHandles.length) {
  294. const nodeHandle = nextNodeHandles.shift()!
  295. traverse(nodeHandle)
  296. }
  297. return {
  298. parallelList,
  299. hasAbnormalEdges,
  300. }
  301. }
  302. export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
  303. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  304. }
  305. export const transformStartNodeVariables = (chatVarList: ConversationVariable[], environmentVariables: EnvironmentVariable[]) => {
  306. const variablesMap: Record<string, ConversationVariable | EnvironmentVariable> = {}
  307. chatVarList.forEach((variable) => {
  308. variablesMap[`conversation.${variable.name}`] = variable
  309. })
  310. environmentVariables.forEach((variable) => {
  311. variablesMap[`env.${variable.name}`] = variable
  312. })
  313. return variablesMap
  314. }
  315. export const getNotExistVariablesByText = (text: string, varMap: Record<string, Var>) => {
  316. const var_warnings: string[] = []
  317. text?.replace(VAR_REGEX_TEXT, (str, id_name) => {
  318. if (id_name.startsWith('sys.'))
  319. return str
  320. if (varMap[id_name])
  321. return str
  322. const arr = id_name.split('.')
  323. arr.shift()
  324. var_warnings.push(arr.join('.'))
  325. return str
  326. })
  327. return var_warnings
  328. }
  329. export const getNotExistVariablesByArray = (array: string[][], varMap: Record<string, Var>) => {
  330. if (!array.length)
  331. return []
  332. const var_warnings: string[] = []
  333. array.forEach((item) => {
  334. if (!item.length)
  335. return
  336. if (['sys'].includes(item[0]))
  337. return
  338. const var_warning = varMap[item.join('.')]
  339. if (var_warning)
  340. return
  341. const arr = [...item]
  342. arr.shift()
  343. var_warnings.push(arr.join('.'))
  344. })
  345. return var_warnings
  346. }