Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

6 месяцев назад
5 месяцев назад
6 месяцев назад
6 месяцев назад
5 месяцев назад
6 месяцев назад
5 месяцев назад
5 месяцев назад
6 месяцев назад
3 месяцев назад
6 месяцев назад
5 месяцев назад
3 месяцев назад
3 месяцев назад
3 месяцев назад
5 месяцев назад
3 месяцев назад
3 месяцев назад
3 месяцев назад
3 месяцев назад
5 месяцев назад
6 месяцев назад
5 месяцев назад
5 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
6 месяцев назад
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import {
  2. useCallback,
  3. } from 'react'
  4. import { produce } from 'immer'
  5. import { useStoreApi } from 'reactflow'
  6. import { useNodeDataUpdate } from '@/app/components/workflow/hooks'
  7. import type { ValueSelector } from '@/app/components/workflow/types'
  8. import {
  9. ChunkStructureEnum,
  10. IndexMethodEnum,
  11. RetrievalSearchMethodEnum,
  12. } from '../types'
  13. import type {
  14. HybridSearchModeEnum,
  15. KnowledgeBaseNodeType,
  16. RerankingModel,
  17. } from '../types'
  18. import { isHighQualitySearchMethod } from '../utils'
  19. export const useConfig = (id: string) => {
  20. const store = useStoreApi()
  21. const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate()
  22. const getNodeData = useCallback(() => {
  23. const { getNodes } = store.getState()
  24. const nodes = getNodes()
  25. return nodes.find(node => node.id === id)
  26. }, [store, id])
  27. const handleNodeDataUpdate = useCallback((data: Partial<KnowledgeBaseNodeType>) => {
  28. handleNodeDataUpdateWithSyncDraft({
  29. id,
  30. data,
  31. })
  32. }, [id, handleNodeDataUpdateWithSyncDraft])
  33. const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => {
  34. const nodeData = getNodeData()
  35. const {
  36. indexing_technique,
  37. retrieval_model,
  38. chunk_structure,
  39. index_chunk_variable_selector,
  40. } = nodeData?.data
  41. const { search_method } = retrieval_model || {}
  42. handleNodeDataUpdate({
  43. chunk_structure: chunkStructure,
  44. indexing_technique: (chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) ? IndexMethodEnum.QUALIFIED : indexing_technique,
  45. retrieval_model: {
  46. ...retrieval_model,
  47. search_method: ((chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) && !isHighQualitySearchMethod(search_method)) ? RetrievalSearchMethodEnum.keywordSearch : search_method,
  48. },
  49. index_chunk_variable_selector: chunkStructure === chunk_structure ? index_chunk_variable_selector : [],
  50. })
  51. }, [handleNodeDataUpdate, getNodeData])
  52. const handleIndexMethodChange = useCallback((indexMethod: IndexMethodEnum) => {
  53. const nodeData = getNodeData()
  54. handleNodeDataUpdate(produce(nodeData?.data as KnowledgeBaseNodeType, (draft) => {
  55. draft.indexing_technique = indexMethod
  56. if (indexMethod === IndexMethodEnum.ECONOMICAL)
  57. draft.retrieval_model.search_method = RetrievalSearchMethodEnum.keywordSearch
  58. else if (indexMethod === IndexMethodEnum.QUALIFIED)
  59. draft.retrieval_model.search_method = RetrievalSearchMethodEnum.semantic
  60. }))
  61. }, [handleNodeDataUpdate, getNodeData])
  62. const handleKeywordNumberChange = useCallback((keywordNumber: number) => {
  63. handleNodeDataUpdate({ keyword_number: keywordNumber })
  64. }, [handleNodeDataUpdate])
  65. const handleEmbeddingModelChange = useCallback(({
  66. embeddingModel,
  67. embeddingModelProvider,
  68. }: {
  69. embeddingModel: string
  70. embeddingModelProvider: string
  71. }) => {
  72. const nodeData = getNodeData()
  73. handleNodeDataUpdate({
  74. embedding_model: embeddingModel,
  75. embedding_model_provider: embeddingModelProvider,
  76. retrieval_model: {
  77. ...nodeData?.data.retrieval_model,
  78. vector_setting: {
  79. ...nodeData?.data.retrieval_model.vector_setting,
  80. embedding_provider_name: embeddingModelProvider,
  81. embedding_model_name: embeddingModel,
  82. },
  83. },
  84. })
  85. }, [getNodeData, handleNodeDataUpdate])
  86. const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => {
  87. const nodeData = getNodeData()
  88. handleNodeDataUpdate({
  89. retrieval_model: {
  90. ...nodeData?.data.retrieval_model,
  91. search_method: searchMethod,
  92. },
  93. })
  94. }, [getNodeData, handleNodeDataUpdate])
  95. const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => {
  96. const nodeData = getNodeData()
  97. handleNodeDataUpdate({
  98. retrieval_model: {
  99. ...nodeData?.data.retrieval_model,
  100. hybridSearchMode,
  101. },
  102. })
  103. }, [getNodeData, handleNodeDataUpdate])
  104. const handleWeighedScoreChange = useCallback((weightedScore: { value: number[] }) => {
  105. const nodeData = getNodeData()
  106. handleNodeDataUpdate({
  107. retrieval_model: {
  108. ...nodeData?.data.retrieval_model,
  109. weights: {
  110. weight_type: 'weighted_score',
  111. vector_setting: {
  112. vector_weight: weightedScore.value[0],
  113. embedding_provider_name: '',
  114. embedding_model_name: '',
  115. },
  116. keyword_setting: {
  117. keyword_weight: weightedScore.value[1],
  118. },
  119. },
  120. },
  121. })
  122. }, [getNodeData, handleNodeDataUpdate])
  123. const handleRerankingModelChange = useCallback((rerankingModel: RerankingModel) => {
  124. const nodeData = getNodeData()
  125. handleNodeDataUpdate({
  126. retrieval_model: {
  127. ...nodeData?.data.retrieval_model,
  128. reranking_model: {
  129. reranking_provider_name: rerankingModel.reranking_provider_name,
  130. reranking_model_name: rerankingModel.reranking_model_name,
  131. },
  132. },
  133. })
  134. }, [getNodeData, handleNodeDataUpdate])
  135. const handleTopKChange = useCallback((topK: number) => {
  136. const nodeData = getNodeData()
  137. handleNodeDataUpdate({
  138. retrieval_model: {
  139. ...nodeData?.data.retrieval_model,
  140. top_k: topK,
  141. },
  142. })
  143. }, [getNodeData, handleNodeDataUpdate])
  144. const handleScoreThresholdChange = useCallback((scoreThreshold: number) => {
  145. const nodeData = getNodeData()
  146. handleNodeDataUpdate({
  147. retrieval_model: {
  148. ...nodeData?.data.retrieval_model,
  149. score_threshold: scoreThreshold,
  150. },
  151. })
  152. }, [getNodeData, handleNodeDataUpdate])
  153. const handleScoreThresholdEnabledChange = useCallback((isEnabled: boolean) => {
  154. const nodeData = getNodeData()
  155. handleNodeDataUpdate({
  156. retrieval_model: {
  157. ...nodeData?.data.retrieval_model,
  158. score_threshold_enabled: isEnabled,
  159. },
  160. })
  161. }, [getNodeData, handleNodeDataUpdate])
  162. const handleInputVariableChange = useCallback((inputVariable: string | ValueSelector) => {
  163. handleNodeDataUpdate({
  164. index_chunk_variable_selector: Array.isArray(inputVariable) ? inputVariable : [],
  165. })
  166. }, [handleNodeDataUpdate])
  167. return {
  168. handleChunkStructureChange,
  169. handleIndexMethodChange,
  170. handleKeywordNumberChange,
  171. handleEmbeddingModelChange,
  172. handleRetrievalSearchMethodChange,
  173. handleHybridSearchModeChange,
  174. handleWeighedScoreChange,
  175. handleRerankingModelChange,
  176. handleTopKChange,
  177. handleScoreThresholdChange,
  178. handleScoreThresholdEnabledChange,
  179. handleInputVariableChange,
  180. }
  181. }