您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

use-rag-pipeline-search.tsx 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. 'use client'
  2. import { useCallback, useEffect, useMemo } from 'react'
  3. import { useNodes } from 'reactflow'
  4. import { useNodesInteractions } from '@/app/components/workflow/hooks/use-nodes-interactions'
  5. import type { CommonNodeType } from '@/app/components/workflow/types'
  6. import { ragPipelineNodesAction } from '@/app/components/goto-anything/actions/rag-pipeline-nodes'
  7. import BlockIcon from '@/app/components/workflow/block-icon'
  8. import { setupNodeSelectionListener } from '@/app/components/workflow/utils/node-navigation'
  9. import { BlockEnum } from '@/app/components/workflow/types'
  10. import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types'
  11. import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types'
  12. import type { KnowledgeRetrievalNodeType } from '@/app/components/workflow/nodes/knowledge-retrieval/types'
  13. import { useGetToolIcon } from '@/app/components/workflow/hooks/use-tool-icon'
  14. /**
  15. * Hook to register RAG pipeline nodes search functionality
  16. */
  17. export const useRagPipelineSearch = () => {
  18. const nodes = useNodes()
  19. const { handleNodeSelect } = useNodesInteractions()
  20. const getToolIcon = useGetToolIcon()
  21. // Process nodes to create searchable data structure
  22. const searchableNodes = useMemo(() => {
  23. return nodes.map((node) => {
  24. const nodeData = node.data as CommonNodeType
  25. const title = nodeData.title || nodeData.type || 'Untitled Node'
  26. let desc = nodeData.desc || ''
  27. // Keep the original node title for consistency with workflow display
  28. // Only enhance description for better search context
  29. if (nodeData.type === BlockEnum.Tool) {
  30. const toolData = nodeData as ToolNodeType
  31. desc = toolData.tool_description || toolData.tool_label || desc
  32. }
  33. if (nodeData.type === BlockEnum.LLM) {
  34. const llmData = nodeData as LLMNodeType
  35. if (llmData.model?.provider && llmData.model?.name)
  36. desc = `${llmData.model.name} (${llmData.model.provider}) - ${llmData.model.mode || desc}`
  37. }
  38. if (nodeData.type === BlockEnum.KnowledgeRetrieval) {
  39. const knowledgeData = nodeData as KnowledgeRetrievalNodeType
  40. if (knowledgeData.dataset_ids?.length)
  41. desc = `Knowledge Retrieval with ${knowledgeData.dataset_ids.length} datasets - ${desc}`
  42. }
  43. return {
  44. id: node.id,
  45. title,
  46. desc,
  47. type: nodeData.type,
  48. blockType: nodeData.type,
  49. nodeData,
  50. toolIcon: getToolIcon(nodeData),
  51. modelInfo: nodeData.type === BlockEnum.LLM ? {
  52. provider: (nodeData as LLMNodeType).model?.provider,
  53. name: (nodeData as LLMNodeType).model?.name,
  54. mode: (nodeData as LLMNodeType).model?.mode,
  55. } : {
  56. provider: undefined,
  57. name: undefined,
  58. mode: undefined,
  59. },
  60. }
  61. })
  62. }, [nodes, getToolIcon])
  63. // Calculate relevance score for search results
  64. const calculateScore = useCallback((node: {
  65. title: string;
  66. type: string;
  67. desc: string;
  68. modelInfo: { provider?: string; name?: string; mode?: string }
  69. }, searchTerm: string): number => {
  70. if (!searchTerm) return 1
  71. let score = 0
  72. const term = searchTerm.toLowerCase()
  73. // Title match (highest priority)
  74. if (node.title.toLowerCase().includes(term))
  75. score += 10
  76. // Type match
  77. if (node.type.toLowerCase().includes(term))
  78. score += 8
  79. // Description match
  80. if (node.desc.toLowerCase().includes(term))
  81. score += 5
  82. // Model info matches (for LLM nodes)
  83. if (node.modelInfo.provider?.toLowerCase().includes(term))
  84. score += 6
  85. if (node.modelInfo.name?.toLowerCase().includes(term))
  86. score += 6
  87. if (node.modelInfo.mode?.toLowerCase().includes(term))
  88. score += 4
  89. return score
  90. }, [])
  91. // Create search function for RAG pipeline nodes
  92. const searchRagPipelineNodes = useCallback((query: string) => {
  93. if (!searchableNodes.length) return []
  94. const searchTerm = query.toLowerCase().trim()
  95. const results = searchableNodes
  96. .map((node) => {
  97. const score = calculateScore(node, searchTerm)
  98. return score > 0 ? {
  99. id: node.id,
  100. title: node.title,
  101. description: node.desc || node.type,
  102. type: 'workflow-node' as const,
  103. path: `#${node.id}`,
  104. icon: (
  105. <BlockIcon
  106. type={node.blockType}
  107. className="shrink-0"
  108. size="sm"
  109. toolIcon={node.toolIcon}
  110. />
  111. ),
  112. metadata: {
  113. nodeId: node.id,
  114. nodeData: node.nodeData,
  115. },
  116. data: node.nodeData,
  117. score,
  118. } : null
  119. })
  120. .filter((node): node is NonNullable<typeof node> => node !== null)
  121. .sort((a, b) => {
  122. // If no search term, sort alphabetically
  123. if (!searchTerm) return a.title.localeCompare(b.title)
  124. // Sort by relevance score (higher score first)
  125. return (b.score || 0) - (a.score || 0)
  126. })
  127. return results
  128. }, [searchableNodes, calculateScore])
  129. // Directly set the search function on the action object
  130. useEffect(() => {
  131. if (searchableNodes.length > 0) {
  132. // Set the search function directly on the action
  133. ragPipelineNodesAction.searchFn = searchRagPipelineNodes
  134. }
  135. return () => {
  136. // Clean up when component unmounts
  137. ragPipelineNodesAction.searchFn = undefined
  138. }
  139. }, [searchableNodes, searchRagPipelineNodes])
  140. // Set up node selection event listener using the utility function
  141. useEffect(() => {
  142. return setupNodeSelectionListener(handleNodeSelect)
  143. }, [handleNodeSelect])
  144. return null
  145. }