You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.ts 32KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. import {
  2. Position,
  3. getConnectedEdges,
  4. getIncomers,
  5. getOutgoers,
  6. } from 'reactflow'
  7. import dagre from '@dagrejs/dagre'
  8. import { v4 as uuid4 } from 'uuid'
  9. import {
  10. cloneDeep,
  11. groupBy,
  12. isEqual,
  13. uniqBy,
  14. } from 'lodash-es'
  15. import type {
  16. Edge,
  17. InputVar,
  18. Node,
  19. ToolWithProvider,
  20. ValueSelector,
  21. } from './types'
  22. import {
  23. BlockEnum,
  24. ErrorHandleMode,
  25. NodeRunningStatus,
  26. } from './types'
  27. import {
  28. CUSTOM_NODE,
  29. DEFAULT_RETRY_INTERVAL,
  30. DEFAULT_RETRY_MAX,
  31. ITERATION_CHILDREN_Z_INDEX,
  32. ITERATION_NODE_Z_INDEX,
  33. LOOP_CHILDREN_Z_INDEX,
  34. LOOP_NODE_Z_INDEX,
  35. NODE_LAYOUT_HORIZONTAL_PADDING,
  36. NODE_LAYOUT_MIN_DISTANCE,
  37. NODE_LAYOUT_VERTICAL_PADDING,
  38. NODE_WIDTH_X_OFFSET,
  39. START_INITIAL_POSITION,
  40. } from './constants'
  41. import { CUSTOM_ITERATION_START_NODE } from './nodes/iteration-start/constants'
  42. import { CUSTOM_LOOP_START_NODE } from './nodes/loop-start/constants'
  43. import type { QuestionClassifierNodeType } from './nodes/question-classifier/types'
  44. import type { IfElseNodeType } from './nodes/if-else/types'
  45. import { branchNameCorrect } from './nodes/if-else/utils'
  46. import type { ToolNodeType } from './nodes/tool/types'
  47. import type { IterationNodeType } from './nodes/iteration/types'
  48. import type { LoopNodeType } from './nodes/loop/types'
  49. import { CollectionType } from '@/app/components/tools/types'
  50. import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
  51. import { canFindTool, correctModelProvider } from '@/utils'
  52. const WHITE = 'WHITE'
  53. const GRAY = 'GRAY'
  54. const BLACK = 'BLACK'
  55. const isCyclicUtil = (nodeId: string, color: Record<string, string>, adjList: Record<string, string[]>, stack: string[]) => {
  56. color[nodeId] = GRAY
  57. stack.push(nodeId)
  58. for (let i = 0; i < adjList[nodeId].length; ++i) {
  59. const childId = adjList[nodeId][i]
  60. if (color[childId] === GRAY) {
  61. stack.push(childId)
  62. return true
  63. }
  64. if (color[childId] === WHITE && isCyclicUtil(childId, color, adjList, stack))
  65. return true
  66. }
  67. color[nodeId] = BLACK
  68. if (stack.length > 0 && stack[stack.length - 1] === nodeId)
  69. stack.pop()
  70. return false
  71. }
  72. const getCycleEdges = (nodes: Node[], edges: Edge[]) => {
  73. const adjList: Record<string, string[]> = {}
  74. const color: Record<string, string> = {}
  75. const stack: string[] = []
  76. for (const node of nodes) {
  77. color[node.id] = WHITE
  78. adjList[node.id] = []
  79. }
  80. for (const edge of edges)
  81. adjList[edge.source]?.push(edge.target)
  82. for (let i = 0; i < nodes.length; i++) {
  83. if (color[nodes[i].id] === WHITE)
  84. isCyclicUtil(nodes[i].id, color, adjList, stack)
  85. }
  86. const cycleEdges = []
  87. if (stack.length > 0) {
  88. const cycleNodes = new Set(stack)
  89. for (const edge of edges) {
  90. if (cycleNodes.has(edge.source) && cycleNodes.has(edge.target))
  91. cycleEdges.push(edge)
  92. }
  93. }
  94. return cycleEdges
  95. }
  96. export function getIterationStartNode(iterationId: string): Node {
  97. return generateNewNode({
  98. id: `${iterationId}start`,
  99. type: CUSTOM_ITERATION_START_NODE,
  100. data: {
  101. title: '',
  102. desc: '',
  103. type: BlockEnum.IterationStart,
  104. isInIteration: true,
  105. },
  106. position: {
  107. x: 24,
  108. y: 68,
  109. },
  110. zIndex: ITERATION_CHILDREN_Z_INDEX,
  111. parentId: iterationId,
  112. selectable: false,
  113. draggable: false,
  114. }).newNode
  115. }
  116. export function getLoopStartNode(loopId: string): Node {
  117. return generateNewNode({
  118. id: `${loopId}start`,
  119. type: CUSTOM_LOOP_START_NODE,
  120. data: {
  121. title: '',
  122. desc: '',
  123. type: BlockEnum.LoopStart,
  124. isInLoop: true,
  125. },
  126. position: {
  127. x: 24,
  128. y: 68,
  129. },
  130. zIndex: LOOP_CHILDREN_Z_INDEX,
  131. parentId: loopId,
  132. selectable: false,
  133. draggable: false,
  134. }).newNode
  135. }
  136. export function generateNewNode({ data, position, id, zIndex, type, ...rest }: Omit<Node, 'id'> & { id?: string }): {
  137. newNode: Node
  138. newIterationStartNode?: Node
  139. newLoopStartNode?: Node
  140. } {
  141. const newNode = {
  142. id: id || `${Date.now()}`,
  143. type: type || CUSTOM_NODE,
  144. data,
  145. position,
  146. targetPosition: Position.Left,
  147. sourcePosition: Position.Right,
  148. zIndex: data.type === BlockEnum.Iteration ? ITERATION_NODE_Z_INDEX : (data.type === BlockEnum.Loop ? LOOP_NODE_Z_INDEX : zIndex),
  149. ...rest,
  150. } as Node
  151. if (data.type === BlockEnum.Iteration) {
  152. const newIterationStartNode = getIterationStartNode(newNode.id);
  153. (newNode.data as IterationNodeType).start_node_id = newIterationStartNode.id;
  154. (newNode.data as IterationNodeType)._children = [newIterationStartNode.id]
  155. return {
  156. newNode,
  157. newIterationStartNode,
  158. }
  159. }
  160. if (data.type === BlockEnum.Loop) {
  161. const newLoopStartNode = getLoopStartNode(newNode.id);
  162. (newNode.data as LoopNodeType).start_node_id = newLoopStartNode.id;
  163. (newNode.data as LoopNodeType)._children = [newLoopStartNode.id]
  164. return {
  165. newNode,
  166. newLoopStartNode,
  167. }
  168. }
  169. return {
  170. newNode,
  171. }
  172. }
  173. export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
  174. const hasIterationNode = nodes.some(node => node.data.type === BlockEnum.Iteration)
  175. const hasLoopNode = nodes.some(node => node.data.type === BlockEnum.Loop)
  176. if (!hasIterationNode) {
  177. return {
  178. nodes,
  179. edges,
  180. }
  181. }
  182. if (!hasLoopNode) {
  183. return {
  184. nodes,
  185. edges,
  186. }
  187. }
  188. const nodesMap = nodes.reduce((prev, next) => {
  189. prev[next.id] = next
  190. return prev
  191. }, {} as Record<string, Node>)
  192. const iterationNodesWithStartNode = []
  193. const iterationNodesWithoutStartNode = []
  194. const loopNodesWithStartNode = []
  195. const loopNodesWithoutStartNode = []
  196. for (let i = 0; i < nodes.length; i++) {
  197. const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType>
  198. if (currentNode.data.type === BlockEnum.Iteration) {
  199. if (currentNode.data.start_node_id) {
  200. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE)
  201. iterationNodesWithStartNode.push(currentNode)
  202. }
  203. else {
  204. iterationNodesWithoutStartNode.push(currentNode)
  205. }
  206. }
  207. if (currentNode.data.type === BlockEnum.Loop) {
  208. if (currentNode.data.start_node_id) {
  209. if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE)
  210. loopNodesWithStartNode.push(currentNode)
  211. }
  212. else {
  213. loopNodesWithoutStartNode.push(currentNode)
  214. }
  215. }
  216. }
  217. const newIterationStartNodesMap = {} as Record<string, Node>
  218. const newIterationStartNodes = [...iterationNodesWithStartNode, ...iterationNodesWithoutStartNode].map((iterationNode, index) => {
  219. const newNode = getIterationStartNode(iterationNode.id)
  220. newNode.id = newNode.id + index
  221. newIterationStartNodesMap[iterationNode.id] = newNode
  222. return newNode
  223. })
  224. const newLoopStartNodesMap = {} as Record<string, Node>
  225. const newLoopStartNodes = [...loopNodesWithStartNode, ...loopNodesWithoutStartNode].map((loopNode, index) => {
  226. const newNode = getLoopStartNode(loopNode.id)
  227. newNode.id = newNode.id + index
  228. newLoopStartNodesMap[loopNode.id] = newNode
  229. return newNode
  230. })
  231. const newEdges = [...iterationNodesWithStartNode, ...loopNodesWithStartNode].map((nodeItem) => {
  232. const isIteration = nodeItem.data.type === BlockEnum.Iteration
  233. const newNode = (isIteration ? newIterationStartNodesMap : newLoopStartNodesMap)[nodeItem.id]
  234. const startNode = nodesMap[nodeItem.data.start_node_id]
  235. const source = newNode.id
  236. const sourceHandle = 'source'
  237. const target = startNode.id
  238. const targetHandle = 'target'
  239. const parentNode = nodes.find(node => node.id === startNode.parentId) || null
  240. const isInIteration = !!parentNode && parentNode.data.type === BlockEnum.Iteration
  241. const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
  242. return {
  243. id: `${source}-${sourceHandle}-${target}-${targetHandle}`,
  244. type: 'custom',
  245. source,
  246. sourceHandle,
  247. target,
  248. targetHandle,
  249. data: {
  250. sourceType: newNode.data.type,
  251. targetType: startNode.data.type,
  252. isInIteration,
  253. iteration_id: isInIteration ? startNode.parentId : undefined,
  254. isInLoop,
  255. loop_id: isInLoop ? startNode.parentId : undefined,
  256. _connectedNodeIsSelected: true,
  257. },
  258. zIndex: isIteration ? ITERATION_CHILDREN_Z_INDEX : LOOP_CHILDREN_Z_INDEX,
  259. }
  260. })
  261. nodes.forEach((node) => {
  262. if (node.data.type === BlockEnum.Iteration && newIterationStartNodesMap[node.id])
  263. (node.data as IterationNodeType).start_node_id = newIterationStartNodesMap[node.id].id
  264. if (node.data.type === BlockEnum.Loop && newLoopStartNodesMap[node.id])
  265. (node.data as LoopNodeType).start_node_id = newLoopStartNodesMap[node.id].id
  266. })
  267. return {
  268. nodes: [...nodes, ...newIterationStartNodes, ...newLoopStartNodes],
  269. edges: [...edges, ...newEdges],
  270. }
  271. }
  272. export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
  273. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  274. const firstNode = nodes[0]
  275. if (!firstNode?.position) {
  276. nodes.forEach((node, index) => {
  277. node.position = {
  278. x: START_INITIAL_POSITION.x + index * NODE_WIDTH_X_OFFSET,
  279. y: START_INITIAL_POSITION.y,
  280. }
  281. })
  282. }
  283. const iterationOrLoopNodeMap = nodes.reduce((acc, node) => {
  284. if (node.parentId) {
  285. if (acc[node.parentId])
  286. acc[node.parentId].push(node.id)
  287. else
  288. acc[node.parentId] = [node.id]
  289. }
  290. return acc
  291. }, {} as Record<string, string[]>)
  292. return nodes.map((node) => {
  293. if (!node.type)
  294. node.type = CUSTOM_NODE
  295. const connectedEdges = getConnectedEdges([node], edges)
  296. node.data._connectedSourceHandleIds = connectedEdges.filter(edge => edge.source === node.id).map(edge => edge.sourceHandle || 'source')
  297. node.data._connectedTargetHandleIds = connectedEdges.filter(edge => edge.target === node.id).map(edge => edge.targetHandle || 'target')
  298. if (node.data.type === BlockEnum.IfElse) {
  299. const nodeData = node.data as IfElseNodeType
  300. if (!nodeData.cases && nodeData.logical_operator && nodeData.conditions) {
  301. (node.data as IfElseNodeType).cases = [
  302. {
  303. case_id: 'true',
  304. logical_operator: nodeData.logical_operator,
  305. conditions: nodeData.conditions,
  306. },
  307. ]
  308. }
  309. node.data._targetBranches = branchNameCorrect([
  310. ...(node.data as IfElseNodeType).cases.map(item => ({ id: item.case_id, name: '' })),
  311. { id: 'false', name: '' },
  312. ])
  313. }
  314. if (node.data.type === BlockEnum.QuestionClassifier) {
  315. node.data._targetBranches = (node.data as QuestionClassifierNodeType).classes.map((topic) => {
  316. return topic
  317. })
  318. }
  319. if (node.data.type === BlockEnum.Iteration) {
  320. const iterationNodeData = node.data as IterationNodeType
  321. iterationNodeData._children = iterationOrLoopNodeMap[node.id] || []
  322. iterationNodeData.is_parallel = iterationNodeData.is_parallel || false
  323. iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10
  324. iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated
  325. }
  326. // TODO: loop error handle mode
  327. if (node.data.type === BlockEnum.Loop) {
  328. const loopNodeData = node.data as LoopNodeType
  329. loopNodeData._children = iterationOrLoopNodeMap[node.id] || []
  330. loopNodeData.error_handle_mode = loopNodeData.error_handle_mode || ErrorHandleMode.Terminated
  331. }
  332. // legacy provider handle
  333. if (node.data.type === BlockEnum.LLM)
  334. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  335. if (node.data.type === BlockEnum.KnowledgeRetrieval && (node as any).data.multiple_retrieval_config?.reranking_model)
  336. (node as any).data.multiple_retrieval_config.reranking_model.provider = correctModelProvider((node as any).data.multiple_retrieval_config?.reranking_model.provider)
  337. if (node.data.type === BlockEnum.QuestionClassifier)
  338. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  339. if (node.data.type === BlockEnum.ParameterExtractor)
  340. (node as any).data.model.provider = correctModelProvider((node as any).data.model.provider)
  341. if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) {
  342. node.data.retry_config = {
  343. retry_enabled: true,
  344. max_retries: DEFAULT_RETRY_MAX,
  345. retry_interval: DEFAULT_RETRY_INTERVAL,
  346. }
  347. }
  348. return node
  349. })
  350. }
  351. export const initialEdges = (originEdges: Edge[], originNodes: Node[]) => {
  352. const { nodes, edges } = preprocessNodesAndEdges(cloneDeep(originNodes), cloneDeep(originEdges))
  353. let selectedNode: Node | null = null
  354. const nodesMap = nodes.reduce((acc, node) => {
  355. acc[node.id] = node
  356. if (node.data?.selected)
  357. selectedNode = node
  358. return acc
  359. }, {} as Record<string, Node>)
  360. const cycleEdges = getCycleEdges(nodes, edges)
  361. return edges.filter((edge) => {
  362. return !cycleEdges.find(cycEdge => cycEdge.source === edge.source && cycEdge.target === edge.target)
  363. }).map((edge) => {
  364. edge.type = 'custom'
  365. if (!edge.sourceHandle)
  366. edge.sourceHandle = 'source'
  367. if (!edge.targetHandle)
  368. edge.targetHandle = 'target'
  369. if (!edge.data?.sourceType && edge.source && nodesMap[edge.source]) {
  370. edge.data = {
  371. ...edge.data,
  372. sourceType: nodesMap[edge.source].data.type!,
  373. } as any
  374. }
  375. if (!edge.data?.targetType && edge.target && nodesMap[edge.target]) {
  376. edge.data = {
  377. ...edge.data,
  378. targetType: nodesMap[edge.target].data.type!,
  379. } as any
  380. }
  381. if (selectedNode) {
  382. edge.data = {
  383. ...edge.data,
  384. _connectedNodeIsSelected: edge.source === selectedNode.id || edge.target === selectedNode.id,
  385. } as any
  386. }
  387. return edge
  388. })
  389. }
  390. export const getLayoutByDagre = (originNodes: Node[], originEdges: Edge[]) => {
  391. const dagreGraph = new dagre.graphlib.Graph()
  392. dagreGraph.setDefaultEdgeLabel(() => ({}))
  393. const nodes = cloneDeep(originNodes).filter(node => !node.parentId && node.type === CUSTOM_NODE)
  394. const edges = cloneDeep(originEdges).filter(edge => (!edge.data?.isInIteration && !edge.data?.isInLoop))
  395. dagreGraph.setGraph({
  396. rankdir: 'LR',
  397. align: 'UL',
  398. nodesep: 40,
  399. ranksep: 60,
  400. ranker: 'tight-tree',
  401. marginx: 30,
  402. marginy: 200,
  403. })
  404. nodes.forEach((node) => {
  405. dagreGraph.setNode(node.id, {
  406. width: node.width!,
  407. height: node.height!,
  408. })
  409. })
  410. edges.forEach((edge) => {
  411. dagreGraph.setEdge(edge.source, edge.target)
  412. })
  413. dagre.layout(dagreGraph)
  414. return dagreGraph
  415. }
  416. export const getLayoutForChildNodes = (parentNodeId: string, originNodes: Node[], originEdges: Edge[]) => {
  417. const dagreGraph = new dagre.graphlib.Graph()
  418. dagreGraph.setDefaultEdgeLabel(() => ({}))
  419. const nodes = cloneDeep(originNodes).filter(node => node.parentId === parentNodeId)
  420. const edges = cloneDeep(originEdges).filter(edge =>
  421. (edge.data?.isInIteration && edge.data?.iteration_id === parentNodeId)
  422. || (edge.data?.isInLoop && edge.data?.loop_id === parentNodeId),
  423. )
  424. const startNode = nodes.find(node =>
  425. node.type === CUSTOM_ITERATION_START_NODE
  426. || node.type === CUSTOM_LOOP_START_NODE
  427. || node.data?.type === BlockEnum.LoopStart
  428. || node.data?.type === BlockEnum.IterationStart,
  429. )
  430. if (!startNode) {
  431. dagreGraph.setGraph({
  432. rankdir: 'LR',
  433. align: 'UL',
  434. nodesep: 40,
  435. ranksep: 60,
  436. marginx: NODE_LAYOUT_HORIZONTAL_PADDING,
  437. marginy: NODE_LAYOUT_VERTICAL_PADDING,
  438. })
  439. nodes.forEach((node) => {
  440. dagreGraph.setNode(node.id, {
  441. width: node.width || 244,
  442. height: node.height || 100,
  443. })
  444. })
  445. edges.forEach((edge) => {
  446. dagreGraph.setEdge(edge.source, edge.target)
  447. })
  448. dagre.layout(dagreGraph)
  449. return dagreGraph
  450. }
  451. const startNodeOutEdges = edges.filter(edge => edge.source === startNode.id)
  452. const firstConnectedNodes = startNodeOutEdges.map(edge =>
  453. nodes.find(node => node.id === edge.target),
  454. ).filter(Boolean) as Node[]
  455. const nonStartNodes = nodes.filter(node => node.id !== startNode.id)
  456. const nonStartEdges = edges.filter(edge => edge.source !== startNode.id && edge.target !== startNode.id)
  457. dagreGraph.setGraph({
  458. rankdir: 'LR',
  459. align: 'UL',
  460. nodesep: 40,
  461. ranksep: 60,
  462. marginx: NODE_LAYOUT_HORIZONTAL_PADDING / 2,
  463. marginy: NODE_LAYOUT_VERTICAL_PADDING / 2,
  464. })
  465. nonStartNodes.forEach((node) => {
  466. dagreGraph.setNode(node.id, {
  467. width: node.width || 244,
  468. height: node.height || 100,
  469. })
  470. })
  471. nonStartEdges.forEach((edge) => {
  472. dagreGraph.setEdge(edge.source, edge.target)
  473. })
  474. dagre.layout(dagreGraph)
  475. const startNodeSize = {
  476. width: startNode.width || 44,
  477. height: startNode.height || 48,
  478. }
  479. const startNodeX = NODE_LAYOUT_HORIZONTAL_PADDING / 1.5
  480. let startNodeY = 100
  481. let minFirstLayerX = Infinity
  482. let avgFirstLayerY = 0
  483. let firstLayerCount = 0
  484. if (firstConnectedNodes.length > 0) {
  485. firstConnectedNodes.forEach((node) => {
  486. if (dagreGraph.node(node.id)) {
  487. const nodePos = dagreGraph.node(node.id)
  488. avgFirstLayerY += nodePos.y
  489. firstLayerCount++
  490. minFirstLayerX = Math.min(minFirstLayerX, nodePos.x - nodePos.width / 2)
  491. }
  492. })
  493. if (firstLayerCount > 0) {
  494. avgFirstLayerY /= firstLayerCount
  495. startNodeY = avgFirstLayerY
  496. }
  497. const minRequiredX = startNodeX + startNodeSize.width + NODE_LAYOUT_MIN_DISTANCE
  498. if (minFirstLayerX < minRequiredX) {
  499. const shiftX = minRequiredX - minFirstLayerX
  500. nonStartNodes.forEach((node) => {
  501. if (dagreGraph.node(node.id)) {
  502. const nodePos = dagreGraph.node(node.id)
  503. dagreGraph.setNode(node.id, {
  504. x: nodePos.x + shiftX,
  505. y: nodePos.y,
  506. width: nodePos.width,
  507. height: nodePos.height,
  508. })
  509. }
  510. })
  511. }
  512. }
  513. dagreGraph.setNode(startNode.id, {
  514. x: startNodeX + startNodeSize.width / 2,
  515. y: startNodeY,
  516. width: startNodeSize.width,
  517. height: startNodeSize.height,
  518. })
  519. startNodeOutEdges.forEach((edge) => {
  520. dagreGraph.setEdge(edge.source, edge.target)
  521. })
  522. return dagreGraph
  523. }
  524. export const canRunBySingle = (nodeType: BlockEnum) => {
  525. return nodeType === BlockEnum.LLM
  526. || nodeType === BlockEnum.KnowledgeRetrieval
  527. || nodeType === BlockEnum.Code
  528. || nodeType === BlockEnum.TemplateTransform
  529. || nodeType === BlockEnum.QuestionClassifier
  530. || nodeType === BlockEnum.HttpRequest
  531. || nodeType === BlockEnum.Tool
  532. || nodeType === BlockEnum.ParameterExtractor
  533. || nodeType === BlockEnum.Iteration
  534. || nodeType === BlockEnum.Agent
  535. || nodeType === BlockEnum.DocExtractor
  536. || nodeType === BlockEnum.Loop
  537. }
  538. type ConnectedSourceOrTargetNodesChange = {
  539. type: string
  540. edge: Edge
  541. }[]
  542. export const getNodesConnectedSourceOrTargetHandleIdsMap = (changes: ConnectedSourceOrTargetNodesChange, nodes: Node[]) => {
  543. const nodesConnectedSourceOrTargetHandleIdsMap = {} as Record<string, any>
  544. changes.forEach((change) => {
  545. const {
  546. edge,
  547. type,
  548. } = change
  549. const sourceNode = nodes.find(node => node.id === edge.source)!
  550. if (sourceNode) {
  551. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id] || {
  552. _connectedSourceHandleIds: [...(sourceNode?.data._connectedSourceHandleIds || [])],
  553. _connectedTargetHandleIds: [...(sourceNode?.data._connectedTargetHandleIds || [])],
  554. }
  555. }
  556. const targetNode = nodes.find(node => node.id === edge.target)!
  557. if (targetNode) {
  558. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id] || {
  559. _connectedSourceHandleIds: [...(targetNode?.data._connectedSourceHandleIds || [])],
  560. _connectedTargetHandleIds: [...(targetNode?.data._connectedTargetHandleIds || [])],
  561. }
  562. }
  563. if (sourceNode) {
  564. if (type === 'remove') {
  565. const index = nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.findIndex((handleId: string) => handleId === edge.sourceHandle)
  566. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.splice(index, 1)
  567. }
  568. if (type === 'add')
  569. nodesConnectedSourceOrTargetHandleIdsMap[sourceNode.id]._connectedSourceHandleIds.push(edge.sourceHandle || 'source')
  570. }
  571. if (targetNode) {
  572. if (type === 'remove') {
  573. const index = nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.findIndex((handleId: string) => handleId === edge.targetHandle)
  574. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.splice(index, 1)
  575. }
  576. if (type === 'add')
  577. nodesConnectedSourceOrTargetHandleIdsMap[targetNode.id]._connectedTargetHandleIds.push(edge.targetHandle || 'target')
  578. }
  579. })
  580. return nodesConnectedSourceOrTargetHandleIdsMap
  581. }
  582. export const genNewNodeTitleFromOld = (oldTitle: string) => {
  583. const regex = /^(.+?)\s*\((\d+)\)\s*$/
  584. const match = oldTitle.match(regex)
  585. if (match) {
  586. const title = match[1]
  587. const num = Number.parseInt(match[2], 10)
  588. return `${title} (${num + 1})`
  589. }
  590. else {
  591. return `${oldTitle} (1)`
  592. }
  593. }
  594. export const getValidTreeNodes = (nodes: Node[], edges: Edge[]) => {
  595. const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  596. if (!startNode) {
  597. return {
  598. validNodes: [],
  599. maxDepth: 0,
  600. }
  601. }
  602. const list: Node[] = [startNode]
  603. let maxDepth = 1
  604. const traverse = (root: Node, depth: number) => {
  605. if (depth > maxDepth)
  606. maxDepth = depth
  607. const outgoers = getOutgoers(root, nodes, edges)
  608. if (outgoers.length) {
  609. outgoers.forEach((outgoer) => {
  610. list.push(outgoer)
  611. if (outgoer.data.type === BlockEnum.Iteration)
  612. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  613. if (outgoer.data.type === BlockEnum.Loop)
  614. list.push(...nodes.filter(node => node.parentId === outgoer.id))
  615. traverse(outgoer, depth + 1)
  616. })
  617. }
  618. else {
  619. list.push(root)
  620. if (root.data.type === BlockEnum.Iteration)
  621. list.push(...nodes.filter(node => node.parentId === root.id))
  622. if (root.data.type === BlockEnum.Loop)
  623. list.push(...nodes.filter(node => node.parentId === root.id))
  624. }
  625. }
  626. traverse(startNode, maxDepth)
  627. return {
  628. validNodes: uniqBy(list, 'id'),
  629. maxDepth,
  630. }
  631. }
  632. export const getToolCheckParams = (
  633. toolData: ToolNodeType,
  634. buildInTools: ToolWithProvider[],
  635. customTools: ToolWithProvider[],
  636. workflowTools: ToolWithProvider[],
  637. language: string,
  638. ) => {
  639. const { provider_id, provider_type, tool_name } = toolData
  640. const isBuiltIn = provider_type === CollectionType.builtIn
  641. const currentTools = provider_type === CollectionType.builtIn ? buildInTools : provider_type === CollectionType.custom ? customTools : workflowTools
  642. const currCollection = currentTools.find(item => canFindTool(item.id, provider_id))
  643. const currTool = currCollection?.tools.find(tool => tool.name === tool_name)
  644. const formSchemas = currTool ? toolParametersToFormSchemas(currTool.parameters) : []
  645. const toolInputVarSchema = formSchemas.filter((item: any) => item.form === 'llm')
  646. const toolSettingSchema = formSchemas.filter((item: any) => item.form !== 'llm')
  647. return {
  648. toolInputsSchema: (() => {
  649. const formInputs: InputVar[] = []
  650. toolInputVarSchema.forEach((item: any) => {
  651. formInputs.push({
  652. label: item.label[language] || item.label.en_US,
  653. variable: item.variable,
  654. type: item.type,
  655. required: item.required,
  656. })
  657. })
  658. return formInputs
  659. })(),
  660. notAuthed: isBuiltIn && !!currCollection?.allow_delete && !currCollection?.is_team_authorization,
  661. toolSettingSchema,
  662. language,
  663. }
  664. }
  665. export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
  666. const idMap = nodes.reduce((acc, node) => {
  667. acc[node.id] = uuid4()
  668. return acc
  669. }, {} as Record<string, string>)
  670. const newNodes = nodes.map((node) => {
  671. return {
  672. ...node,
  673. id: idMap[node.id],
  674. }
  675. })
  676. const newEdges = edges.map((edge) => {
  677. return {
  678. ...edge,
  679. source: idMap[edge.source],
  680. target: idMap[edge.target],
  681. }
  682. })
  683. return [newNodes, newEdges] as [Node[], Edge[]]
  684. }
  685. export const isMac = () => {
  686. return navigator.userAgent.toUpperCase().includes('MAC')
  687. }
  688. const specialKeysNameMap: Record<string, string | undefined> = {
  689. ctrl: '⌘',
  690. alt: '⌥',
  691. shift: '⇧',
  692. }
  693. export const getKeyboardKeyNameBySystem = (key: string) => {
  694. if (isMac())
  695. return specialKeysNameMap[key] || key
  696. return key
  697. }
  698. const specialKeysCodeMap: Record<string, string | undefined> = {
  699. ctrl: 'meta',
  700. }
  701. export const getKeyboardKeyCodeBySystem = (key: string) => {
  702. if (isMac())
  703. return specialKeysCodeMap[key] || key
  704. return key
  705. }
  706. export const getTopLeftNodePosition = (nodes: Node[]) => {
  707. let minX = Infinity
  708. let minY = Infinity
  709. nodes.forEach((node) => {
  710. if (node.position.x < minX)
  711. minX = node.position.x
  712. if (node.position.y < minY)
  713. minY = node.position.y
  714. })
  715. return {
  716. x: minX,
  717. y: minY,
  718. }
  719. }
  720. export const isEventTargetInputArea = (target: HTMLElement) => {
  721. if (target.tagName === 'INPUT' || target.tagName === 'TEXTAREA')
  722. return true
  723. if (target.contentEditable === 'true')
  724. return true
  725. }
  726. export const variableTransformer = (v: ValueSelector | string) => {
  727. if (typeof v === 'string')
  728. return v.replace(/^{{#|#}}$/g, '').split('.')
  729. return `{{#${v.join('.')}#}}`
  730. }
  731. type ParallelInfoItem = {
  732. parallelNodeId: string
  733. depth: number
  734. isBranch?: boolean
  735. }
  736. type NodeParallelInfo = {
  737. parallelNodeId: string
  738. edgeHandleId: string
  739. depth: number
  740. }
  741. type NodeHandle = {
  742. node: Node
  743. handle: string
  744. }
  745. type NodeStreamInfo = {
  746. upstreamNodes: Set<string>
  747. downstreamEdges: Set<string>
  748. }
  749. export const getParallelInfo = (nodes: Node[], edges: Edge[], parentNodeId?: string) => {
  750. let startNode
  751. if (parentNodeId) {
  752. const parentNode = nodes.find(node => node.id === parentNodeId)
  753. if (!parentNode)
  754. throw new Error('Parent node not found')
  755. startNode = nodes.find(node => node.id === (parentNode.data as (IterationNodeType | LoopNodeType)).start_node_id)
  756. }
  757. else {
  758. startNode = nodes.find(node => node.data.type === BlockEnum.Start)
  759. }
  760. if (!startNode)
  761. throw new Error('Start node not found')
  762. const parallelList = [] as ParallelInfoItem[]
  763. const nextNodeHandles = [{ node: startNode, handle: 'source' }]
  764. let hasAbnormalEdges = false
  765. const traverse = (firstNodeHandle: NodeHandle) => {
  766. const nodeEdgesSet = {} as Record<string, Set<string>>
  767. const totalEdgesSet = new Set<string>()
  768. const nextHandles = [firstNodeHandle]
  769. const streamInfo = {} as Record<string, NodeStreamInfo>
  770. const parallelListItem = {
  771. parallelNodeId: '',
  772. depth: 0,
  773. } as ParallelInfoItem
  774. const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
  775. nodeParallelInfoMap[firstNodeHandle.node.id] = {
  776. parallelNodeId: '',
  777. edgeHandleId: '',
  778. depth: 0,
  779. }
  780. while (nextHandles.length) {
  781. const currentNodeHandle = nextHandles.shift()!
  782. const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
  783. const currentNodeHandleKey = currentNode.id
  784. const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
  785. const connectedEdgesLength = connectedEdges.length
  786. const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
  787. const incomers = getIncomers(currentNode, nodes, edges)
  788. if (!streamInfo[currentNodeHandleKey]) {
  789. streamInfo[currentNodeHandleKey] = {
  790. upstreamNodes: new Set<string>(),
  791. downstreamEdges: new Set<string>(),
  792. }
  793. }
  794. if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
  795. const newSet = new Set<string>()
  796. for (const item of totalEdgesSet) {
  797. if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
  798. newSet.add(item)
  799. }
  800. if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
  801. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  802. nextNodeHandles.push({ node: currentNode, handle: currentHandle })
  803. break
  804. }
  805. }
  806. if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
  807. parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
  808. outgoers.forEach((outgoer) => {
  809. const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
  810. const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
  811. const incomers = getIncomers(outgoer, nodes, edges)
  812. if (outgoers.length > 1 && incomers.length > 1)
  813. hasAbnormalEdges = true
  814. Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
  815. nextHandles.push({ node: outgoer, handle: sourceHandle })
  816. })
  817. if (!outgoerConnectedEdges.length)
  818. nextHandles.push({ node: outgoer, handle: 'source' })
  819. const outgoerKey = outgoer.id
  820. if (!nodeEdgesSet[outgoerKey])
  821. nodeEdgesSet[outgoerKey] = new Set<string>()
  822. if (nodeEdgesSet[currentNodeHandleKey]) {
  823. for (const item of nodeEdgesSet[currentNodeHandleKey])
  824. nodeEdgesSet[outgoerKey].add(item)
  825. }
  826. if (!streamInfo[outgoerKey]) {
  827. streamInfo[outgoerKey] = {
  828. upstreamNodes: new Set<string>(),
  829. downstreamEdges: new Set<string>(),
  830. }
  831. }
  832. if (!nodeParallelInfoMap[outgoer.id]) {
  833. nodeParallelInfoMap[outgoer.id] = {
  834. ...nodeParallelInfoMap[currentNode.id],
  835. }
  836. }
  837. if (connectedEdgesLength > 1) {
  838. const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
  839. nodeEdgesSet[outgoerKey].add(edge.id)
  840. totalEdgesSet.add(edge.id)
  841. streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
  842. streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
  843. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  844. streamInfo[item].downstreamEdges.add(edge.id)
  845. if (!parallelListItem.parallelNodeId)
  846. parallelListItem.parallelNodeId = currentNode.id
  847. const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
  848. const currentDepth = nodeParallelInfoMap[outgoer.id].depth
  849. nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
  850. }
  851. else {
  852. for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
  853. streamInfo[outgoerKey].upstreamNodes.add(item)
  854. nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
  855. }
  856. })
  857. }
  858. parallelList.push(parallelListItem)
  859. }
  860. while (nextNodeHandles.length) {
  861. const nodeHandle = nextNodeHandles.shift()!
  862. traverse(nodeHandle)
  863. }
  864. return {
  865. parallelList,
  866. hasAbnormalEdges,
  867. }
  868. }
  869. export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
  870. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  871. }
  872. export const getEdgeColor = (nodeRunningStatus?: NodeRunningStatus, isFailBranch?: boolean) => {
  873. if (nodeRunningStatus === NodeRunningStatus.Succeeded)
  874. return 'var(--color-workflow-link-line-success-handle)'
  875. if (nodeRunningStatus === NodeRunningStatus.Failed)
  876. return 'var(--color-workflow-link-line-error-handle)'
  877. if (nodeRunningStatus === NodeRunningStatus.Exception)
  878. return 'var(--color-workflow-link-line-failure-handle)'
  879. if (nodeRunningStatus === NodeRunningStatus.Running) {
  880. if (isFailBranch)
  881. return 'var(--color-workflow-link-line-failure-handle)'
  882. return 'var(--color-workflow-link-line-handle)'
  883. }
  884. return 'var(--color-workflow-link-line-normal)'
  885. }
  886. export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => {
  887. if ((variable === 'error_message' || variable === 'error_type') && hasErrorHandleNode(nodeType))
  888. return true
  889. return false
  890. }
  891. export const hasRetryNode = (nodeType?: BlockEnum) => {
  892. return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
  893. }