import { useCallback, } from 'react' import { produce } from 'immer' import { useStoreApi } from 'reactflow' import { useNodeDataUpdate } from '@/app/components/workflow/hooks' import type { ValueSelector } from '@/app/components/workflow/types' import { ChunkStructureEnum, IndexMethodEnum, RetrievalSearchMethodEnum, WeightedScoreEnum, } from '../types' import type { KnowledgeBaseNodeType, RerankingModel, } from '../types' import { HybridSearchModeEnum, } from '../types' import { isHighQualitySearchMethod } from '../utils' import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum } from '@/models/datasets' export const useConfig = (id: string) => { const store = useStoreApi() const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate() const getNodeData = useCallback(() => { const { getNodes } = store.getState() const nodes = getNodes() return nodes.find(node => node.id === id) }, [store, id]) const handleNodeDataUpdate = useCallback((data: Partial) => { handleNodeDataUpdateWithSyncDraft({ id, data, }) }, [id, handleNodeDataUpdateWithSyncDraft]) const getDefaultWeights = useCallback(({ embeddingModel, embeddingModelProvider, }: { embeddingModel: string embeddingModelProvider: string }) => { return { vector_setting: { vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, embedding_provider_name: embeddingModelProvider || '', embedding_model_name: embeddingModel, }, keyword_setting: { keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, }, } }, []) const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => { const nodeData = getNodeData() const { indexing_technique, retrieval_model, chunk_structure, index_chunk_variable_selector, } = nodeData?.data || {} const { search_method } = retrieval_model || {} handleNodeDataUpdate({ chunk_structure: chunkStructure, indexing_technique: (chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) ? IndexMethodEnum.QUALIFIED : indexing_technique, retrieval_model: { ...retrieval_model, search_method: ((chunkStructure === ChunkStructureEnum.parent_child || chunkStructure === ChunkStructureEnum.question_answer) && !isHighQualitySearchMethod(search_method)) ? RetrievalSearchMethodEnum.keywordSearch : search_method, }, index_chunk_variable_selector: chunkStructure === chunk_structure ? index_chunk_variable_selector : [], }) }, [handleNodeDataUpdate, getNodeData]) const handleIndexMethodChange = useCallback((indexMethod: IndexMethodEnum) => { const nodeData = getNodeData() handleNodeDataUpdate(produce(nodeData?.data as KnowledgeBaseNodeType, (draft) => { draft.indexing_technique = indexMethod if (indexMethod === IndexMethodEnum.ECONOMICAL) draft.retrieval_model.search_method = RetrievalSearchMethodEnum.keywordSearch else if (indexMethod === IndexMethodEnum.QUALIFIED) draft.retrieval_model.search_method = RetrievalSearchMethodEnum.semantic })) }, [handleNodeDataUpdate, getNodeData]) const handleKeywordNumberChange = useCallback((keywordNumber: number) => { handleNodeDataUpdate({ keyword_number: keywordNumber }) }, [handleNodeDataUpdate]) const handleEmbeddingModelChange = useCallback(({ embeddingModel, embeddingModelProvider, }: { embeddingModel: string embeddingModelProvider: string }) => { const nodeData = getNodeData() const defaultWeights = getDefaultWeights({ embeddingModel, embeddingModelProvider, }) const changeData = { embedding_model: embeddingModel, embedding_model_provider: embeddingModelProvider, retrieval_model: { ...nodeData?.data.retrieval_model, }, } if (changeData.retrieval_model.weights) { changeData.retrieval_model = { ...changeData.retrieval_model, weights: { ...changeData.retrieval_model.weights, vector_setting: { ...changeData.retrieval_model.weights.vector_setting, embedding_provider_name: embeddingModelProvider, embedding_model_name: embeddingModel, }, }, } } else { changeData.retrieval_model = { ...changeData.retrieval_model, weights: defaultWeights, } } handleNodeDataUpdate(changeData) }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => { const nodeData = getNodeData() const changeData = { retrieval_model: { ...nodeData?.data.retrieval_model, search_method: searchMethod, reranking_mode: nodeData?.data.retrieval_model.reranking_mode || RerankingModeEnum.RerankingModel, }, } if (searchMethod === RetrievalSearchMethodEnum.hybrid) { changeData.retrieval_model = { ...changeData.retrieval_model, reranking_enable: changeData.retrieval_model.reranking_mode === RerankingModeEnum.RerankingModel, } } handleNodeDataUpdate(changeData) }, [getNodeData, handleNodeDataUpdate]) const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => { const nodeData = getNodeData() const defaultWeights = getDefaultWeights({ embeddingModel: nodeData?.data.embedding_model || '', embeddingModelProvider: nodeData?.data.embedding_model_provider || '', }) handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, reranking_mode: hybridSearchMode, reranking_enable: hybridSearchMode === HybridSearchModeEnum.RerankingModel, weights: nodeData?.data.retrieval_model.weights || defaultWeights, }, }) }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) const handleRerankingModelEnabledChange = useCallback((rerankingModelEnabled: boolean) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, reranking_enable: rerankingModelEnabled, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleWeighedScoreChange = useCallback((weightedScore: { value: number[] }) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, weights: { weight_type: WeightedScoreEnum.Customized, vector_setting: { ...nodeData?.data.retrieval_model.weights?.vector_setting, vector_weight: weightedScore.value[0], }, keyword_setting: { keyword_weight: weightedScore.value[1], }, }, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleRerankingModelChange = useCallback((rerankingModel: RerankingModel) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, reranking_model: { reranking_provider_name: rerankingModel.reranking_provider_name, reranking_model_name: rerankingModel.reranking_model_name, }, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleTopKChange = useCallback((topK: number) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, top_k: topK, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleScoreThresholdChange = useCallback((scoreThreshold: number) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, score_threshold: scoreThreshold, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleScoreThresholdEnabledChange = useCallback((isEnabled: boolean) => { const nodeData = getNodeData() handleNodeDataUpdate({ retrieval_model: { ...nodeData?.data.retrieval_model, score_threshold_enabled: isEnabled, }, }) }, [getNodeData, handleNodeDataUpdate]) const handleInputVariableChange = useCallback((inputVariable: string | ValueSelector) => { handleNodeDataUpdate({ index_chunk_variable_selector: Array.isArray(inputVariable) ? inputVariable : [], }) }, [handleNodeDataUpdate]) return { handleChunkStructureChange, handleIndexMethodChange, handleKeywordNumberChange, handleEmbeddingModelChange, handleRetrievalSearchMethodChange, handleHybridSearchModeChange, handleRerankingModelEnabledChange, handleWeighedScoreChange, handleRerankingModelChange, handleTopKChange, handleScoreThresholdChange, handleScoreThresholdEnabledChange, handleInputVariableChange, } }