Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

use-config.ts 8.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import {
  2. useCallback,
  3. } from 'react'
  4. import { produce } from 'immer'
  5. import { useStoreApi } from 'reactflow'
  6. import { useNodeDataUpdate } from '@/app/components/workflow/hooks'
  7. import type { ValueSelector } from '@/app/components/workflow/types'
  8. import {
  9. ChunkStructureEnum,
  10. IndexMethodEnum,
  11. RetrievalSearchMethodEnum,
  12. WeightedScoreEnum,
  13. } from '../types'
  14. import type {
  15. KnowledgeBaseNodeType,
  16. RerankingModel,
  17. } from '../types'
  18. import {
  19. HybridSearchModeEnum,
  20. } from '../types'
  21. import { isHighQualitySearchMethod } from '../utils'
  22. import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum } from '@/models/datasets'
  23. export const useConfig = (id: string) => {
  24. const store = useStoreApi()
  25. const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate()
  26. const getNodeData = useCallback(() => {
  27. const { getNodes } = store.getState()
  28. const nodes = getNodes()
  29. return nodes.find(node => node.id === id)
  30. }, [store, id])
  31. const handleNodeDataUpdate = useCallback((data: Partial<KnowledgeBaseNodeType>) => {
  32. handleNodeDataUpdateWithSyncDraft({
  33. id,
  34. data,
  35. })
  36. }, [id, handleNodeDataUpdateWithSyncDraft])
  37. const getDefaultWeights = useCallback(({
  38. embeddingModel,
  39. embeddingModelProvider,
  40. }: {
  41. embeddingModel: string
  42. embeddingModelProvider: string
  43. }) => {
  44. return {
  45. vector_setting: {
  46. vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
  47. embedding_provider_name: embeddingModelProvider || '',
  48. embedding_model_name: embeddingModel,
  49. },
  50. keyword_setting: {
  51. keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
  52. },
  53. }
  54. }, [])
  55. const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => {
  56. const nodeData = getNodeData()
  57. const {
  58. indexing_technique,
  59. retrieval_model,
  60. chunk_structure,
  61. index_chunk_variable_selector,
  62. } = nodeData?.data || {}
  63. const { search_method } = retrieval_model || {}
  64. handleNodeDataUpdate({
  65. chunk_structure: chunkStructure,
  66. indexing_technique: (chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) ? IndexMethodEnum.QUALIFIED : indexing_technique,
  67. retrieval_model: {
  68. ...retrieval_model,
  69. search_method: ((chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) && !isHighQualitySearchMethod(search_method)) ? RetrievalSearchMethodEnum.keywordSearch : search_method,
  70. },
  71. index_chunk_variable_selector: chunkStructure === chunk_structure ? index_chunk_variable_selector : [],
  72. })
  73. }, [handleNodeDataUpdate, getNodeData])
  74. const handleIndexMethodChange = useCallback((indexMethod: IndexMethodEnum) => {
  75. const nodeData = getNodeData()
  76. handleNodeDataUpdate(produce(nodeData?.data as KnowledgeBaseNodeType, (draft) => {
  77. draft.indexing_technique = indexMethod
  78. if (indexMethod === IndexMethodEnum.ECONOMICAL)
  79. draft.retrieval_model.search_method = RetrievalSearchMethodEnum.keywordSearch
  80. else if (indexMethod === IndexMethodEnum.QUALIFIED)
  81. draft.retrieval_model.search_method = RetrievalSearchMethodEnum.semantic
  82. }))
  83. }, [handleNodeDataUpdate, getNodeData])
  84. const handleKeywordNumberChange = useCallback((keywordNumber: number) => {
  85. handleNodeDataUpdate({ keyword_number: keywordNumber })
  86. }, [handleNodeDataUpdate])
  87. const handleEmbeddingModelChange = useCallback(({
  88. embeddingModel,
  89. embeddingModelProvider,
  90. }: {
  91. embeddingModel: string
  92. embeddingModelProvider: string
  93. }) => {
  94. const nodeData = getNodeData()
  95. const defaultWeights = getDefaultWeights({
  96. embeddingModel,
  97. embeddingModelProvider,
  98. })
  99. const changeData = {
  100. embedding_model: embeddingModel,
  101. embedding_model_provider: embeddingModelProvider,
  102. retrieval_model: {
  103. ...nodeData?.data.retrieval_model,
  104. },
  105. }
  106. if (changeData.retrieval_model.weights) {
  107. changeData.retrieval_model = {
  108. ...changeData.retrieval_model,
  109. weights: {
  110. ...changeData.retrieval_model.weights,
  111. vector_setting: {
  112. ...changeData.retrieval_model.weights.vector_setting,
  113. embedding_provider_name: embeddingModelProvider,
  114. embedding_model_name: embeddingModel,
  115. },
  116. },
  117. }
  118. }
  119. else {
  120. changeData.retrieval_model = {
  121. ...changeData.retrieval_model,
  122. weights: defaultWeights,
  123. }
  124. }
  125. handleNodeDataUpdate(changeData)
  126. }, [getNodeData, getDefaultWeights, handleNodeDataUpdate])
  127. const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => {
  128. const nodeData = getNodeData()
  129. const changeData = {
  130. retrieval_model: {
  131. ...nodeData?.data.retrieval_model,
  132. search_method: searchMethod,
  133. reranking_mode: nodeData?.data.retrieval_model.reranking_mode || RerankingModeEnum.RerankingModel,
  134. },
  135. }
  136. if (searchMethod === RetrievalSearchMethodEnum.hybrid) {
  137. changeData.retrieval_model = {
  138. ...changeData.retrieval_model,
  139. reranking_enable: changeData.retrieval_model.reranking_mode === RerankingModeEnum.RerankingModel,
  140. }
  141. }
  142. handleNodeDataUpdate(changeData)
  143. }, [getNodeData, handleNodeDataUpdate])
  144. const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => {
  145. const nodeData = getNodeData()
  146. const defaultWeights = getDefaultWeights({
  147. embeddingModel: nodeData?.data.embedding_model || '',
  148. embeddingModelProvider: nodeData?.data.embedding_model_provider || '',
  149. })
  150. handleNodeDataUpdate({
  151. retrieval_model: {
  152. ...nodeData?.data.retrieval_model,
  153. reranking_mode: hybridSearchMode,
  154. reranking_enable: hybridSearchMode === HybridSearchModeEnum.RerankingModel,
  155. weights: nodeData?.data.retrieval_model.weights || defaultWeights,
  156. },
  157. })
  158. }, [getNodeData, getDefaultWeights, handleNodeDataUpdate])
  159. const handleRerankingModelEnabledChange = useCallback((rerankingModelEnabled: boolean) => {
  160. const nodeData = getNodeData()
  161. handleNodeDataUpdate({
  162. retrieval_model: {
  163. ...nodeData?.data.retrieval_model,
  164. reranking_enable: rerankingModelEnabled,
  165. },
  166. })
  167. }, [getNodeData, handleNodeDataUpdate])
  168. const handleWeighedScoreChange = useCallback((weightedScore: { value: number[] }) => {
  169. const nodeData = getNodeData()
  170. handleNodeDataUpdate({
  171. retrieval_model: {
  172. ...nodeData?.data.retrieval_model,
  173. weights: {
  174. weight_type: WeightedScoreEnum.Customized,
  175. vector_setting: {
  176. ...nodeData?.data.retrieval_model.weights?.vector_setting,
  177. vector_weight: weightedScore.value[0],
  178. },
  179. keyword_setting: {
  180. keyword_weight: weightedScore.value[1],
  181. },
  182. },
  183. },
  184. })
  185. }, [getNodeData, handleNodeDataUpdate])
  186. const handleRerankingModelChange = useCallback((rerankingModel: RerankingModel) => {
  187. const nodeData = getNodeData()
  188. handleNodeDataUpdate({
  189. retrieval_model: {
  190. ...nodeData?.data.retrieval_model,
  191. reranking_model: {
  192. reranking_provider_name: rerankingModel.reranking_provider_name,
  193. reranking_model_name: rerankingModel.reranking_model_name,
  194. },
  195. },
  196. })
  197. }, [getNodeData, handleNodeDataUpdate])
  198. const handleTopKChange = useCallback((topK: number) => {
  199. const nodeData = getNodeData()
  200. handleNodeDataUpdate({
  201. retrieval_model: {
  202. ...nodeData?.data.retrieval_model,
  203. top_k: topK,
  204. },
  205. })
  206. }, [getNodeData, handleNodeDataUpdate])
  207. const handleScoreThresholdChange = useCallback((scoreThreshold: number) => {
  208. const nodeData = getNodeData()
  209. handleNodeDataUpdate({
  210. retrieval_model: {
  211. ...nodeData?.data.retrieval_model,
  212. score_threshold: scoreThreshold,
  213. },
  214. })
  215. }, [getNodeData, handleNodeDataUpdate])
  216. const handleScoreThresholdEnabledChange = useCallback((isEnabled: boolean) => {
  217. const nodeData = getNodeData()
  218. handleNodeDataUpdate({
  219. retrieval_model: {
  220. ...nodeData?.data.retrieval_model,
  221. score_threshold_enabled: isEnabled,
  222. },
  223. })
  224. }, [getNodeData, handleNodeDataUpdate])
  225. const handleInputVariableChange = useCallback((inputVariable: string | ValueSelector) => {
  226. handleNodeDataUpdate({
  227. index_chunk_variable_selector: Array.isArray(inputVariable) ? inputVariable : [],
  228. })
  229. }, [handleNodeDataUpdate])
  230. return {
  231. handleChunkStructureChange,
  232. handleIndexMethodChange,
  233. handleKeywordNumberChange,
  234. handleEmbeddingModelChange,
  235. handleRetrievalSearchMethodChange,
  236. handleHybridSearchModeChange,
  237. handleRerankingModelEnabledChange,
  238. handleWeighedScoreChange,
  239. handleRerankingModelChange,
  240. handleTopKChange,
  241. handleScoreThresholdChange,
  242. handleScoreThresholdEnabledChange,
  243. handleInputVariableChange,
  244. }
  245. }